This commit is contained in:
D-X-Y
2021-04-26 21:44:03 +08:00
parent 8358d71cdf
commit d3371296a7
10 changed files with 270 additions and 264 deletions

View File

@@ -1,8 +1,10 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
#####################################################
# To be finished.
#
import os, sys, time, torch
from typing import import Optional, Text, Callable
from typing import Optional, Text, Callable
# modules in AutoDL
from log_utils import AverageMeter
@@ -60,9 +62,10 @@ def procedure(
network,
criterion,
optimizer,
eval_metric,
mode: Text,
print_freq: int = 100,
logger_fn: Callable = None
logger_fn: Callable = None,
):
data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
if mode.lower() == "train":
@@ -90,7 +93,7 @@ def procedure(
optimizer.step()
# record
metrics =
metrics = eval_metric(logits.data, targets.data)
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))

View File

@@ -3,6 +3,7 @@
#####################################################
import abc
def obtain_accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
@@ -20,7 +21,6 @@ def obtain_accuracy(output, target, topk=(1,)):
class EvaluationMetric(abc.ABC):
def __init__(self):
self._total_metrics = 0