Update scripts
This commit is contained in:
@@ -308,10 +308,10 @@ class QuantTransformer(Model):
|
||||
torch.cuda.empty_cache()
|
||||
self.fitted = True
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset, segment="test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("The model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature")
|
||||
index = x_test.index
|
||||
|
||||
self.model.eval()
|
||||
|
Reference in New Issue
Block a user