Update models

This commit is contained in:
D-X-Y
2021-03-29 05:04:24 +00:00
parent e637cddc39
commit b51320dfb1
7 changed files with 85 additions and 14 deletions

View File

@@ -99,7 +99,12 @@ def run_exp(
# Train model
try:
model = R.load_object(model_obj_name)
if hasattr(model, "to"): # Recoverable model
device = model.device
model = R.load_object(model_obj_name)
model.to(device)
else:
model = R.load_object(model_obj_name)
logger.info("[Find existing object from {:}]".format(model_obj_name))
except OSError:
R.log_params(**flatten_dict(task_config))
@@ -112,16 +117,29 @@ def run_exp(
recorder_root_dir, "model-ckps"
)
model.fit(**model_fit_kwargs)
R.save_objects(**{model_obj_name: model})
except:
raise ValueError("Something wrong.")
# remove model to CPU for saving
if hasattr(model, "to"):
model.to("cpu")
R.save_objects(**{model_obj_name: model})
model.to()
else:
R.save_objects(**{model_obj_name: model})
except Exception as e:
import pdb
pdb.set_trace()
raise ValueError("Something wrong: {:}".format(e))
# Get the recorder
recorder = R.get_recorder()
# Generate records: prediction, backtest, and analysis
for record in task_config["record"]:
record = deepcopy(record)
if record["class"] == "SignalRecord":
if record["class"] == "MultiSegRecord":
record["kwargs"] = dict(model=model, dataset=dataset, recorder=recorder)
sr = init_instance_by_config(record)
sr.generate(**record["generate_kwargs"])
elif record["class"] == "SignalRecord":
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
record["kwargs"].update(srconf)
sr = init_instance_by_config(record)