Update baselines

This commit is contained in:
D-X-Y
2021-03-28 10:57:20 +00:00
parent 0055511829
commit 53cb5f1fdd
4 changed files with 24 additions and 10 deletions

View File

@@ -8,16 +8,12 @@ import os, math, random
from collections import OrderedDict
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from functools import partial
from typing import Optional, Text
from qlib.utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from qlib.utils import get_or_create_path
from qlib.log import get_module_logger
import torch
@@ -308,10 +304,10 @@ class QuantTransformer(Model):
torch.cuda.empty_cache()
self.fitted = True
def predict(self, dataset, segment="test"):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("The model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.model.eval()