Update GeMOSA v4
This commit is contained in:
@@ -98,21 +98,53 @@ class ComposeMetric(Metric):
|
||||
class MSEMetric(Metric):
|
||||
"""The metric for mse."""
|
||||
|
||||
def __init__(self, ignore_batch):
|
||||
super(MSEMetric, self).__init__()
|
||||
self._ignore_batch = ignore_batch
|
||||
|
||||
def reset(self):
|
||||
self._mse = AverageMeter()
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
|
||||
batch = predictions.shape[0]
|
||||
loss = torch.nn.functional.mse_loss(predictions.data, targets.data)
|
||||
loss = loss.item()
|
||||
self._mse.update(loss, batch)
|
||||
loss = torch.nn.functional.mse_loss(predictions.data, targets.data).item()
|
||||
if self._ignore_batch:
|
||||
self._mse.update(loss, 1)
|
||||
else:
|
||||
self._mse.update(loss, predictions.shape[0])
|
||||
return loss
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
return {"mse": self._mse.avg}
|
||||
return {"mse": self._mse.avg, "score": self._mse.avg}
|
||||
|
||||
|
||||
class Top1AccMetric(Metric):
|
||||
"""The metric for the top-1 accuracy."""
|
||||
|
||||
def __init__(self, ignore_batch):
|
||||
super(Top1AccMetric, self).__init__()
|
||||
self._ignore_batch = ignore_batch
|
||||
|
||||
def reset(self):
|
||||
self._accuracy = AverageMeter()
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
|
||||
max_prob_indexes = torch.argmax(predictions, dim=-1)
|
||||
corrects = torch.eq(max_prob_indexes, targets)
|
||||
accuracy = corrects.float().mean().float()
|
||||
if self._ignore_batch:
|
||||
self._accuracy.update(accuracy, 1)
|
||||
else: # [TODO] for 3-d tensor
|
||||
self._accuracy.update(accuracy, predictions.shape[0])
|
||||
return accuracy
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100}
|
||||
|
||||
|
||||
class SaveMetric(Metric):
|
||||
|
Reference in New Issue
Block a user