Update models

This commit is contained in:
D-X-Y
2021-03-29 05:04:24 +00:00
parent e637cddc39
commit b51320dfb1
7 changed files with 85 additions and 14 deletions

View File

@@ -67,14 +67,26 @@ def extend_transformer_settings(alg2configs, name):
return alg2configs
def remove_PortAnaRecord(alg2configs):
def refresh_record(alg2configs):
alg2configs = copy.deepcopy(alg2configs)
for key, config in alg2configs.items():
xlist = config["task"]["record"]
new_list = []
for x in xlist:
if x["class"] != "PortAnaRecord":
# remove PortAnaRecord and SignalMseRecord
if x["class"] != "PortAnaRecord" and x["class"] != "SignalMseRecord":
new_list.append(x)
## add MultiSegRecord
new_list.append(
{
"class": "MultiSegRecord",
"module_path": "qlib.contrib.workflow",
"generate_kwargs": {
"segments": {"train": "train", "valid": "valid", "test": "test"},
"save": True,
},
}
)
config["task"]["record"] = new_list
return alg2configs
@@ -117,7 +129,7 @@ def retrieve_configs():
)
)
alg2configs = extend_transformer_settings(alg2configs, "TSF")
alg2configs = remove_PortAnaRecord(alg2configs)
alg2configs = refresh_record(alg2configs)
print(
"There are {:} algorithms : {:}".format(
len(alg2configs), list(alg2configs.keys())