Add int search space
This commit is contained in:
@@ -65,7 +65,11 @@ def retrieve_configs():
|
||||
path = config_dir / name
|
||||
assert path.exists(), "{:} does not exist.".format(path)
|
||||
alg2paths[alg] = str(path)
|
||||
print("The {:02d}/{:02d}-th baseline algorithm is {:9s} ({:}).".format(idx, len(alg2names), alg, path))
|
||||
print(
|
||||
"The {:02d}/{:02d}-th baseline algorithm is {:9s} ({:}).".format(
|
||||
idx, len(alg2names), alg, path
|
||||
)
|
||||
)
|
||||
return alg2paths
|
||||
|
||||
|
||||
@@ -100,13 +104,30 @@ if __name__ == "__main__":
|
||||
alg2paths = retrieve_configs()
|
||||
|
||||
parser = argparse.ArgumentParser("Baselines")
|
||||
parser.add_argument("--save_dir", type=str, default="./outputs/qlib-baselines", help="The checkpoint directory.")
|
||||
parser.add_argument(
|
||||
"--market", type=str, default="all", choices=["csi100", "csi300", "all"], help="The market indicator."
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./outputs/qlib-baselines",
|
||||
help="The checkpoint directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--market",
|
||||
type=str,
|
||||
default="all",
|
||||
choices=["csi100", "csi300", "all"],
|
||||
help="The market indicator.",
|
||||
)
|
||||
parser.add_argument("--times", type=int, default=10, help="The repeated run times.")
|
||||
parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.")
|
||||
parser.add_argument("--alg", type=str, choices=list(alg2paths.keys()), required=True, help="The algorithm name.")
|
||||
parser.add_argument(
|
||||
"--gpu", type=int, default=0, help="The GPU ID used for train / test."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alg",
|
||||
type=str,
|
||||
choices=list(alg2paths.keys()),
|
||||
required=True,
|
||||
help="The algorithm name.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args, alg2paths[args.alg])
|
||||
|
@@ -55,7 +55,13 @@ class QResult:
|
||||
new_dict[xkey] = values
|
||||
return new_dict
|
||||
|
||||
def info(self, keys: List[Text], separate: Text = "& ", space: int = 25, verbose: bool = True):
|
||||
def info(
|
||||
self,
|
||||
keys: List[Text],
|
||||
separate: Text = "& ",
|
||||
space: int = 25,
|
||||
verbose: bool = True,
|
||||
):
|
||||
avaliable_keys = []
|
||||
for key in keys:
|
||||
if key not in self.result:
|
||||
@@ -89,7 +95,10 @@ def compare_results(heads, values, names, space=10, verbose=True, sort_key=False
|
||||
if verbose:
|
||||
print(info_str_dict["head"])
|
||||
if sort_key:
|
||||
lines = sorted(list(zip(values, info_str_dict["lines"])), key=lambda x: float(x[0].split(" ")[0]))
|
||||
lines = sorted(
|
||||
list(zip(values, info_str_dict["lines"])),
|
||||
key=lambda x: float(x[0].split(" ")[0]),
|
||||
)
|
||||
lines = [x[1] for x in lines]
|
||||
else:
|
||||
lines = info_str_dict["lines"]
|
||||
@@ -136,7 +145,11 @@ def query_info(save_dir, verbose):
|
||||
if verbose:
|
||||
print(
|
||||
"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.".format(
|
||||
idx + 1, len(experiments), experiment.name, len(recorders), len(recorders) + not_finished
|
||||
idx + 1,
|
||||
len(experiments),
|
||||
experiment.name,
|
||||
len(recorders),
|
||||
len(recorders) + not_finished,
|
||||
)
|
||||
)
|
||||
result = QResult()
|
||||
@@ -149,7 +162,9 @@ def query_info(save_dir, verbose):
|
||||
head_strs.append(head_str)
|
||||
value_strs.append(value_str)
|
||||
names.append(experiment.name)
|
||||
info_str_dict = compare_results(head_strs, value_strs, names, space=10, verbose=verbose)
|
||||
info_str_dict = compare_results(
|
||||
head_strs, value_strs, names, space=10, verbose=verbose
|
||||
)
|
||||
info_value_dict = dict(heads=head_strs, values=value_strs, names=names)
|
||||
return info_str_dict, info_value_dict
|
||||
|
||||
@@ -169,9 +184,18 @@ if __name__ == "__main__":
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
parser.add_argument(
|
||||
"--save_dir", type=str, nargs="+", default=["./outputs/qlib-baselines"], help="The checkpoint directory."
|
||||
"--save_dir",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["./outputs/qlib-baselines"],
|
||||
help="The checkpoint directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Print detailed log information or not.",
|
||||
)
|
||||
parser.add_argument("--verbose", type=str2bool, default=False, help="Print detailed log information or not.")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Show results of {:}".format(args.save_dir))
|
||||
@@ -184,4 +208,11 @@ if __name__ == "__main__":
|
||||
_, info_dict = query_info(save_dir, args.verbose)
|
||||
all_info_dict.append(info_dict)
|
||||
info_dict = QResult.merge_dict(all_info_dict)
|
||||
compare_results(info_dict["heads"], info_dict["values"], info_dict["names"], space=10, verbose=True, sort_key=True)
|
||||
compare_results(
|
||||
info_dict["heads"],
|
||||
info_dict["values"],
|
||||
info_dict["names"],
|
||||
space=10,
|
||||
verbose=True,
|
||||
sort_key=True,
|
||||
)
|
||||
|
@@ -39,7 +39,10 @@ def main(xargs):
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": xargs.market,
|
||||
"infer_processors": [
|
||||
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": True}},
|
||||
{
|
||||
"class": "RobustZScoreNorm",
|
||||
"kwargs": {"fields_group": "feature", "clip_outlier": True},
|
||||
},
|
||||
{"class": "Fillna", "kwargs": {"fields_group": "feature"}},
|
||||
],
|
||||
"learn_processors": [
|
||||
@@ -90,7 +93,11 @@ def main(xargs):
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{"class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict()},
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
"kwargs": dict(),
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
@@ -111,18 +118,37 @@ def main(xargs):
|
||||
for irun in range(xargs.times):
|
||||
xmodel_config = model_config.copy()
|
||||
xmodel_config = update_gpu(xmodel_config, xargs.gpu)
|
||||
task_config = dict(model=xmodel_config, dataset=dataset_config, record=record_config)
|
||||
task_config = dict(
|
||||
model=xmodel_config, dataset=dataset_config, record=record_config
|
||||
)
|
||||
|
||||
run_exp(task_config, dataset, xargs.name, "recorder-{:02d}-{:02d}".format(irun, xargs.times), save_dir)
|
||||
run_exp(
|
||||
task_config,
|
||||
dataset,
|
||||
xargs.name,
|
||||
"recorder-{:02d}-{:02d}".format(irun, xargs.times),
|
||||
save_dir,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Vanilla Transformable Transformer")
|
||||
parser.add_argument("--save_dir", type=str, default="./outputs/vtt-runs", help="The checkpoint directory.")
|
||||
parser.add_argument("--name", type=str, default="Transformer", help="The experiment name.")
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./outputs/vtt-runs",
|
||||
help="The checkpoint directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name", type=str, default="Transformer", help="The experiment name."
|
||||
)
|
||||
parser.add_argument("--times", type=int, default=10, help="The repeated run times.")
|
||||
parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.")
|
||||
parser.add_argument("--market", type=str, default="all", help="The market indicator.")
|
||||
parser.add_argument(
|
||||
"--gpu", type=int, default=0, help="The GPU ID used for train / test."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--market", type=str, default="all", help="The market indicator."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
Reference in New Issue
Block a user