102->201 / NAS->autoDL / more configs of TAS / reorganize docs / fix bugs in NAS baselines
This commit is contained in:
@@ -17,7 +17,7 @@ from datasets import get_datasets, SearchDataset
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from nas_102_api import NASBench102API as API
|
||||
from nas_201_api import NASBench201API as API
|
||||
from models import CellStructure, get_search_spaces
|
||||
from R_EA import train_and_eval
|
||||
|
||||
@@ -128,6 +128,7 @@ def main(xargs, nas_bench):
|
||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||
policy = Policy(xargs.max_nodes, search_space)
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate)
|
||||
#optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate)
|
||||
eps = np.finfo(np.float32).eps.item()
|
||||
baseline = ExponentialMovingAverage(xargs.EMA_momentum)
|
||||
logger.log('policy : {:}'.format(policy))
|
||||
@@ -141,13 +142,14 @@ def main(xargs, nas_bench):
|
||||
# attempts = 0
|
||||
x_start_time = time.time()
|
||||
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
|
||||
total_steps, total_costs = 0, 0
|
||||
total_steps, total_costs, trace = 0, 0, []
|
||||
#for istep in range(xargs.RL_steps):
|
||||
while total_costs < xargs.time_budget:
|
||||
start_time = time.time()
|
||||
log_prob, action = select_action( policy )
|
||||
arch = policy.generate_arch( action )
|
||||
reward, cost_time = train_and_eval(arch, nas_bench, extra_info)
|
||||
trace.append( (reward, arch) )
|
||||
# accumulate time
|
||||
if total_costs + cost_time < xargs.time_budget:
|
||||
total_costs += cost_time
|
||||
@@ -166,7 +168,8 @@ def main(xargs, nas_bench):
|
||||
#logger.log('----> {:}'.format(policy.arch_parameters))
|
||||
#logger.log('')
|
||||
|
||||
best_arch = policy.genotype()
|
||||
# best_arch = policy.genotype() # first version
|
||||
best_arch = max(trace, key=lambda x: x[0])[1]
|
||||
logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs, time.time()-x_start_time))
|
||||
info = nas_bench.query_by_arch( best_arch )
|
||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||
|
Reference in New Issue
Block a user