update NAS-Bench-102 baselines
This commit is contained in:
@@ -17,7 +17,7 @@ from datasets import get_datasets, SearchDataset
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from nas_102_api import NASBench102API
|
||||
from nas_102_api import NASBench102API as API
|
||||
from models import CellStructure, get_search_spaces
|
||||
from R_EA import train_and_eval
|
||||
|
||||
@@ -132,10 +132,18 @@ def main(xargs, nas_bench):
|
||||
|
||||
# REINFORCE
|
||||
# attempts = 0
|
||||
for istep in range(xargs.RL_steps):
|
||||
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
|
||||
total_steps, total_costs = 0, 0
|
||||
#for istep in range(xargs.RL_steps):
|
||||
while total_costs < xargs.time_budget:
|
||||
start_time = time.time()
|
||||
log_prob, action = select_action( policy )
|
||||
arch = policy.generate_arch( action )
|
||||
reward = train_and_eval(arch, nas_bench, extra_info)
|
||||
reward, cost_time = train_and_eval(arch, nas_bench, extra_info)
|
||||
# accumulate time
|
||||
if total_costs + cost_time < xargs.time_budget:
|
||||
total_costs += cost_time
|
||||
else: break
|
||||
|
||||
baseline.update(reward)
|
||||
# calculate loss
|
||||
@@ -143,13 +151,15 @@ def main(xargs, nas_bench):
|
||||
optimizer.zero_grad()
|
||||
policy_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
logger.log('step [{:3d}/{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(istep, xargs.RL_steps, baseline.value(), policy_loss.item(), policy.genotype()))
|
||||
# accumulate time
|
||||
total_costs += time.time() - start_time
|
||||
total_steps += 1
|
||||
logger.log('step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(total_steps, baseline.value(), policy_loss.item(), policy.genotype()))
|
||||
#logger.log('----> {:}'.format(policy.arch_parameters))
|
||||
logger.log('')
|
||||
#logger.log('')
|
||||
|
||||
best_arch = policy.genotype()
|
||||
|
||||
logger.log('REINFORCE finish with {:} steps and {:.1f} s.'.format(total_steps, total_costs))
|
||||
info = nas_bench.query_by_arch( best_arch )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||
else : logger.log('{:}'.format(info))
|
||||
@@ -169,8 +179,9 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||
parser.add_argument('--learning_rate', type=float, help='The learning rate for REINFORCE.')
|
||||
parser.add_argument('--RL_steps', type=int, help='The steps for REINFORCE.')
|
||||
#parser.add_argument('--RL_steps', type=int, help='The steps for REINFORCE.')
|
||||
parser.add_argument('--EMA_momentum', type=float, help='The momentum value for EMA.')
|
||||
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||
@@ -183,7 +194,7 @@ if __name__ == '__main__':
|
||||
nas_bench = None
|
||||
else:
|
||||
print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
|
||||
nas_bench = AANASBenchAPI(args.arch_nas_dataset)
|
||||
nas_bench = API(args.arch_nas_dataset)
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_indexes, num = None, [], 500
|
||||
for i in range(num):
|
||||
|
Reference in New Issue
Block a user