Update models
This commit is contained in:
@@ -99,7 +99,12 @@ def run_exp(
|
||||
|
||||
# Train model
|
||||
try:
|
||||
model = R.load_object(model_obj_name)
|
||||
if hasattr(model, "to"): # Recoverable model
|
||||
device = model.device
|
||||
model = R.load_object(model_obj_name)
|
||||
model.to(device)
|
||||
else:
|
||||
model = R.load_object(model_obj_name)
|
||||
logger.info("[Find existing object from {:}]".format(model_obj_name))
|
||||
except OSError:
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
@@ -112,16 +117,29 @@ def run_exp(
|
||||
recorder_root_dir, "model-ckps"
|
||||
)
|
||||
model.fit(**model_fit_kwargs)
|
||||
R.save_objects(**{model_obj_name: model})
|
||||
except:
|
||||
raise ValueError("Something wrong.")
|
||||
# remove model to CPU for saving
|
||||
if hasattr(model, "to"):
|
||||
model.to("cpu")
|
||||
R.save_objects(**{model_obj_name: model})
|
||||
model.to()
|
||||
else:
|
||||
R.save_objects(**{model_obj_name: model})
|
||||
except Exception as e:
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
raise ValueError("Something wrong: {:}".format(e))
|
||||
# Get the recorder
|
||||
recorder = R.get_recorder()
|
||||
|
||||
# Generate records: prediction, backtest, and analysis
|
||||
for record in task_config["record"]:
|
||||
record = deepcopy(record)
|
||||
if record["class"] == "SignalRecord":
|
||||
if record["class"] == "MultiSegRecord":
|
||||
record["kwargs"] = dict(model=model, dataset=dataset, recorder=recorder)
|
||||
sr = init_instance_by_config(record)
|
||||
sr.generate(**record["generate_kwargs"])
|
||||
elif record["class"] == "SignalRecord":
|
||||
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
||||
record["kwargs"].update(srconf)
|
||||
sr = init_instance_by_config(record)
|
||||
|
@@ -112,6 +112,12 @@ class QuantTransformer(Model):
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def to(self, device):
|
||||
if device is None:
|
||||
self.model.to(self.device)
|
||||
else:
|
||||
self.model.to("cpu")
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
if self.opt_config["loss"] == "mse":
|
||||
|
Reference in New Issue
Block a user