Move to xautodl
This commit is contained in:
4
xautodl/trade_models/__init__.py
Normal file
4
xautodl/trade_models/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
from .transformers import get_transformer
|
102
xautodl/trade_models/naive_v1_model.py
Normal file
102
xautodl/trade_models/naive_v1_model.py
Normal file
@@ -0,0 +1,102 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 #
|
||||
##################################################
|
||||
# Use noise as prediction #
|
||||
##################################################
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
from qlib.model.base import Model
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class NAIVE_V1(Model):
|
||||
"""NAIVE Version 1 Quant Model"""
|
||||
|
||||
def __init__(self, d_feat=6, seed=None, **kwargs):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("NAIVE")
|
||||
self.logger.info("NAIVE 1st version: random noise ...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"NAIVE-V1 parameters setting: d_feat={:}, seed={:}".format(
|
||||
self.d_feat, self.seed
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
random.seed(self.seed)
|
||||
np.random.seed(self.seed)
|
||||
self._mean = None
|
||||
self._std = None
|
||||
self.fitted = False
|
||||
|
||||
def process_data(self, features):
|
||||
features = features.reshape(len(features), self.d_feat, -1)
|
||||
features = features.transpose((0, 2, 1))
|
||||
return features[:, :59, 0]
|
||||
|
||||
def mse(self, preds, labels):
|
||||
masks = ~np.isnan(labels)
|
||||
masked_preds = preds[masks]
|
||||
masked_labels = labels[masks]
|
||||
return np.square(masked_preds - masked_labels).mean()
|
||||
|
||||
def model(self, x):
|
||||
num = len(x)
|
||||
return np.random.normal(loc=self._mean, scale=self._std, size=num).astype(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
def fit(self, dataset: DatasetH):
|
||||
def _prepare_dataset(df_data):
|
||||
features = df_data["feature"].values
|
||||
features = self.process_data(features)
|
||||
labels = df_data["label"].values.squeeze()
|
||||
return dict(features=features, labels=labels)
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
train_dataset, valid_dataset, test_dataset = (
|
||||
_prepare_dataset(df_train),
|
||||
_prepare_dataset(df_valid),
|
||||
_prepare_dataset(df_test),
|
||||
)
|
||||
# df_train['feature']['CLOSE1'].values
|
||||
# train_dataset['features'][:, -1]
|
||||
masks = ~np.isnan(train_dataset["labels"])
|
||||
self._mean, self._std = np.mean(train_dataset["labels"][masks]), np.std(
|
||||
train_dataset["labels"][masks]
|
||||
)
|
||||
train_mse_loss = self.mse(
|
||||
self.model(train_dataset["features"]), train_dataset["labels"]
|
||||
)
|
||||
valid_mse_loss = self.mse(
|
||||
self.model(valid_dataset["features"]), valid_dataset["labels"]
|
||||
)
|
||||
self.logger.info("Training MSE loss: {:}".format(train_mse_loss))
|
||||
self.logger.info("Validation MSE loss: {:}".format(valid_mse_loss))
|
||||
self.fitted = True
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
raise ValueError("The model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
index = x_test.index
|
||||
|
||||
preds = self.model(self.process_data(x_test.values))
|
||||
return pd.Series(preds, index=index)
|
103
xautodl/trade_models/naive_v2_model.py
Normal file
103
xautodl/trade_models/naive_v2_model.py
Normal file
@@ -0,0 +1,103 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 #
|
||||
##################################################
|
||||
# A Simple Model that reused the prices of last day
|
||||
##################################################
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
from qlib.model.base import Model
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class NAIVE_V2(Model):
|
||||
"""NAIVE Version 2 Quant Model"""
|
||||
|
||||
def __init__(self, d_feat=6, seed=None, **kwargs):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("NAIVE")
|
||||
self.logger.info("NAIVE version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"NAIVE parameters setting: d_feat={:}, seed={:}".format(
|
||||
self.d_feat, self.seed
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
random.seed(self.seed)
|
||||
np.random.seed(self.seed)
|
||||
|
||||
self.fitted = False
|
||||
|
||||
def process_data(self, features):
|
||||
features = features.reshape(len(features), self.d_feat, -1)
|
||||
features = features.transpose((0, 2, 1))
|
||||
return features[:, :59, 0]
|
||||
|
||||
def mse(self, preds, labels):
|
||||
masks = ~np.isnan(labels)
|
||||
masked_preds = preds[masks]
|
||||
masked_labels = labels[masks]
|
||||
return np.square(masked_preds - masked_labels).mean()
|
||||
|
||||
def model(self, x):
|
||||
x = 1 / x - 1
|
||||
masks = ~np.isnan(x)
|
||||
results = []
|
||||
for rowd, rowm in zip(x, masks):
|
||||
temp = rowd[rowm]
|
||||
if rowm.any():
|
||||
results.append(float(rowd[rowm][-1]))
|
||||
else:
|
||||
results.append(0)
|
||||
return np.array(results, dtype=x.dtype)
|
||||
|
||||
def fit(self, dataset: DatasetH):
|
||||
def _prepare_dataset(df_data):
|
||||
features = df_data["feature"].values
|
||||
features = self.process_data(features)
|
||||
labels = df_data["label"].values.squeeze()
|
||||
return dict(features=features, labels=labels)
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
train_dataset, valid_dataset, test_dataset = (
|
||||
_prepare_dataset(df_train),
|
||||
_prepare_dataset(df_valid),
|
||||
_prepare_dataset(df_test),
|
||||
)
|
||||
# df_train['feature']['CLOSE1'].values
|
||||
# train_dataset['features'][:, -1]
|
||||
train_mse_loss = self.mse(
|
||||
self.model(train_dataset["features"]), train_dataset["labels"]
|
||||
)
|
||||
valid_mse_loss = self.mse(
|
||||
self.model(valid_dataset["features"]), valid_dataset["labels"]
|
||||
)
|
||||
self.logger.info("Training MSE loss: {:}".format(train_mse_loss))
|
||||
self.logger.info("Validation MSE loss: {:}".format(valid_mse_loss))
|
||||
self.fitted = True
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
raise ValueError("The model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
index = x_test.index
|
||||
|
||||
preds = self.model(self.process_data(x_test.values))
|
||||
return pd.Series(preds, index=index)
|
358
xautodl/trade_models/quant_transformer.py
Normal file
358
xautodl/trade_models/quant_transformer.py
Normal file
@@ -0,0 +1,358 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 #
|
||||
##################################################
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os, math, random
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import Optional, Text
|
||||
|
||||
from qlib.utils import get_or_create_path
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.utils.data as th_data
|
||||
|
||||
from log_utils import AverageMeter
|
||||
from utils import count_parameters
|
||||
|
||||
from xlayers import super_core
|
||||
from .transformers import DEFAULT_NET_CONFIG
|
||||
from .transformers import get_transformer
|
||||
|
||||
|
||||
from qlib.model.base import Model
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
DEFAULT_OPT_CONFIG = dict(
|
||||
epochs=200,
|
||||
lr=0.001,
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
num_workers=4,
|
||||
)
|
||||
|
||||
|
||||
def train_or_test_epoch(
|
||||
xloader, model, loss_fn, metric_fn, is_train, optimizer, device
|
||||
):
|
||||
if is_train:
|
||||
model.train()
|
||||
else:
|
||||
model.eval()
|
||||
score_meter, loss_meter = AverageMeter(), AverageMeter()
|
||||
for ibatch, (feats, labels) in enumerate(xloader):
|
||||
feats, labels = feats.to(device), labels.to(device)
|
||||
# forward the network
|
||||
preds = model(feats)
|
||||
loss = loss_fn(preds, labels)
|
||||
with torch.no_grad():
|
||||
score = metric_fn(preds, labels)
|
||||
loss_meter.update(loss.item(), feats.size(0))
|
||||
score_meter.update(score.item(), feats.size(0))
|
||||
# optimize the network
|
||||
if is_train and optimizer is not None:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(model.parameters(), 3.0)
|
||||
optimizer.step()
|
||||
return loss_meter.avg, score_meter.avg
|
||||
|
||||
|
||||
class QuantTransformer(Model):
|
||||
"""Transformer-based Quant Model"""
|
||||
|
||||
def __init__(
|
||||
self, net_config=None, opt_config=None, metric="", GPU=0, seed=None, **kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("QuantTransformer")
|
||||
self.logger.info("QuantTransformer PyTorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.net_config = net_config or DEFAULT_NET_CONFIG
|
||||
self.opt_config = opt_config or DEFAULT_OPT_CONFIG
|
||||
self.metric = metric
|
||||
self.device = torch.device(
|
||||
"cuda:{:}".format(GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu"
|
||||
)
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"Transformer parameters setting:"
|
||||
"\nnet_config : {:}"
|
||||
"\nopt_config : {:}"
|
||||
"\nmetric : {:}"
|
||||
"\ndevice : {:}"
|
||||
"\nseed : {:}".format(
|
||||
self.net_config,
|
||||
self.opt_config,
|
||||
self.metric,
|
||||
self.device,
|
||||
self.seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
random.seed(self.seed)
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
if self.use_gpu:
|
||||
torch.cuda.manual_seed(self.seed)
|
||||
torch.cuda.manual_seed_all(self.seed)
|
||||
|
||||
self.model = get_transformer(self.net_config)
|
||||
self.model.set_super_run_type(super_core.SuperRunMode.FullModel)
|
||||
self.logger.info("model: {:}".format(self.model))
|
||||
self.logger.info("model size: {:.3f} MB".format(count_parameters(self.model)))
|
||||
|
||||
if self.opt_config["optimizer"] == "adam":
|
||||
self.train_optimizer = optim.Adam(
|
||||
self.model.parameters(), lr=self.opt_config["lr"]
|
||||
)
|
||||
elif self.opt_config["optimizer"] == "adam":
|
||||
self.train_optimizer = optim.SGD(
|
||||
self.model.parameters(), lr=self.opt_config["lr"]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"optimizer {:} is not supported!".format(optimizer)
|
||||
)
|
||||
|
||||
self.fitted = False
|
||||
self.model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def to(self, device):
|
||||
if device is None:
|
||||
device = "cpu"
|
||||
self.device = device
|
||||
self.model.to(self.device)
|
||||
# move the optimizer
|
||||
for param in self.train_optimizer.state.values():
|
||||
# Not sure there are any global tensors in the state dict
|
||||
if isinstance(param, torch.Tensor):
|
||||
param.data = param.data.to(device)
|
||||
if param._grad is not None:
|
||||
param._grad.data = param._grad.data.to(device)
|
||||
elif isinstance(param, dict):
|
||||
for subparam in param.values():
|
||||
if isinstance(subparam, torch.Tensor):
|
||||
subparam.data = subparam.data.to(device)
|
||||
if subparam._grad is not None:
|
||||
subparam._grad.data = subparam._grad.data.to(device)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
if self.opt_config["loss"] == "mse":
|
||||
return F.mse_loss(pred[mask], label[mask])
|
||||
else:
|
||||
raise ValueError("unknown loss `{:}`".format(self.loss))
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
# the metric score : higher is better
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred, label)
|
||||
else:
|
||||
raise ValueError("unknown metric `{:}`".format(self.metric))
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
save_dir: Optional[Text] = None,
|
||||
):
|
||||
def _prepare_dataset(df_data):
|
||||
return th_data.TensorDataset(
|
||||
torch.from_numpy(df_data["feature"].values).float(),
|
||||
torch.from_numpy(df_data["label"].values).squeeze().float(),
|
||||
)
|
||||
|
||||
def _prepare_loader(dataset, shuffle):
|
||||
return th_data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.opt_config["batch_size"],
|
||||
drop_last=False,
|
||||
pin_memory=True,
|
||||
num_workers=self.opt_config["num_workers"],
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
train_dataset, valid_dataset, test_dataset = (
|
||||
_prepare_dataset(df_train),
|
||||
_prepare_dataset(df_valid),
|
||||
_prepare_dataset(df_test),
|
||||
)
|
||||
train_loader, valid_loader, test_loader = (
|
||||
_prepare_loader(train_dataset, True),
|
||||
_prepare_loader(valid_dataset, False),
|
||||
_prepare_loader(test_dataset, False),
|
||||
)
|
||||
|
||||
save_dir = get_or_create_path(save_dir, return_dir=True)
|
||||
self.logger.info(
|
||||
"Fit procedure for [{:}] with save path={:}".format(
|
||||
self.__class__.__name__, save_dir
|
||||
)
|
||||
)
|
||||
|
||||
def _internal_test(ckp_epoch=None, results_dict=None):
|
||||
with torch.no_grad():
|
||||
shared_kwards = {
|
||||
"model": self.model,
|
||||
"loss_fn": self.loss_fn,
|
||||
"metric_fn": self.metric_fn,
|
||||
"is_train": False,
|
||||
"optimizer": None,
|
||||
"device": self.device,
|
||||
}
|
||||
train_loss, train_score = train_or_test_epoch(
|
||||
train_loader, **shared_kwards
|
||||
)
|
||||
valid_loss, valid_score = train_or_test_epoch(
|
||||
valid_loader, **shared_kwards
|
||||
)
|
||||
test_loss, test_score = train_or_test_epoch(
|
||||
test_loader, **shared_kwards
|
||||
)
|
||||
xstr = (
|
||||
"train-score={:.6f}, valid-score={:.6f}, test-score={:.6f}".format(
|
||||
train_score, valid_score, test_score
|
||||
)
|
||||
)
|
||||
if ckp_epoch is not None and isinstance(results_dict, dict):
|
||||
results_dict["train"][ckp_epoch] = train_score
|
||||
results_dict["valid"][ckp_epoch] = valid_score
|
||||
results_dict["test"][ckp_epoch] = test_score
|
||||
return dict(train=train_score, valid=valid_score, test=test_score), xstr
|
||||
|
||||
# Pre-fetch the potential checkpoints
|
||||
ckp_path = os.path.join(save_dir, "{:}.pth".format(self.__class__.__name__))
|
||||
if os.path.exists(ckp_path):
|
||||
ckp_data = torch.load(ckp_path, map_location=self.device)
|
||||
stop_steps, best_score, best_epoch = (
|
||||
ckp_data["stop_steps"],
|
||||
ckp_data["best_score"],
|
||||
ckp_data["best_epoch"],
|
||||
)
|
||||
start_epoch, best_param = ckp_data["start_epoch"], ckp_data["best_param"]
|
||||
results_dict = ckp_data["results_dict"]
|
||||
self.model.load_state_dict(ckp_data["net_state_dict"])
|
||||
self.train_optimizer.load_state_dict(ckp_data["opt_state_dict"])
|
||||
self.logger.info("Resume from existing checkpoint: {:}".format(ckp_path))
|
||||
else:
|
||||
stop_steps, best_score, best_epoch = 0, -np.inf, -1
|
||||
start_epoch, best_param = 0, None
|
||||
results_dict = dict(
|
||||
train=OrderedDict(), valid=OrderedDict(), test=OrderedDict()
|
||||
)
|
||||
_, eval_str = _internal_test(-1, results_dict)
|
||||
self.logger.info(
|
||||
"Training from scratch, metrics@start: {:}".format(eval_str)
|
||||
)
|
||||
|
||||
for iepoch in range(start_epoch, self.opt_config["epochs"]):
|
||||
self.logger.info(
|
||||
"Epoch={:03d}/{:03d} ::==>> Best valid @{:03d} ({:.6f})".format(
|
||||
iepoch, self.opt_config["epochs"], best_epoch, best_score
|
||||
)
|
||||
)
|
||||
train_loss, train_score = train_or_test_epoch(
|
||||
train_loader,
|
||||
self.model,
|
||||
self.loss_fn,
|
||||
self.metric_fn,
|
||||
True,
|
||||
self.train_optimizer,
|
||||
self.device,
|
||||
)
|
||||
self.logger.info(
|
||||
"Training :: loss={:.6f}, score={:.6f}".format(train_loss, train_score)
|
||||
)
|
||||
|
||||
current_eval_scores, eval_str = _internal_test(iepoch, results_dict)
|
||||
self.logger.info("Evaluating :: {:}".format(eval_str))
|
||||
|
||||
if current_eval_scores["valid"] > best_score:
|
||||
stop_steps, best_epoch, best_score = (
|
||||
0,
|
||||
iepoch,
|
||||
current_eval_scores["valid"],
|
||||
)
|
||||
best_param = copy.deepcopy(self.model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.opt_config["early_stop"]:
|
||||
self.logger.info(
|
||||
"early stop at {:}-th epoch, where the best is @{:}".format(
|
||||
iepoch, best_epoch
|
||||
)
|
||||
)
|
||||
break
|
||||
save_info = dict(
|
||||
net_config=self.net_config,
|
||||
opt_config=self.opt_config,
|
||||
net_state_dict=self.model.state_dict(),
|
||||
opt_state_dict=self.train_optimizer.state_dict(),
|
||||
best_param=best_param,
|
||||
stop_steps=stop_steps,
|
||||
best_score=best_score,
|
||||
best_epoch=best_epoch,
|
||||
results_dict=results_dict,
|
||||
start_epoch=iepoch + 1,
|
||||
)
|
||||
torch.save(save_info, ckp_path)
|
||||
self.logger.info(
|
||||
"The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch)
|
||||
)
|
||||
self.model.load_state_dict(best_param)
|
||||
_, eval_str = _internal_test("final", results_dict)
|
||||
self.logger.info("Reload the best parameter :: {:}".format(eval_str))
|
||||
|
||||
if self.use_gpu:
|
||||
with torch.cuda.device(self.device):
|
||||
torch.cuda.empty_cache()
|
||||
self.fitted = True
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("The model is not fitted yet!")
|
||||
x_test = dataset.prepare(
|
||||
segment, col_set="feature", data_key=DataHandlerLP.DK_I
|
||||
)
|
||||
index = x_test.index
|
||||
|
||||
with torch.no_grad():
|
||||
self.model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"]
|
||||
preds = []
|
||||
for begin in range(sample_num)[::batch_size]:
|
||||
if sample_num - begin < batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + batch_size
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
with torch.no_grad():
|
||||
pred = self.model(x_batch).detach().cpu().numpy()
|
||||
preds.append(pred)
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
198
xautodl/trade_models/transformers.py
Normal file
198
xautodl/trade_models/transformers.py
Normal file
@@ -0,0 +1,198 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Optional, Text, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import spaces
|
||||
from xlayers import trunc_normal_
|
||||
from xlayers import super_core
|
||||
|
||||
|
||||
__all__ = ["DefaultSearchSpace", "DEFAULT_NET_CONFIG", "get_transformer"]
|
||||
|
||||
|
||||
def _get_mul_specs(candidates, num):
|
||||
results = []
|
||||
for i in range(num):
|
||||
results.append(spaces.Categorical(*candidates))
|
||||
return results
|
||||
|
||||
|
||||
def _get_list_mul(num, multipler):
|
||||
results = []
|
||||
for i in range(1, num + 1):
|
||||
results.append(i * multipler)
|
||||
return results
|
||||
|
||||
|
||||
def _assert_types(x, expected_types):
|
||||
if not isinstance(x, expected_types):
|
||||
raise TypeError(
|
||||
"The type [{:}] is expected to be {:}.".format(type(x), expected_types)
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_NET_CONFIG = None
|
||||
_default_max_depth = 5
|
||||
DefaultSearchSpace = dict(
|
||||
d_feat=6,
|
||||
embed_dim=spaces.Categorical(*_get_list_mul(8, 16)),
|
||||
num_heads=_get_mul_specs((1, 2, 4, 8), _default_max_depth),
|
||||
mlp_hidden_multipliers=_get_mul_specs((0.5, 1, 2, 4, 8), _default_max_depth),
|
||||
qkv_bias=True,
|
||||
pos_drop=0.0,
|
||||
other_drop=0.0,
|
||||
)
|
||||
|
||||
|
||||
class SuperTransformer(super_core.SuperModule):
|
||||
"""The super model for transformer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat: int = 6,
|
||||
embed_dim: List[super_core.IntSpaceType] = DefaultSearchSpace["embed_dim"],
|
||||
num_heads: List[super_core.IntSpaceType] = DefaultSearchSpace["num_heads"],
|
||||
mlp_hidden_multipliers: List[super_core.IntSpaceType] = DefaultSearchSpace[
|
||||
"mlp_hidden_multipliers"
|
||||
],
|
||||
qkv_bias: bool = DefaultSearchSpace["qkv_bias"],
|
||||
pos_drop: float = DefaultSearchSpace["pos_drop"],
|
||||
other_drop: float = DefaultSearchSpace["other_drop"],
|
||||
max_seq_len: int = 65,
|
||||
):
|
||||
super(SuperTransformer, self).__init__()
|
||||
self._embed_dim = embed_dim
|
||||
self._num_heads = num_heads
|
||||
self._mlp_hidden_multipliers = mlp_hidden_multipliers
|
||||
|
||||
# the stem part
|
||||
self.input_embed = super_core.SuperAlphaEBDv1(d_feat, embed_dim)
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
||||
self.pos_embed = super_core.SuperPositionalEncoder(
|
||||
d_model=embed_dim, max_seq_len=max_seq_len, dropout=pos_drop
|
||||
)
|
||||
# build the transformer encode layers -->> check params
|
||||
_assert_types(num_heads, (tuple, list))
|
||||
_assert_types(mlp_hidden_multipliers, (tuple, list))
|
||||
assert len(num_heads) == len(mlp_hidden_multipliers), "{:} vs {:}".format(
|
||||
len(num_heads), len(mlp_hidden_multipliers)
|
||||
)
|
||||
# build the transformer encode layers -->> backbone
|
||||
layers = []
|
||||
for num_head, mlp_hidden_multiplier in zip(num_heads, mlp_hidden_multipliers):
|
||||
layer = super_core.SuperTransformerEncoderLayer(
|
||||
embed_dim,
|
||||
num_head,
|
||||
qkv_bias,
|
||||
mlp_hidden_multiplier,
|
||||
other_drop,
|
||||
)
|
||||
layers.append(layer)
|
||||
self.backbone = super_core.SuperSequential(*layers)
|
||||
|
||||
# the regression head
|
||||
self.head = super_core.SuperSequential(
|
||||
super_core.SuperLayerNorm1D(embed_dim), super_core.SuperLinear(embed_dim, 1)
|
||||
)
|
||||
trunc_normal_(self.cls_token, std=0.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
@property
|
||||
def embed_dim(self):
|
||||
return spaces.get_max(self._embed_dim)
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
root_node = spaces.VirtualNode(id(self))
|
||||
if not spaces.is_determined(self._embed_dim):
|
||||
root_node.append("_embed_dim", self._embed_dim.abstract(reuse_last=True))
|
||||
xdict = dict(
|
||||
input_embed=self.input_embed.abstract_search_space,
|
||||
pos_embed=self.pos_embed.abstract_search_space,
|
||||
backbone=self.backbone.abstract_search_space,
|
||||
head=self.head.abstract_search_space,
|
||||
)
|
||||
for key, space in xdict.items():
|
||||
if not spaces.is_determined(space):
|
||||
root_node.append(key, space)
|
||||
return root_node
|
||||
|
||||
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
||||
super(SuperTransformer, self).apply_candidate(abstract_child)
|
||||
xkeys = ("input_embed", "pos_embed", "backbone", "head")
|
||||
for key in xkeys:
|
||||
if key in abstract_child:
|
||||
getattr(self, key).apply_candidate(abstract_child[key])
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, super_core.SuperLinear):
|
||||
trunc_normal_(m._super_weight, std=0.02)
|
||||
if m._super_bias is not None:
|
||||
nn.init.constant_(m._super_bias, 0)
|
||||
elif isinstance(m, super_core.SuperLayerNorm1D):
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
batch, flatten_size = input.shape
|
||||
feats = self.input_embed(input) # batch * 60 * 64
|
||||
if not spaces.is_determined(self._embed_dim):
|
||||
embed_dim = self.abstract_child["_embed_dim"].value
|
||||
else:
|
||||
embed_dim = spaces.get_determined_value(self._embed_dim)
|
||||
cls_tokens = self.cls_token.expand(batch, -1, -1)
|
||||
cls_tokens = F.interpolate(
|
||||
cls_tokens, size=(embed_dim), mode="linear", align_corners=True
|
||||
)
|
||||
feats_w_ct = torch.cat((cls_tokens, feats), dim=1)
|
||||
feats_w_tp = self.pos_embed(feats_w_ct)
|
||||
xfeats = self.backbone(feats_w_tp)
|
||||
xfeats = xfeats[:, 0, :] # use the feature for the first token
|
||||
predicts = self.head(xfeats).squeeze(-1)
|
||||
return predicts
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
batch, flatten_size = input.shape
|
||||
feats = self.input_embed(input) # batch * 60 * 64
|
||||
cls_tokens = self.cls_token.expand(batch, -1, -1)
|
||||
feats_w_ct = torch.cat((cls_tokens, feats), dim=1)
|
||||
feats_w_tp = self.pos_embed(feats_w_ct)
|
||||
xfeats = self.backbone(feats_w_tp)
|
||||
xfeats = xfeats[:, 0, :] # use the feature for the first token
|
||||
predicts = self.head(xfeats).squeeze(-1)
|
||||
return predicts
|
||||
|
||||
|
||||
def get_transformer(config):
|
||||
if config is None:
|
||||
return SuperTransformer(6)
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError("Invalid Configuration: {:}".format(config))
|
||||
name = config.get("name", "basic")
|
||||
if name == "basic":
|
||||
model = SuperTransformer(
|
||||
d_feat=config.get("d_feat"),
|
||||
embed_dim=config.get("embed_dim"),
|
||||
num_heads=config.get("num_heads"),
|
||||
mlp_hidden_multipliers=config.get("mlp_hidden_multipliers"),
|
||||
qkv_bias=config.get("qkv_bias"),
|
||||
pos_drop=config.get("pos_drop"),
|
||||
other_drop=config.get("other_drop"),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown model name: {:}".format(name))
|
||||
return model
|
Reference in New Issue
Block a user