Fix bugs in .to(cpu)
This commit is contained in:
@@ -143,6 +143,19 @@ class QuantTransformer(Model):
|
||||
device = "cpu"
|
||||
self.device = device
|
||||
self.model.to(self.device)
|
||||
# move the optimizer
|
||||
for param in self.train_optimizer.state.values():
|
||||
# Not sure there are any global tensors in the state dict
|
||||
if isinstance(param, torch.Tensor):
|
||||
param.data = param.data.to(device)
|
||||
if param._grad is not None:
|
||||
param._grad.data = param._grad.data.to(device)
|
||||
elif isinstance(param, dict):
|
||||
for subparam in param.values():
|
||||
if isinstance(subparam, torch.Tensor):
|
||||
subparam.data = subparam.data.to(device)
|
||||
if subparam._grad is not None:
|
||||
subparam._grad.data = subparam._grad.data.to(device)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
Reference in New Issue
Block a user