updates
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
|
||||
##################################################
|
||||
######################################################################################
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
|
||||
######################################################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
@@ -24,6 +22,7 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
||||
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
end = time.time()
|
||||
network.train()
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
|
||||
scheduler.update(None, 1.0 * step / len(xloader))
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
@@ -32,13 +31,11 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the weights
|
||||
network.train()
|
||||
sampled_arch = network.module.dync_genotype(True)
|
||||
network.module.set_cal_mode('dynamic', sampled_arch)
|
||||
#network.module.set_cal_mode( 'urs' )
|
||||
network.zero_grad()
|
||||
_, logits = network( torch.cat((base_inputs, arch_inputs), dim=0) )
|
||||
logits = logits[:base_inputs.size(0)]
|
||||
_, logits = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
w_optimizer.step()
|
||||
@@ -49,7 +46,6 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
||||
base_top5.update (base_prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture-weight
|
||||
network.eval()
|
||||
network.module.set_cal_mode( 'joint' )
|
||||
network.zero_grad()
|
||||
_, logits = network(arch_inputs)
|
||||
@@ -257,6 +253,7 @@ def main(xargs):
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best']))
|
||||
# sampling
|
||||
"""
|
||||
with torch.no_grad():
|
||||
|
Reference in New Issue
Block a user