Update GeMOSA v4

This commit is contained in:
D-X-Y
2021-05-27 19:27:29 +08:00
parent 16861f0f3d
commit 08337138f1
3 changed files with 130 additions and 39 deletions

View File

@@ -98,21 +98,53 @@ class ComposeMetric(Metric):
class MSEMetric(Metric):
"""The metric for mse."""
def __init__(self, ignore_batch):
super(MSEMetric, self).__init__()
self._ignore_batch = ignore_batch
def reset(self):
self._mse = AverageMeter()
def __call__(self, predictions, targets):
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
batch = predictions.shape[0]
loss = torch.nn.functional.mse_loss(predictions.data, targets.data)
loss = loss.item()
self._mse.update(loss, batch)
loss = torch.nn.functional.mse_loss(predictions.data, targets.data).item()
if self._ignore_batch:
self._mse.update(loss, 1)
else:
self._mse.update(loss, predictions.shape[0])
return loss
else:
raise NotImplementedError
def get_info(self):
return {"mse": self._mse.avg}
return {"mse": self._mse.avg, "score": self._mse.avg}
class Top1AccMetric(Metric):
"""The metric for the top-1 accuracy."""
def __init__(self, ignore_batch):
super(Top1AccMetric, self).__init__()
self._ignore_batch = ignore_batch
def reset(self):
self._accuracy = AverageMeter()
def __call__(self, predictions, targets):
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
max_prob_indexes = torch.argmax(predictions, dim=-1)
corrects = torch.eq(max_prob_indexes, targets)
accuracy = corrects.float().mean().float()
if self._ignore_batch:
self._accuracy.update(accuracy, 1)
else: # [TODO] for 3-d tensor
self._accuracy.update(accuracy, predictions.shape[0])
return accuracy
else:
raise NotImplementedError
def get_info(self):
return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100}
class SaveMetric(Metric):