diff --git a/correlation.py b/correlation.py index 238ffeb..c45f54b 100644 --- a/correlation.py +++ b/correlation.py @@ -59,15 +59,15 @@ if __name__ == "__main__": nasbench_len = 15625 # for index, i in arch_info.iterrows(): - for i in range(nasbench_len): + for ind in range(nasbench_len): # print(f'Evaluating network: {index}') - print(f'Evaluating network: {i}') + print(f'Evaluating network: {ind}') - config = api.get_net_config(i, 'cifar10') + config = api.get_net_config(ind, 'cifar10') network = get_cell_based_tiny_net(config) # nas_results = api.query_by_index(i, 'cifar10') # acc = nas_results[111].get_eval('ori-test') - nas_results = api.get_more_info(i, 'cifar10', None, hp=200, is_random=False) + nas_results = api.get_more_info(ind, 'cifar10', None, hp=200, is_random=False) acc = nas_results['test-accuracy'] # print(type(network)) @@ -96,7 +96,7 @@ if __name__ == "__main__": print(f'Average SWAP score: {np.mean(swap_score)}') print(f'Elapsed time: {end_time - start_time:.2f} seconds') - results.append([np.mean(swap_score), acc, i]) + results.append([np.mean(swap_score), acc, ind]) results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index']) results.to_csv('output/swap_results.csv', float_format='%.4f', index=False)