Fix small bugs

This commit is contained in:
D-X-Y
2021-08-14 16:01:07 -07:00
parent 58733c18be
commit d04edcd211
12 changed files with 95 additions and 18 deletions

View File

@@ -6,6 +6,7 @@
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py
# python exps/trading/workflow_tt.py --gpu 1 --market csi300
#####################################################
import yaml
import argparse
from xautodl.procedures.q_exps import update_gpu
@@ -57,7 +58,7 @@ def main(xargs):
model_config = {
"class": "QuantTransformer",
"module_path": "trade_models",
"module_path": "xautodl.trade_models.quant_transformer",
"kwargs": {
"net_config": None,
"opt_config": None,
@@ -108,6 +109,62 @@ def main(xargs):
provider_uri = "~/.qlib/qlib_data/cn_data"
qlib.init(provider_uri=provider_uri, region=REG_CN)
from qlib.utils import init_instance_by_config
xconfig = """
model:
class: SFM
module_path: qlib.contrib.model.pytorch_sfm
kwargs:
d_feat: 6
hidden_size: 64
output_dim: 32
freq_dim: 25
dropout_W: 0.5
dropout_U: 0.5
n_epochs: 20
lr: 1e-3
batch_size: 1600
early_stop: 20
eval_steps: 5
loss: mse
optimizer: adam
GPU: 0
"""
xconfig = """
model:
class: TabnetModel
module_path: qlib.contrib.model.pytorch_tabnet
kwargs:
d_feat: 360
pretrain: True
"""
xconfig = """
model:
class: GRU
module_path: qlib.contrib.model.pytorch_gru
kwargs:
d_feat: 6
hidden_size: 64
num_layers: 4
dropout: 0.0
n_epochs: 200
lr: 0.001
early_stop: 20
batch_size: 800
metric: loss
loss: mse
GPU: 0
"""
xconfig = yaml.safe_load(xconfig)
model = init_instance_by_config(xconfig["model"])
from xautodl.utils.flop_benchmark import count_parameters_in_MB
# print(count_parameters_in_MB(model.tabnet_model))
import pdb
pdb.set_trace()
save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market)
dataset = init_instance_by_config(dataset_config)
for irun in range(xargs.times):