Update baselines
This commit is contained in:
@@ -8,16 +8,12 @@ import os, math, random
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import Optional, Text
|
||||
|
||||
from qlib.utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from qlib.utils import get_or_create_path
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -308,10 +304,10 @@ class QuantTransformer(Model):
|
||||
torch.cuda.empty_cache()
|
||||
self.fitted = True
|
||||
|
||||
def predict(self, dataset, segment="test"):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("The model is not fitted yet!")
|
||||
x_test = dataset.prepare(segment, col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
|
||||
self.model.eval()
|
||||
|
Reference in New Issue
Block a user