Update models

This commit is contained in:
D-X-Y
2021-03-29 05:04:24 +00:00
parent e637cddc39
commit b51320dfb1
7 changed files with 85 additions and 14 deletions

View File

@@ -99,7 +99,12 @@ def run_exp(
# Train model
try:
model = R.load_object(model_obj_name)
if hasattr(model, "to"): # Recoverable model
device = model.device
model = R.load_object(model_obj_name)
model.to(device)
else:
model = R.load_object(model_obj_name)
logger.info("[Find existing object from {:}]".format(model_obj_name))
except OSError:
R.log_params(**flatten_dict(task_config))
@@ -112,16 +117,29 @@ def run_exp(
recorder_root_dir, "model-ckps"
)
model.fit(**model_fit_kwargs)
R.save_objects(**{model_obj_name: model})
except:
raise ValueError("Something wrong.")
# remove model to CPU for saving
if hasattr(model, "to"):
model.to("cpu")
R.save_objects(**{model_obj_name: model})
model.to()
else:
R.save_objects(**{model_obj_name: model})
except Exception as e:
import pdb
pdb.set_trace()
raise ValueError("Something wrong: {:}".format(e))
# Get the recorder
recorder = R.get_recorder()
# Generate records: prediction, backtest, and analysis
for record in task_config["record"]:
record = deepcopy(record)
if record["class"] == "SignalRecord":
if record["class"] == "MultiSegRecord":
record["kwargs"] = dict(model=model, dataset=dataset, recorder=recorder)
sr = init_instance_by_config(record)
sr.generate(**record["generate_kwargs"])
elif record["class"] == "SignalRecord":
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
record["kwargs"].update(srconf)
sr = init_instance_by_config(record)

View File

@@ -112,6 +112,12 @@ class QuantTransformer(Model):
def use_gpu(self):
return self.device != torch.device("cpu")
def to(self, device):
if device is None:
self.model.to(self.device)
else:
self.model.to("cpu")
def loss_fn(self, pred, label):
mask = ~torch.isnan(label)
if self.opt_config["loss"] == "mse":