add resize to resize the images; cancel the acc; update the folder path

This commit is contained in:
2024-08-31 15:49:42 +02:00
parent 33452adc3b
commit 968157b657
3 changed files with 56 additions and 31 deletions

View File

@@ -40,9 +40,9 @@ parser.add_argument('--device', default="cuda", type=str, nargs='?', help='setup
parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric')
parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric')
parser.add_argument('--datasets', default='cifar10', type=str, help='input datasets')
parser.add_argument('--start_index', default=0, type=int, help='start index of the networks to evaluate')
args = parser.parse_args()
if __name__ == "__main__":
device = torch.device(args.device)
@@ -58,18 +58,21 @@ if __name__ == "__main__":
# nasbench_len = 15625
nasbench_len = 15625
filename = f'output/swap_results_{args.datasets}.csv'
if args.datasets == 'aircraft':
api_datasets = 'cifar10'
# for index, i in arch_info.iterrows():
for ind in range(nasbench_len):
for ind in range(args.start_index,nasbench_len):
# print(f'Evaluating network: {index}')
print(f'Evaluating network: {ind}')
config = api.get_net_config(ind, args.datasets)
config = api.get_net_config(ind, api_datasets)
network = get_cell_based_tiny_net(config)
# nas_results = api.query_by_index(i, 'cifar10')
# acc = nas_results[111].get_eval('ori-test')
nas_results = api.get_more_info(ind, args.datasets, None, hp=200, is_random=False)
acc = nas_results['test-accuracy']
# nas_results = api.get_more_info(ind, api_datasets, None, hp=200, is_random=False)
# acc = nas_results['test-accuracy']
acc = 99
# print(type(network))
start_time = time.time()
@@ -98,6 +101,8 @@ if __name__ == "__main__":
print(f'Elapsed time: {end_time - start_time:.2f} seconds')
results.append([np.mean(swap_score), acc, ind])
with open(filename, 'a') as f:
f.write(f'{np.mean(swap_score)},{acc},{ind}\n')
results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index'])
results.to_csv('output/swap_results.csv', float_format='%.4f', index=False)