Add visualize codes for Q
This commit is contained in:
@@ -2,9 +2,11 @@
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter
|
||||
from log_utils import time_string
|
||||
from utils import obtain_accuracy
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def basic_train(
|
||||
|
14
lib/procedures/eval_funcs.py
Normal file
14
lib/procedures/eval_funcs.py
Normal file
@@ -0,0 +1,14 @@
|
||||
def obtain_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
@@ -3,12 +3,14 @@
|
||||
#####################################################
|
||||
import os, time, copy, torch, pathlib
|
||||
|
||||
# modules in AutoDL
|
||||
import datasets
|
||||
from config_utils import load_config
|
||||
from procedures import prepare_seed, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net
|
||||
from utils import get_model_infos
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
__all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"]
|
||||
|
@@ -3,9 +3,10 @@
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
from models import change_key
|
||||
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
||||
expected_flop = torch.mean(expected_flop)
|
||||
|
@@ -2,9 +2,11 @@
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
from models import change_key
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
||||
|
@@ -4,9 +4,9 @@
|
||||
import os, sys, time, torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# our modules
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def simple_KD_train(
|
||||
|
Reference in New Issue
Block a user