Fix bugs
This commit is contained in:
@@ -23,10 +23,12 @@ if str(lib_dir) not in sys.path:
|
||||
import qlib
|
||||
from qlib import config as qconfig
|
||||
from qlib.workflow import R
|
||||
qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN)
|
||||
|
||||
from utils.qlib_utils import QResult
|
||||
|
||||
|
||||
def filter_finished(recorders):
|
||||
returned_recorders = dict()
|
||||
not_finished = 0
|
||||
@@ -41,9 +43,10 @@ def filter_finished(recorders):
|
||||
def add_to_dict(xdict, timestamp, value):
|
||||
date = timestamp.date().strftime("%Y-%m-%d")
|
||||
if date in xdict:
|
||||
raise ValueError("This date [{:}] is already in the dict".format(date))
|
||||
raise ValueError("This date [{:}] is already in the dict".format(date))
|
||||
xdict[date] = value
|
||||
|
||||
|
||||
def query_info(save_dir, verbose, name_filter, key_map):
|
||||
if isinstance(save_dir, list):
|
||||
results = []
|
||||
@@ -61,7 +64,10 @@ def query_info(save_dir, verbose, name_filter, key_map):
|
||||
for idx, (key, experiment) in enumerate(experiments.items()):
|
||||
if experiment.id == "0":
|
||||
continue
|
||||
if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:
|
||||
if (
|
||||
name_filter is not None
|
||||
and re.fullmatch(name_filter, experiment.name) is None
|
||||
):
|
||||
continue
|
||||
recorders = experiment.list_recorders()
|
||||
recorders, not_finished = filter_finished(recorders)
|
||||
@@ -77,10 +83,10 @@ def query_info(save_dir, verbose, name_filter, key_map):
|
||||
)
|
||||
result = QResult(experiment.name)
|
||||
for recorder_id, recorder in recorders.items():
|
||||
file_names = ['results-train.pkl', 'results-valid.pkl', 'results-test.pkl']
|
||||
file_names = ["results-train.pkl", "results-valid.pkl", "results-test.pkl"]
|
||||
date2IC = OrderedDict()
|
||||
for file_name in file_names:
|
||||
xtemp = recorder.load_object(file_name)['all-IC']
|
||||
xtemp = recorder.load_object(file_name)["all-IC"]
|
||||
timestamps, values = xtemp.index.tolist(), xtemp.tolist()
|
||||
for timestamp, value in zip(timestamps, values):
|
||||
add_to_dict(date2IC, timestamp, value)
|
||||
@@ -104,7 +110,7 @@ def query_info(save_dir, verbose, name_filter, key_map):
|
||||
|
||||
|
||||
##
|
||||
paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']
|
||||
paths = [root_dir / "outputs" / "qlib-baselines-csi300"]
|
||||
paths = [path.resolve() for path in paths]
|
||||
print(paths)
|
||||
|
||||
@@ -112,12 +118,12 @@ key_map = dict()
|
||||
for xset in ("train", "valid", "test"):
|
||||
key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset)
|
||||
key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset)
|
||||
qresults = query_info(paths, False, 'TSF-2x24-drop0_0s.*-.*-01', key_map)
|
||||
print('Find {:} results'.format(len(qresults)))
|
||||
qresults = query_info(paths, False, "TSF-2x24-drop0_0s.*-.*-01", key_map)
|
||||
print("Find {:} results".format(len(qresults)))
|
||||
times = []
|
||||
for qresult in qresults:
|
||||
times.append(qresult.name.split('0_0s')[-1])
|
||||
times.append(qresult.name.split("0_0s")[-1])
|
||||
print(times)
|
||||
save_path = os.path.join(note_dir, 'temp-time-x.pth')
|
||||
save_path = os.path.join(note_dir, "temp-time-x.pth")
|
||||
torch.save(qresults, save_path)
|
||||
print(save_path)
|
||||
|
@@ -24,38 +24,38 @@ from qlib.model.base import Model
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN)
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha360",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha360",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
},
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
},
|
||||
}
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
pprint.pprint(dataset_config)
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
model = get_transformer(None)
|
||||
print(model)
|
||||
|
||||
@@ -72,4 +72,5 @@ label = labels[batch][mask]
|
||||
loss = torch.nn.functional.mse_loss(pred, label)
|
||||
|
||||
from sklearn.metrics import mean_squared_error
|
||||
|
||||
mse_loss = mean_squared_error(pred.numpy(), label.numpy())
|
||||
|
Reference in New Issue
Block a user