This commit is contained in:
D-X-Y
2019-11-11 00:46:02 +11:00
parent fac556c176
commit 7b354d4c74
26 changed files with 1563 additions and 43 deletions

View File

@@ -1,8 +1,6 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
##################################################
######################################################################################
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
######################################################################################
import os, sys, time, glob, random, argparse
import numpy as np
from copy import deepcopy
@@ -24,6 +22,7 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
network.train()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
@@ -32,13 +31,11 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
data_time.update(time.time() - end)
# update the weights
network.train()
sampled_arch = network.module.dync_genotype(True)
network.module.set_cal_mode('dynamic', sampled_arch)
#network.module.set_cal_mode( 'urs' )
network.zero_grad()
_, logits = network( torch.cat((base_inputs, arch_inputs), dim=0) )
logits = logits[:base_inputs.size(0)]
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
w_optimizer.step()
@@ -49,7 +46,6 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
base_top5.update (base_prec5.item(), base_inputs.size(0))
# update the architecture-weight
network.eval()
network.module.set_cal_mode( 'joint' )
network.zero_grad()
_, logits = network(arch_inputs)
@@ -257,6 +253,7 @@ def main(xargs):
epoch_time.update(time.time() - start_time)
start_time = time.time()
logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best']))
# sampling
"""
with torch.no_grad():