Temp / 0.5
This commit is contained in:
@@ -9,7 +9,6 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from functools import partial
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
@@ -23,10 +22,11 @@ from qlib.log import get_module_logger, TimeInspector
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
import layers as xlayers
|
||||
from utils import count_parameters_in_MB
|
||||
from utils import count_parameters
|
||||
|
||||
from qlib.model.base import Model
|
||||
from qlib.data.dataset import DatasetH
|
||||
@@ -137,9 +137,11 @@ class QuantTransformer(Model):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
import pdb; pdb.set_trace()
|
||||
print('--')
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
else:
|
||||
raise ValueError("unknown loss `{:}`".format(self.loss))
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
@@ -147,8 +149,8 @@ class QuantTransformer(Model):
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
else:
|
||||
raise ValueError("unknown metric `{:}`".format(self.metric))
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
|
Reference in New Issue
Block a user