updates for beta
This commit is contained in:
@@ -62,7 +62,7 @@ def train_controller(xloader, shared_cnn, controller, criterion, optimizer, conf
|
||||
# config. (containing some necessary arg)
|
||||
# baseline: The baseline score (i.e. average val_acc) from the previous epoch
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
GradnormMeter, LossMeter, ValAccMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time()
|
||||
GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time()
|
||||
|
||||
shared_cnn.eval()
|
||||
controller.train()
|
||||
@@ -96,8 +96,9 @@ def train_controller(xloader, shared_cnn, controller, criterion, optimizer, conf
|
||||
# account
|
||||
RewardMeter.update(reward.item())
|
||||
BaselineMeter.update(baseline.item())
|
||||
ValAccMeter.update(val_top1.item())
|
||||
ValAccMeter.update(val_top1.item()*100)
|
||||
LossMeter.update(loss.item())
|
||||
EntropyMeter.update(entropy.item())
|
||||
|
||||
# Average gradient over controller_num_aggregate samples
|
||||
loss = loss / config.ctl_num_aggre
|
||||
@@ -116,7 +117,8 @@ def train_controller(xloader, shared_cnn, controller, criterion, optimizer, conf
|
||||
Sstr = '*Train-Controller* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, config.ctl_train_steps * config.ctl_num_aggre)
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter)
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr)
|
||||
Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val, EntropyMeter.avg)
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr)
|
||||
|
||||
return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg, baseline.item()
|
||||
|
||||
@@ -250,7 +252,7 @@ def main(xargs):
|
||||
w_scheduler.update(epoch, 0.0)
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
||||
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
|
||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
|
||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), baseline))
|
||||
|
||||
cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(train_loader, shared_cnn, controller, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
|
||||
logger.log('[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, cnn_loss, cnn_top1, cnn_top5))
|
||||
@@ -264,7 +266,7 @@ def main(xargs):
|
||||
logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline))
|
||||
best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader)
|
||||
shared_cnn.module.update_arch(best_arch)
|
||||
best_valid_acc = valid_func(valid_loader, shared_cnn, criterion)
|
||||
_, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion)
|
||||
|
||||
genotypes[epoch] = best_arch
|
||||
# check the best accuracy
|
||||
@@ -301,6 +303,14 @@ def main(xargs):
|
||||
start_time = time.time()
|
||||
|
||||
logger.log('\n' + '-'*100)
|
||||
logger.log('During searching, the best architecture is {:}'.format(genotypes['best']))
|
||||
logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best']))
|
||||
logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples))
|
||||
final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples)
|
||||
shared_cnn.module.update_arch(final_arch)
|
||||
final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion)
|
||||
logger.log('The Selected Final Architecture : {:}'.format(final_arch))
|
||||
logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5))
|
||||
# check the performance from the architecture dataset
|
||||
#if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||
# logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||
|
@@ -23,7 +23,6 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
network.train()
|
||||
end = time.time()
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
|
||||
scheduler.update(None, 1.0 * step / len(xloader))
|
||||
@@ -33,9 +32,13 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the weights
|
||||
network.module.set_cal_mode( 'urs' )
|
||||
w_optimizer.zero_grad()
|
||||
_, logits = network(base_inputs)
|
||||
network.train()
|
||||
sampled_arch = network.module.dync_genotype(True)
|
||||
network.module.set_cal_mode('dynamic', sampled_arch)
|
||||
#network.module.set_cal_mode( 'urs' )
|
||||
network.zero_grad()
|
||||
_, logits = network( torch.cat((base_inputs, arch_inputs), dim=0) )
|
||||
logits = logits[:base_inputs.size(0)]
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
w_optimizer.step()
|
||||
@@ -46,8 +49,9 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
||||
base_top5.update (base_prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture-weight
|
||||
network.eval()
|
||||
network.module.set_cal_mode( 'joint' )
|
||||
a_optimizer.zero_grad()
|
||||
network.zero_grad()
|
||||
_, logits = network(arch_inputs)
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
arch_loss.backward()
|
||||
@@ -68,15 +72,42 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
||||
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
|
||||
Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
|
||||
return base_losses.avg, base_top1.avg, base_top5.avg
|
||||
#print (nn.functional.softmax(network.module.arch_parameters, dim=-1))
|
||||
#print (network.module.arch_parameters)
|
||||
return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||
|
||||
|
||||
def get_best_arch(xloader, network, n_samples):
|
||||
with torch.no_grad():
|
||||
network.eval()
|
||||
archs, valid_accs = [], []
|
||||
loader_iter = iter(xloader)
|
||||
for i in range(n_samples):
|
||||
try:
|
||||
inputs, targets = next(loader_iter)
|
||||
except:
|
||||
loader_iter = iter(xloader)
|
||||
inputs, targets = next(loader_iter)
|
||||
|
||||
sampled_arch = network.module.dync_genotype(False)
|
||||
network.module.set_cal_mode('dynamic', sampled_arch)
|
||||
_, logits = network(inputs)
|
||||
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
|
||||
|
||||
archs.append( sampled_arch )
|
||||
valid_accs.append( val_top1.item() )
|
||||
|
||||
best_idx = np.argmax(valid_accs)
|
||||
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
|
||||
return best_arch, best_valid_acc
|
||||
|
||||
|
||||
def valid_func(xloader, network, criterion):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
network.train()
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
network.eval()
|
||||
for step, (arch_inputs, arch_targets) in enumerate(xloader):
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
@@ -117,8 +148,8 @@ def main(xargs):
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
config_path = 'configs/nas-benchmark/algos/SETN.config'
|
||||
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
#config_path = 'configs/nas-benchmark/algos/SETN.config'
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
@@ -126,7 +157,7 @@ def main(xargs):
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
@@ -134,6 +165,7 @@ def main(xargs):
|
||||
model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||
'space' : search_space}, None)
|
||||
logger.log('search space : {:}'.format(search_space))
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
|
||||
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
||||
@@ -173,17 +205,24 @@ def main(xargs):
|
||||
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
|
||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
|
||||
|
||||
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
|
||||
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
||||
search_model.set_cal_mode('urs')
|
||||
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
|
||||
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
|
||||
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
||||
logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
|
||||
|
||||
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
|
||||
network.module.set_cal_mode('dynamic', genotype)
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||
search_model.set_cal_mode('joint')
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||
search_model.set_cal_mode('select')
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype))
|
||||
#search_model.set_cal_mode('urs')
|
||||
#valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
#logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||
#search_model.set_cal_mode('joint')
|
||||
#valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
#logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||
#search_model.set_cal_mode('select')
|
||||
#valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
#logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||
# check the best accuracy
|
||||
valid_accuracies[epoch] = valid_a_top1
|
||||
if valid_a_top1 > valid_accuracies['best']:
|
||||
@@ -192,7 +231,7 @@ def main(xargs):
|
||||
find_best = True
|
||||
else: find_best = False
|
||||
|
||||
genotypes[epoch] = search_model.genotype()
|
||||
genotypes[epoch] = genotype
|
||||
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
|
||||
# save checkpoint
|
||||
save_path = save_checkpoint({'epoch' : epoch + 1,
|
||||
@@ -219,6 +258,7 @@ def main(xargs):
|
||||
start_time = time.time()
|
||||
|
||||
# sampling
|
||||
"""
|
||||
with torch.no_grad():
|
||||
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||
selected_archs = set()
|
||||
@@ -238,6 +278,7 @@ def main(xargs):
|
||||
if best_arch is None or best_acc < valid_a_top1:
|
||||
best_arch, best_acc = arch, valid_a_top1
|
||||
logger.log('Find the best one : {:} with accuracy={:.2f}%'.format(best_arch, best_acc))
|
||||
"""
|
||||
|
||||
logger.log('\n' + '-'*100)
|
||||
# check the performance from the architecture dataset
|
||||
@@ -267,6 +308,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--select_num', type=int, help='The number of selected architectures to evaluate.')
|
||||
parser.add_argument('--config_path', type=str, help='.')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
|
Reference in New Issue
Block a user