Update tests for torch/cuda
This commit is contained in:
@@ -141,26 +141,25 @@ def retrieve_configs():
|
||||
return alg2configs
|
||||
|
||||
|
||||
def main(xargs, config):
|
||||
def main(alg_name, market, config, times, save_dir, gpu):
|
||||
|
||||
pprint("Run {:}".format(xargs.alg))
|
||||
config = update_market(config, xargs.market)
|
||||
config = update_gpu(config, xargs.gpu)
|
||||
pprint("Run {:}".format(alg_name))
|
||||
config = update_market(config, market)
|
||||
config = update_gpu(config, gpu)
|
||||
|
||||
qlib.init(**config.get("qlib_init"))
|
||||
dataset_config = config.get("task").get("dataset")
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
pprint("args: {:}".format(xargs))
|
||||
pprint(dataset_config)
|
||||
pprint(dataset)
|
||||
|
||||
for irun in range(xargs.times):
|
||||
for irun in range(times):
|
||||
run_exp(
|
||||
config.get("task"),
|
||||
dataset,
|
||||
xargs.alg,
|
||||
"recorder-{:02d}-{:02d}".format(irun, xargs.times),
|
||||
"{:}-{:}".format(xargs.save_dir, xargs.market),
|
||||
alg_name,
|
||||
"recorder-{:02d}-{:02d}".format(irun, times),
|
||||
"{:}-{:}".format(save_dir, market),
|
||||
)
|
||||
|
||||
|
||||
@@ -203,6 +202,13 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
if len(args.alg) == 1:
|
||||
main(args, alg2configs[args.alg[0]])
|
||||
main(
|
||||
args.alg[0],
|
||||
args.market,
|
||||
alg2configs[args.alg[0]],
|
||||
args.times,
|
||||
args.save_dir,
|
||||
args.gpu,
|
||||
)
|
||||
else:
|
||||
print("-")
|
||||
|
Reference in New Issue
Block a user