updates for beta

This commit is contained in:
D-X-Y
2019-11-09 16:50:13 +11:00
parent 34ba8053de
commit 975fe4c385
9 changed files with 415 additions and 38 deletions

View File

@@ -62,7 +62,7 @@ def train_controller(xloader, shared_cnn, controller, criterion, optimizer, conf
# config. (containing some necessary arg)
# baseline: The baseline score (i.e. average val_acc) from the previous epoch
data_time, batch_time = AverageMeter(), AverageMeter()
GradnormMeter, LossMeter, ValAccMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time()
GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time()
shared_cnn.eval()
controller.train()
@@ -96,8 +96,9 @@ def train_controller(xloader, shared_cnn, controller, criterion, optimizer, conf
# account
RewardMeter.update(reward.item())
BaselineMeter.update(baseline.item())
ValAccMeter.update(val_top1.item())
ValAccMeter.update(val_top1.item()*100)
LossMeter.update(loss.item())
EntropyMeter.update(entropy.item())
# Average gradient over controller_num_aggregate samples
loss = loss / config.ctl_num_aggre
@@ -116,7 +117,8 @@ def train_controller(xloader, shared_cnn, controller, criterion, optimizer, conf
Sstr = '*Train-Controller* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, config.ctl_train_steps * config.ctl_num_aggre)
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter)
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr)
Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val, EntropyMeter.avg)
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr)
return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg, baseline.item()
@@ -250,7 +252,7 @@ def main(xargs):
w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), baseline))
cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(train_loader, shared_cnn, controller, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
logger.log('[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, cnn_loss, cnn_top1, cnn_top5))
@@ -264,7 +266,7 @@ def main(xargs):
logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline))
best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader)
shared_cnn.module.update_arch(best_arch)
best_valid_acc = valid_func(valid_loader, shared_cnn, criterion)
_, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion)
genotypes[epoch] = best_arch
# check the best accuracy
@@ -301,6 +303,14 @@ def main(xargs):
start_time = time.time()
logger.log('\n' + '-'*100)
logger.log('During searching, the best architecture is {:}'.format(genotypes['best']))
logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best']))
logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples))
final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples)
shared_cnn.module.update_arch(final_arch)
final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion)
logger.log('The Selected Final Architecture : {:}'.format(final_arch))
logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5))
# check the performance from the architecture dataset
#if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
# logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))