Fix bugs
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
||||
#####################################################
|
||||
# To be finished.
|
||||
#
|
||||
import os, sys, time, torch
|
||||
from typing import import Optional, Text, Callable
|
||||
from typing import Optional, Text, Callable
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter
|
||||
@@ -60,9 +62,10 @@ def procedure(
|
||||
network,
|
||||
criterion,
|
||||
optimizer,
|
||||
eval_metric,
|
||||
mode: Text,
|
||||
print_freq: int = 100,
|
||||
logger_fn: Callable = None
|
||||
logger_fn: Callable = None,
|
||||
):
|
||||
data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode.lower() == "train":
|
||||
@@ -90,7 +93,7 @@ def procedure(
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
metrics =
|
||||
metrics = eval_metric(logits.data, targets.data)
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#####################################################
|
||||
import abc
|
||||
|
||||
|
||||
def obtain_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
@@ -20,7 +21,6 @@ def obtain_accuracy(output, target, topk=(1,)):
|
||||
|
||||
|
||||
class EvaluationMetric(abc.ABC):
|
||||
|
||||
def __init__(self):
|
||||
self._total_metrics = 0
|
||||
|
||||
|
Reference in New Issue
Block a user