Update scripts

This commit is contained in:
D-X-Y
2021-03-28 00:34:21 -07:00
parent 92d0df0926
commit 0055511829
5 changed files with 9 additions and 363 deletions

View File

@@ -308,10 +308,10 @@ class QuantTransformer(Model):
torch.cuda.empty_cache()
self.fitted = True
def predict(self, dataset):
def predict(self, dataset, segment="test"):
if not self.fitted:
raise ValueError("The model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature")
index = x_test.index
self.model.eval()