Fix CUDA memory issues
This commit is contained in:
@@ -45,6 +45,32 @@ DEFAULT_OPT_CONFIG = dict(
|
||||
)
|
||||
|
||||
|
||||
def train_or_test_epoch(
|
||||
xloader, model, loss_fn, metric_fn, is_train, optimizer, device
|
||||
):
|
||||
if is_train:
|
||||
model.train()
|
||||
else:
|
||||
model.eval()
|
||||
score_meter, loss_meter = AverageMeter(), AverageMeter()
|
||||
for ibatch, (feats, labels) in enumerate(xloader):
|
||||
feats, labels = feats.to(device), labels.to(device)
|
||||
# forward the network
|
||||
preds = model(feats)
|
||||
loss = loss_fn(preds, labels)
|
||||
with torch.no_grad():
|
||||
score = metric_fn(preds, labels)
|
||||
loss_meter.update(loss.item(), feats.size(0))
|
||||
score_meter.update(score.item(), feats.size(0))
|
||||
# optimize the network
|
||||
if is_train and optimizer is not None:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(model.parameters(), 3.0)
|
||||
optimizer.step()
|
||||
return loss_meter.avg, score_meter.avg
|
||||
|
||||
|
||||
class QuantTransformer(Model):
|
||||
"""Transformer-based Quant Model"""
|
||||
|
||||
@@ -132,32 +158,6 @@ class QuantTransformer(Model):
|
||||
else:
|
||||
raise ValueError("unknown metric `{:}`".format(self.metric))
|
||||
|
||||
def train_or_test_epoch(
|
||||
self, xloader, model, loss_fn, metric_fn, is_train, optimizer=None
|
||||
):
|
||||
if is_train:
|
||||
model.train()
|
||||
else:
|
||||
model.eval()
|
||||
score_meter, loss_meter = AverageMeter(), AverageMeter()
|
||||
for ibatch, (feats, labels) in enumerate(xloader):
|
||||
feats = feats.to(self.device, non_blocking=True)
|
||||
labels = labels.to(self.device, non_blocking=True)
|
||||
# forward the network
|
||||
preds = model(feats)
|
||||
loss = loss_fn(preds, labels)
|
||||
with torch.no_grad():
|
||||
score = self.metric_fn(preds, labels)
|
||||
loss_meter.update(loss.item(), feats.size(0))
|
||||
score_meter.update(score.item(), feats.size(0))
|
||||
# optimize the network
|
||||
if is_train and optimizer is not None:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(model.parameters(), 3.0)
|
||||
optimizer.step()
|
||||
return loss_meter.avg, score_meter.avg
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
@@ -204,14 +204,22 @@ class QuantTransformer(Model):
|
||||
|
||||
def _internal_test(ckp_epoch=None, results_dict=None):
|
||||
with torch.no_grad():
|
||||
train_loss, train_score = self.train_or_test_epoch(
|
||||
train_loader, self.model, self.loss_fn, self.metric_fn, False, None
|
||||
shared_kwards = {
|
||||
"model": self.model,
|
||||
"loss_fn": self.loss_fn,
|
||||
"metric_fn": self.metric_fn,
|
||||
"is_train": False,
|
||||
"optimizer": None,
|
||||
"device": self.device,
|
||||
}
|
||||
train_loss, train_score = train_or_test_epoch(
|
||||
train_loader, **shared_kwards
|
||||
)
|
||||
valid_loss, valid_score = self.train_or_test_epoch(
|
||||
valid_loader, self.model, self.loss_fn, self.metric_fn, False, None
|
||||
valid_loss, valid_score = train_or_test_epoch(
|
||||
valid_loader, **shared_kwards
|
||||
)
|
||||
test_loss, test_score = self.train_or_test_epoch(
|
||||
test_loader, self.model, self.loss_fn, self.metric_fn, False, None
|
||||
test_loss, test_score = train_or_test_epoch(
|
||||
test_loader, **shared_kwards
|
||||
)
|
||||
xstr = (
|
||||
"train-score={:.6f}, valid-score={:.6f}, test-score={:.6f}".format(
|
||||
@@ -255,13 +263,14 @@ class QuantTransformer(Model):
|
||||
iepoch, self.opt_config["epochs"], best_epoch, best_score
|
||||
)
|
||||
)
|
||||
train_loss, train_score = self.train_or_test_epoch(
|
||||
train_loss, train_score = train_or_test_epoch(
|
||||
train_loader,
|
||||
self.model,
|
||||
self.loss_fn,
|
||||
self.metric_fn,
|
||||
True,
|
||||
self.train_optimizer,
|
||||
self.device,
|
||||
)
|
||||
self.logger.info(
|
||||
"Training :: loss={:.6f}, score={:.6f}".format(train_loss, train_score)
|
||||
@@ -307,7 +316,8 @@ class QuantTransformer(Model):
|
||||
self.logger.info("Reload the best parameter :: {:}".format(eval_str))
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
with torch.cuda.device(self.device):
|
||||
torch.cuda.empty_cache()
|
||||
self.fitted = True
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
|
Reference in New Issue
Block a user