Update models
This commit is contained in:
@@ -67,14 +67,26 @@ def extend_transformer_settings(alg2configs, name):
|
||||
return alg2configs
|
||||
|
||||
|
||||
def remove_PortAnaRecord(alg2configs):
|
||||
def refresh_record(alg2configs):
|
||||
alg2configs = copy.deepcopy(alg2configs)
|
||||
for key, config in alg2configs.items():
|
||||
xlist = config["task"]["record"]
|
||||
new_list = []
|
||||
for x in xlist:
|
||||
if x["class"] != "PortAnaRecord":
|
||||
# remove PortAnaRecord and SignalMseRecord
|
||||
if x["class"] != "PortAnaRecord" and x["class"] != "SignalMseRecord":
|
||||
new_list.append(x)
|
||||
## add MultiSegRecord
|
||||
new_list.append(
|
||||
{
|
||||
"class": "MultiSegRecord",
|
||||
"module_path": "qlib.contrib.workflow",
|
||||
"generate_kwargs": {
|
||||
"segments": {"train": "train", "valid": "valid", "test": "test"},
|
||||
"save": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
config["task"]["record"] = new_list
|
||||
return alg2configs
|
||||
|
||||
@@ -117,7 +129,7 @@ def retrieve_configs():
|
||||
)
|
||||
)
|
||||
alg2configs = extend_transformer_settings(alg2configs, "TSF")
|
||||
alg2configs = remove_PortAnaRecord(alg2configs)
|
||||
alg2configs = refresh_record(alg2configs)
|
||||
print(
|
||||
"There are {:} algorithms : {:}".format(
|
||||
len(alg2configs), list(alg2configs.keys())
|
||||
|
Reference in New Issue
Block a user