102->201 / NAS->autoDL / more configs of TAS / reorganize docs / fix bugs in NAS baselines
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
##################################################################
|
||||
# Regularized Evolution for Image Classifier Architecture Search #
|
||||
##################################################################
|
||||
@@ -16,7 +16,7 @@ from datasets import get_datasets, SearchDataset
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from nas_102_api import NASBench102API as API
|
||||
from nas_201_api import NASBench201API as API
|
||||
from models import CellStructure, get_search_spaces
|
||||
|
||||
|
||||
@@ -31,30 +31,8 @@ class Model(object):
|
||||
return '{:}'.format(self.arch)
|
||||
|
||||
|
||||
def valid_func(xloader, network, criterion):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
network.train()
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
for step, (arch_inputs, arch_targets) in enumerate(xloader):
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# prediction
|
||||
_, logits = network(arch_inputs)
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
# record
|
||||
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||
|
||||
|
||||
# This function is to mimic the training and evaluatinig procedure for a single architecture `arch`.
|
||||
# The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch.
|
||||
def train_and_eval(arch, nas_bench, extra_info):
|
||||
if nas_bench is not None:
|
||||
arch_index = nas_bench.query_index_by_arch( arch )
|
||||
|
Reference in New Issue
Block a user