Fix small bugs
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py
|
||||
# python exps/trading/workflow_tt.py --gpu 1 --market csi300
|
||||
#####################################################
|
||||
import yaml
|
||||
import argparse
|
||||
|
||||
from xautodl.procedures.q_exps import update_gpu
|
||||
@@ -57,7 +58,7 @@ def main(xargs):
|
||||
|
||||
model_config = {
|
||||
"class": "QuantTransformer",
|
||||
"module_path": "trade_models",
|
||||
"module_path": "xautodl.trade_models.quant_transformer",
|
||||
"kwargs": {
|
||||
"net_config": None,
|
||||
"opt_config": None,
|
||||
@@ -108,6 +109,62 @@ def main(xargs):
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
xconfig = """
|
||||
model:
|
||||
class: SFM
|
||||
module_path: qlib.contrib.model.pytorch_sfm
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
output_dim: 32
|
||||
freq_dim: 25
|
||||
dropout_W: 0.5
|
||||
dropout_U: 0.5
|
||||
n_epochs: 20
|
||||
lr: 1e-3
|
||||
batch_size: 1600
|
||||
early_stop: 20
|
||||
eval_steps: 5
|
||||
loss: mse
|
||||
optimizer: adam
|
||||
GPU: 0
|
||||
"""
|
||||
xconfig = """
|
||||
model:
|
||||
class: TabnetModel
|
||||
module_path: qlib.contrib.model.pytorch_tabnet
|
||||
kwargs:
|
||||
d_feat: 360
|
||||
pretrain: True
|
||||
"""
|
||||
xconfig = """
|
||||
model:
|
||||
class: GRU
|
||||
module_path: qlib.contrib.model.pytorch_gru
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 4
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 0.001
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
GPU: 0
|
||||
"""
|
||||
xconfig = yaml.safe_load(xconfig)
|
||||
model = init_instance_by_config(xconfig["model"])
|
||||
from xautodl.utils.flop_benchmark import count_parameters_in_MB
|
||||
|
||||
# print(count_parameters_in_MB(model.tabnet_model))
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market)
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
for irun in range(xargs.times):
|
||||
|
Reference in New Issue
Block a user