Update REA, REINFORCE, and RANDOM

This commit is contained in:
D-X-Y
2020-07-13 11:35:13 +00:00
parent 6dc494be08
commit ebad9197f7
5 changed files with 38 additions and 26 deletions

View File

@@ -145,6 +145,7 @@ def main(xargs, api):
x_start_time = time.time()
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
total_steps, total_costs, trace = 0, [], []
current_best_index = []
while len(total_costs) == 0 or total_costs[-1] < xargs.time_budget:
start_time = time.time()
log_prob, action = select_action( policy )
@@ -162,9 +163,8 @@ def main(xargs, api):
# accumulate time
total_steps += 1
logger.log('step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(total_steps, baseline.value(), policy_loss.item(), policy.genotype()))
#logger.log('----> {:}'.format(policy.arch_parameters))
#logger.log('')
# to analyze
current_best_index.append(api.query_index_by_arch(max(trace, key=lambda x: x[0])[1]))
# best_arch = policy.genotype() # first version
best_arch = max(trace, key=lambda x: x[0])[1]
logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs[-1], time.time()-x_start_time))
@@ -173,7 +173,7 @@ def main(xargs, api):
logger.log('-'*100)
logger.close()
return logger.log_dir, [api.query_index_by_arch(x[1]) for x in trace], total_costs
return logger.log_dir, current_best_index, total_costs
if __name__ == '__main__':
@@ -203,7 +203,7 @@ if __name__ == '__main__':
print('save-dir : {:}'.format(args.save_dir))
if args.rand_seed < 0:
save_dir, all_info = None, {}
save_dir, all_info = None, collections.OrderedDict()
for i in range(args.loops_if_rand):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)