Add name filters for exp-org

This commit is contained in:
D-X-Y
2021-03-28 22:26:06 -07:00
parent b51320dfb1
commit 62bedaa094
3 changed files with 17 additions and 11 deletions

View File

@@ -100,9 +100,9 @@ def run_exp(
# Train model
try:
if hasattr(model, "to"): # Recoverable model
device = model.device
ori_device = model.device
model = R.load_object(model_obj_name)
model.to(device)
model.to(ori_device)
else:
model = R.load_object(model_obj_name)
logger.info("[Find existing object from {:}]".format(model_obj_name))
@@ -119,9 +119,10 @@ def run_exp(
model.fit(**model_fit_kwargs)
# remove model to CPU for saving
if hasattr(model, "to"):
old_device = model.device
model.to("cpu")
R.save_objects(**{model_obj_name: model})
model.to()
model.to(old_device)
else:
R.save_objects(**{model_obj_name: model})
except Exception as e: