Update REA, REINFORCE, and RANDOM
This commit is contained in:
@@ -145,6 +145,7 @@ def main(xargs, api):
|
||||
x_start_time = time.time()
|
||||
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
|
||||
total_steps, total_costs, trace = 0, [], []
|
||||
current_best_index = []
|
||||
while len(total_costs) == 0 or total_costs[-1] < xargs.time_budget:
|
||||
start_time = time.time()
|
||||
log_prob, action = select_action( policy )
|
||||
@@ -162,9 +163,8 @@ def main(xargs, api):
|
||||
# accumulate time
|
||||
total_steps += 1
|
||||
logger.log('step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(total_steps, baseline.value(), policy_loss.item(), policy.genotype()))
|
||||
#logger.log('----> {:}'.format(policy.arch_parameters))
|
||||
#logger.log('')
|
||||
|
||||
# to analyze
|
||||
current_best_index.append(api.query_index_by_arch(max(trace, key=lambda x: x[0])[1]))
|
||||
# best_arch = policy.genotype() # first version
|
||||
best_arch = max(trace, key=lambda x: x[0])[1]
|
||||
logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs[-1], time.time()-x_start_time))
|
||||
@@ -173,7 +173,7 @@ def main(xargs, api):
|
||||
logger.log('-'*100)
|
||||
logger.close()
|
||||
|
||||
return logger.log_dir, [api.query_index_by_arch(x[1]) for x in trace], total_costs
|
||||
return logger.log_dir, current_best_index, total_costs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -203,7 +203,7 @@ if __name__ == '__main__':
|
||||
print('save-dir : {:}'.format(args.save_dir))
|
||||
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_info = None, {}
|
||||
save_dir, all_info = None, collections.OrderedDict()
|
||||
for i in range(args.loops_if_rand):
|
||||
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
|
Reference in New Issue
Block a user