Add name filters for exp-org
This commit is contained in:
@@ -114,9 +114,9 @@ class QuantTransformer(Model):
|
||||
|
||||
def to(self, device):
|
||||
if device is None:
|
||||
self.model.to(self.device)
|
||||
else:
|
||||
self.model.to("cpu")
|
||||
device = "cpu"
|
||||
self.device = device
|
||||
self.model.to(self.device)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
@@ -227,7 +227,7 @@ class QuantTransformer(Model):
|
||||
# Pre-fetch the potential checkpoints
|
||||
ckp_path = os.path.join(save_dir, "{:}.pth".format(self.__class__.__name__))
|
||||
if os.path.exists(ckp_path):
|
||||
ckp_data = torch.load(ckp_path)
|
||||
ckp_data = torch.load(ckp_path, map_location=self.device)
|
||||
stop_steps, best_score, best_epoch = (
|
||||
ckp_data["stop_steps"],
|
||||
ckp_data["best_score"],
|
||||
@@ -298,7 +298,7 @@ class QuantTransformer(Model):
|
||||
results_dict=results_dict,
|
||||
start_epoch=iepoch + 1,
|
||||
)
|
||||
torch.save(save_info, ckp_path)
|
||||
torch.save(save_info, ckp_path, map_location="cpu")
|
||||
self.logger.info(
|
||||
"The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch)
|
||||
)
|
||||
|
Reference in New Issue
Block a user