Temp / 0.5

This commit is contained in:
D-X-Y
2021-03-05 13:50:30 +00:00
parent 2fa358fdf6
commit cc28e1589e
4 changed files with 35 additions and 10 deletions

View File

@@ -9,7 +9,6 @@ import numpy as np
import pandas as pd
import copy
from functools import partial
from sklearn.metrics import roc_auc_score, mean_squared_error
from typing import Optional
import logging
@@ -23,10 +22,11 @@ from qlib.log import get_module_logger, TimeInspector
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import layers as xlayers
from utils import count_parameters_in_MB
from utils import count_parameters
from qlib.model.base import Model
from qlib.data.dataset import DatasetH
@@ -137,9 +137,11 @@ class QuantTransformer(Model):
mask = ~torch.isnan(label)
if self.loss == "mse":
import pdb; pdb.set_trace()
print('--')
return self.mse(pred[mask], label[mask])
raise ValueError("unknown loss `%s`" % self.loss)
else:
raise ValueError("unknown loss `{:}`".format(self.loss))
def metric_fn(self, pred, label):
@@ -147,8 +149,8 @@ class QuantTransformer(Model):
if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
else:
raise ValueError("unknown metric `{:}`".format(self.metric))
def train_epoch(self, x_train, y_train):