Add visualize codes for Q

This commit is contained in:
D-X-Y
2021-04-11 21:45:20 +08:00
parent e777f38233
commit 0e2dd13762
16 changed files with 570 additions and 125 deletions

View File

@@ -2,9 +2,11 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
# modules in AutoDL
from log_utils import AverageMeter
from log_utils import time_string
from utils import obtain_accuracy
from .eval_funcs import obtain_accuracy
def basic_train(

View File

@@ -0,0 +1,14 @@
def obtain_accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

View File

@@ -3,12 +3,14 @@
#####################################################
import os, time, copy, torch, pathlib
# modules in AutoDL
import datasets
from config_utils import load_config
from procedures import prepare_seed, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net
from utils import get_model_infos
from .eval_funcs import obtain_accuracy
__all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"]

View File

@@ -3,9 +3,10 @@
##################################################
import os, sys, time, torch
from log_utils import AverageMeter, time_string
from utils import obtain_accuracy
from models import change_key
from .eval_funcs import obtain_accuracy
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
expected_flop = torch.mean(expected_flop)

View File

@@ -2,9 +2,11 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
# modules in AutoDL
from log_utils import AverageMeter, time_string
from utils import obtain_accuracy
from models import change_key
from .eval_funcs import obtain_accuracy
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):

View File

@@ -4,9 +4,9 @@
import os, sys, time, torch
import torch.nn.functional as F
# our modules
# modules in AutoDL
from log_utils import AverageMeter, time_string
from utils import obtain_accuracy
from .eval_funcs import obtain_accuracy
def simple_KD_train(