Sync qlib

This commit is contained in:
D-X-Y
2021-03-17 09:10:45 +00:00
parent 1ba1585f20
commit a9093e41e1
4 changed files with 20 additions and 12 deletions

View File

@@ -138,7 +138,7 @@ class QuantTransformer(Model):
def fit(
self,
dataset: DatasetH,
save_path: Optional[Text] = None,
save_dir: Optional[Text] = None,
):
def _prepare_dataset(df_data):
return th_data.TensorDataset(
@@ -172,8 +172,8 @@ class QuantTransformer(Model):
_prepare_loader(test_dataset, False),
)
save_path = get_or_create_path(save_path, return_dir=True)
self.logger.info("Fit procedure for [{:}] with save path={:}".format(self.__class__.__name__, save_path))
save_dir = get_or_create_path(save_dir, return_dir=True)
self.logger.info("Fit procedure for [{:}] with save path={:}".format(self.__class__.__name__, save_dir))
def _internal_test(ckp_epoch=None, results_dict=None):
with torch.no_grad():
@@ -196,15 +196,18 @@ class QuantTransformer(Model):
return dict(train=train_score, valid=valid_score, test=test_score), xstr
# Pre-fetch the potential checkpoints
ckp_path = os.path.join(save_path, "{:}.pth".format(self.__class__.__name__))
ckp_path = os.path.join(save_dir, "{:}.pth".format(self.__class__.__name__))
if os.path.exists(ckp_path):
ckp_data = torch.load(ckp_path)
import pdb
pdb.set_trace()
stop_steps, best_score, best_epoch = ckp_data['stop_steps'], ckp_data['best_score'], ckp_data['best_epoch']
start_epoch, best_param = ckp_data['start_epoch'], ckp_data['best_param']
results_dict = ckp_data['results_dict']
self.model.load_state_dict(ckp_data['net_state_dict'])
self.train_optimizer.load_state_dict(ckp_data['opt_state_dict'])
self.logger.info("Resume from existing checkpoint: {:}".format(ckp_path))
else:
stop_steps, best_score, best_epoch = 0, -np.inf, -1
start_epoch = 0
start_epoch, best_param = 0, None
results_dict = dict(train=OrderedDict(), valid=OrderedDict(), test=OrderedDict())
_, eval_str = _internal_test(-1, results_dict)
self.logger.info("Training from scratch, metrics@start: {:}".format(eval_str))
@@ -215,7 +218,6 @@ class QuantTransformer(Model):
iepoch, self.opt_config["epochs"], best_epoch, best_score
)
)
train_loss, train_score = self.train_or_test_epoch(
train_loader, self.model, self.loss_fn, self.metric_fn, True, self.train_optimizer
)
@@ -241,11 +243,14 @@ class QuantTransformer(Model):
stop_steps=stop_steps,
best_score=best_score,
best_epoch=best_epoch,
results_dict=results_dict,
start_epoch=iepoch + 1,
)
torch.save(save_info, ckp_path)
self.logger.info("The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch))
self.model.load_state_dict(best_param)
_, eval_str = _internal_test('final', results_dict)
self.logger.info("Reload the best parameter :: {:}".format(eval_str))
if self.use_gpu:
torch.cuda.empty_cache()