Add int search space
This commit is contained in:
@@ -39,7 +39,10 @@ def main(xargs):
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": xargs.market,
|
||||
"infer_processors": [
|
||||
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": True}},
|
||||
{
|
||||
"class": "RobustZScoreNorm",
|
||||
"kwargs": {"fields_group": "feature", "clip_outlier": True},
|
||||
},
|
||||
{"class": "Fillna", "kwargs": {"fields_group": "feature"}},
|
||||
],
|
||||
"learn_processors": [
|
||||
@@ -90,7 +93,11 @@ def main(xargs):
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{"class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict()},
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
"kwargs": dict(),
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
@@ -111,18 +118,37 @@ def main(xargs):
|
||||
for irun in range(xargs.times):
|
||||
xmodel_config = model_config.copy()
|
||||
xmodel_config = update_gpu(xmodel_config, xargs.gpu)
|
||||
task_config = dict(model=xmodel_config, dataset=dataset_config, record=record_config)
|
||||
task_config = dict(
|
||||
model=xmodel_config, dataset=dataset_config, record=record_config
|
||||
)
|
||||
|
||||
run_exp(task_config, dataset, xargs.name, "recorder-{:02d}-{:02d}".format(irun, xargs.times), save_dir)
|
||||
run_exp(
|
||||
task_config,
|
||||
dataset,
|
||||
xargs.name,
|
||||
"recorder-{:02d}-{:02d}".format(irun, xargs.times),
|
||||
save_dir,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Vanilla Transformable Transformer")
|
||||
parser.add_argument("--save_dir", type=str, default="./outputs/vtt-runs", help="The checkpoint directory.")
|
||||
parser.add_argument("--name", type=str, default="Transformer", help="The experiment name.")
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./outputs/vtt-runs",
|
||||
help="The checkpoint directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name", type=str, default="Transformer", help="The experiment name."
|
||||
)
|
||||
parser.add_argument("--times", type=int, default=10, help="The repeated run times.")
|
||||
parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.")
|
||||
parser.add_argument("--market", type=str, default="all", help="The market indicator.")
|
||||
parser.add_argument(
|
||||
"--gpu", type=int, default=0, help="The GPU ID used for train / test."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--market", type=str, default="all", help="The market indicator."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
Reference in New Issue
Block a user