update SETN

This commit is contained in:
D-X-Y
2019-11-12 22:35:57 +11:00
parent 7b354d4c74
commit 5c73aeb50b
5 changed files with 42 additions and 21 deletions

View File

@@ -76,22 +76,22 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
def get_best_arch(xloader, network, n_samples):
with torch.no_grad():
network.eval()
archs, valid_accs = [], []
archs, valid_accs = network.module.return_topK(n_samples), []
#print ('obtain the top-{:} architectures'.format(n_samples))
loader_iter = iter(xloader)
for i in range(n_samples):
for i, sampled_arch in enumerate(archs):
network.module.set_cal_mode('dynamic', sampled_arch)
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
sampled_arch = network.module.dync_genotype(False)
network.module.set_cal_mode('dynamic', sampled_arch)
_, logits = network(inputs)
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
archs.append( sampled_arch )
valid_accs.append( val_top1.item() )
#print ('--- {:}/{:} : {:} : {:}'.format(i, len(archs), sampled_arch, val_top1))
best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
@@ -221,11 +221,6 @@ def main(xargs):
#logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
if valid_a_top1 > valid_accuracies['best']:
valid_accuracies['best'] = valid_a_top1
genotypes['best'] = search_model.genotype()
find_best = True
else: find_best = False
genotypes[epoch] = genotype
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
@@ -244,16 +239,17 @@ def main(xargs):
'args' : deepcopy(args),
'last_checkpoint': save_path,
}, logger.path('info'), logger)
if find_best:
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad():
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best']))
#logger.log('During searching, the best gentotype is : {:} , with the validation accuracy of {:.3f}%.'.format(genotypes['best'], valid_accuracies['best']))
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
network.module.set_cal_mode('dynamic', genotype)
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1))
# sampling
"""
with torch.no_grad():