fix bugs in RANDOM-NAS and BOHB
This commit is contained in:
@@ -82,14 +82,29 @@ def valid_func(xloader, network, criterion):
|
||||
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||
|
||||
|
||||
def search_find_best(valid_loader, network, criterion, select_num):
|
||||
best_arch, best_acc = None, -1
|
||||
for iarch in range(select_num):
|
||||
arch = network.module.random_genotype( True )
|
||||
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
if best_arch is None or best_acc < valid_a_top1:
|
||||
best_arch, best_acc = arch, valid_a_top1
|
||||
return best_arch
|
||||
def search_find_best(xloader, network, n_samples):
|
||||
with torch.no_grad():
|
||||
network.eval()
|
||||
archs, valid_accs = [], []
|
||||
#print ('obtain the top-{:} architectures'.format(n_samples))
|
||||
loader_iter = iter(xloader)
|
||||
for i in range(n_samples):
|
||||
arch = network.module.random_genotype( True )
|
||||
try:
|
||||
inputs, targets = next(loader_iter)
|
||||
except:
|
||||
loader_iter = iter(xloader)
|
||||
inputs, targets = next(loader_iter)
|
||||
|
||||
_, logits = network(inputs)
|
||||
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
|
||||
|
||||
archs.append( arch )
|
||||
valid_accs.append( val_top1.item() )
|
||||
|
||||
best_idx = np.argmax(valid_accs)
|
||||
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
|
||||
return best_arch, best_valid_acc
|
||||
|
||||
|
||||
def main(xargs):
|
||||
@@ -127,7 +142,7 @@ def main(xargs):
|
||||
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.test_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
@@ -177,7 +192,8 @@ def main(xargs):
|
||||
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||
cur_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
|
||||
cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num)
|
||||
logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(epoch_str, cur_arch, cur_valid_acc))
|
||||
genotypes[epoch] = cur_arch
|
||||
# check the best accuracy
|
||||
valid_accuracies[epoch] = valid_a_top1
|
||||
@@ -211,13 +227,7 @@ def main(xargs):
|
||||
logger.log('\n' + '-'*200)
|
||||
logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
|
||||
start_time = time.time()
|
||||
best_arch, best_acc = None, -1
|
||||
for iarch in range(xargs.select_num):
|
||||
arch = search_model.random_genotype( True )
|
||||
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||
logger.log('final evaluation [{:02d}/{:02d}] : {:} : accuracy={:.2f}%, loss={:.3f}'.format(iarch, xargs.select_num, arch, valid_a_top1, valid_a_loss))
|
||||
if best_arch is None or best_acc < valid_a_top1:
|
||||
best_arch, best_acc = arch, valid_a_top1
|
||||
best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num)
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
|
||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))
|
||||
|
Reference in New Issue
Block a user