first commit
This commit is contained in:
33
NAS-Bench-201/main_exp/utils.py
Normal file
33
NAS-Bench-201/main_exp/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
def restore_checkpoint(ckpt_dir, state, device, resume=False):
|
||||
if not resume:
|
||||
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
|
||||
return state
|
||||
elif not os.path.exists(ckpt_dir):
|
||||
if not os.path.exists(os.path.dirname(ckpt_dir)):
|
||||
os.makedirs(os.path.dirname(ckpt_dir))
|
||||
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
||||
f"Returned the same state as input")
|
||||
return state
|
||||
else:
|
||||
loaded_state = torch.load(ckpt_dir, map_location=device)
|
||||
for k in state:
|
||||
if k in ['optimizer', 'model', 'ema']:
|
||||
state[k].load_state_dict(loaded_state[k])
|
||||
else:
|
||||
state[k] = loaded_state[k]
|
||||
return state
|
||||
|
||||
|
||||
def reset_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
Reference in New Issue
Block a user