Update models

This commit is contained in:
D-X-Y
2021-03-29 05:04:24 +00:00
parent e637cddc39
commit b51320dfb1
7 changed files with 85 additions and 14 deletions

View File

@@ -112,6 +112,12 @@ class QuantTransformer(Model):
def use_gpu(self):
return self.device != torch.device("cpu")
def to(self, device):
if device is None:
self.model.to(self.device)
else:
self.model.to("cpu")
def loss_fn(self, pred, label):
mask = ~torch.isnan(label)
if self.opt_config["loss"] == "mse":