Update organize
This commit is contained in:
@@ -307,25 +307,23 @@ class QuantTransformer(Model):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("The model is not fitted yet!")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(
|
||||
segment, col_set="feature", data_key=DataHandlerLP.DK_I
|
||||
)
|
||||
index = x_test.index
|
||||
|
||||
self.model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[::batch_size]:
|
||||
|
||||
if sample_num - begin < batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model(x_batch).detach().cpu().numpy()
|
||||
preds.append(pred)
|
||||
|
||||
with torch.no_grad():
|
||||
self.model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"]
|
||||
preds = []
|
||||
for begin in range(sample_num)[::batch_size]:
|
||||
if sample_num - begin < batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + batch_size
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
with torch.no_grad():
|
||||
pred = self.model(x_batch).detach().cpu().numpy()
|
||||
preds.append(pred)
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
Reference in New Issue
Block a user