Update metrics
This commit is contained in:
@@ -1,3 +1,8 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
|
||||
#####################################################
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
@@ -20,3 +25,133 @@ class AverageMeter:
|
||||
return "{name}(val={val}, avg={avg}, count={count})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
|
||||
class Metric(abc.ABC):
|
||||
"""The default meta metric class."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def perf_str(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({inner})".format(
|
||||
name=self.__class__.__name__, inner=self.inner_repr()
|
||||
)
|
||||
|
||||
def inner_repr(self):
|
||||
return ""
|
||||
|
||||
|
||||
class ComposeMetric(Metric):
|
||||
"""The composed metric class."""
|
||||
|
||||
def __init__(self, *metric_list):
|
||||
self.reset()
|
||||
for metric in metric_list:
|
||||
self.append(metric)
|
||||
|
||||
def reset(self):
|
||||
self._metric_list = []
|
||||
|
||||
def append(self, metric):
|
||||
if not isinstance(metric, Metric):
|
||||
raise ValueError(
|
||||
"The input metric is not correct: {:}".format(type(metric))
|
||||
)
|
||||
self._metric_list.append(metric)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._metric_list)
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
results = list()
|
||||
for metric in self._metric_list:
|
||||
results.append(metric(predictions, targets))
|
||||
return results
|
||||
|
||||
def get_info(self):
|
||||
results = dict()
|
||||
for metric in self._metric_list:
|
||||
for key, value in metric.get_info().items():
|
||||
results[key] = value
|
||||
return results
|
||||
|
||||
def inner_repr(self):
|
||||
xlist = []
|
||||
for metric in self._metric_list:
|
||||
xlist.append(str(metric))
|
||||
return ",".join(xlist)
|
||||
|
||||
|
||||
class CrossEntropyMetric(Metric):
|
||||
"""The metric for the cross entropy metric."""
|
||||
|
||||
def __init__(self, ignore_batch):
|
||||
super(CrossEntropyMetric, self).__init__()
|
||||
self._ignore_batch = ignore_batch
|
||||
|
||||
def reset(self):
|
||||
self._loss = AverageMeter()
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
|
||||
batch, _ = predictions.shape() # only support 2-D tensor
|
||||
max_prob_indexes = torch.argmax(predictions, dim=-1)
|
||||
if self._ignore_batch:
|
||||
loss = F.cross_entropy(predictions, targets, reduction="sum")
|
||||
self._loss.update(loss.item(), 1)
|
||||
else:
|
||||
loss = F.cross_entropy(predictions, targets, reduction="mean")
|
||||
self._loss.update(loss.item(), batch)
|
||||
return loss
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
return {"loss": self._loss.avg, "score": self._loss.avg * 100}
|
||||
|
||||
def perf_str(self):
|
||||
return "ce-loss={:.5f}".format(self._loss.avg)
|
||||
|
||||
|
||||
class Top1AccMetric(Metric):
|
||||
"""The metric for the top-1 accuracy."""
|
||||
|
||||
def __init__(self, ignore_batch):
|
||||
super(Top1AccMetric, self).__init__()
|
||||
self._ignore_batch = ignore_batch
|
||||
|
||||
def reset(self):
|
||||
self._accuracy = AverageMeter()
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
|
||||
batch, _ = predictions.shape() # only support 2-D tensor
|
||||
max_prob_indexes = torch.argmax(predictions, dim=-1)
|
||||
corrects = torch.eq(max_prob_indexes, targets)
|
||||
accuracy = corrects.float().mean().float()
|
||||
if self._ignore_batch:
|
||||
self._accuracy.update(accuracy, 1)
|
||||
else:
|
||||
self._accuracy.update(accuracy, batch)
|
||||
return accuracy
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100}
|
||||
|
||||
def perf_str(self):
|
||||
return "accuracy={:.3f}%".format(self._accuracy.avg * 100)
|
||||
|
Reference in New Issue
Block a user