Update tests for torch/cuda

This commit is contained in:
D-X-Y
2021-03-30 12:05:52 +00:00
parent c2270fd153
commit e5ec43e04a
12 changed files with 386 additions and 248 deletions

View File

@@ -141,26 +141,25 @@ def retrieve_configs():
return alg2configs
def main(xargs, config):
def main(alg_name, market, config, times, save_dir, gpu):
pprint("Run {:}".format(xargs.alg))
config = update_market(config, xargs.market)
config = update_gpu(config, xargs.gpu)
pprint("Run {:}".format(alg_name))
config = update_market(config, market)
config = update_gpu(config, gpu)
qlib.init(**config.get("qlib_init"))
dataset_config = config.get("task").get("dataset")
dataset = init_instance_by_config(dataset_config)
pprint("args: {:}".format(xargs))
pprint(dataset_config)
pprint(dataset)
for irun in range(xargs.times):
for irun in range(times):
run_exp(
config.get("task"),
dataset,
xargs.alg,
"recorder-{:02d}-{:02d}".format(irun, xargs.times),
"{:}-{:}".format(xargs.save_dir, xargs.market),
alg_name,
"recorder-{:02d}-{:02d}".format(irun, times),
"{:}-{:}".format(save_dir, market),
)
@@ -203,6 +202,13 @@ if __name__ == "__main__":
args = parser.parse_args()
if len(args.alg) == 1:
main(args, alg2configs[args.alg[0]])
main(
args.alg[0],
args.market,
alg2configs[args.alg[0]],
args.times,
args.save_dir,
args.gpu,
)
else:
print("-")