Update baselines

This commit is contained in:
D-X-Y
2021-03-28 10:57:20 +00:00
parent 0055511829
commit 53cb5f1fdd
4 changed files with 24 additions and 10 deletions

View File

@@ -67,6 +67,18 @@ def extend_transformer_settings(alg2configs, name):
return alg2configs
def remove_PortAnaRecord(alg2configs):
alg2configs = copy.deepcopy(alg2configs)
for key, config in alg2configs.items():
xlist = config["task"]["record"]
new_list = []
for x in xlist:
if x["class"] != "PortAnaRecord":
new_list.append(x)
config["task"]["record"] = new_list
return alg2configs
def retrieve_configs():
# https://github.com/microsoft/qlib/blob/main/examples/benchmarks/
config_dir = (lib_dir / ".." / "configs" / "qlib").resolve()
@@ -105,6 +117,12 @@ def retrieve_configs():
)
)
alg2configs = extend_transformer_settings(alg2configs, "TSF")
alg2configs = remove_PortAnaRecord(alg2configs)
print(
"There are {:} algorithms : {:}".format(
len(alg2configs), list(alg2configs.keys())
)
)
return alg2configs

View File

@@ -223,7 +223,7 @@ if __name__ == "__main__":
info_dict["heads"],
info_dict["values"],
info_dict["names"],
space=12,
space=14,
verbose=True,
sort_key=True,
)

View File

@@ -8,16 +8,12 @@ import os, math, random
from collections import OrderedDict
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from functools import partial
from typing import Optional, Text
from qlib.utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from qlib.utils import get_or_create_path
from qlib.log import get_module_logger
import torch
@@ -308,10 +304,10 @@ class QuantTransformer(Model):
torch.cuda.empty_cache()
self.fitted = True
def predict(self, dataset, segment="test"):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("The model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.model.eval()