From 968157b65703443924209f181bcd011ebb54154e Mon Sep 17 00:00:00 2001 From: Mhrooz Date: Sat, 31 Aug 2024 15:49:42 +0200 Subject: [PATCH] add resize to resize the images; cancel the acc; update the folder path --- calculate_datasets_statistics.py | 10 +++++- correlation.py | 17 +++++---- preprocess_aircraft.py | 60 +++++++++++++++++++------------- 3 files changed, 56 insertions(+), 31 deletions(-) diff --git a/calculate_datasets_statistics.py b/calculate_datasets_statistics.py index 1d6322a..7f7a2ef 100644 --- a/calculate_datasets_statistics.py +++ b/calculate_datasets_statistics.py @@ -4,7 +4,7 @@ # # 加载CIFAR-10数据集 # transform = transforms.Compose([transforms.ToTensor()]) -# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) +# trainset = torchvision.datasets.CIFAR10(root='./datasets', train=True, download=True, transform=transform) # trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2) # # 将所有数据加载到内存中 @@ -18,6 +18,10 @@ # print(f'Mean: {mean}') # print(f'Std: {std}') +# results: +# Mean: tensor([0.4935, 0.4834, 0.4472]) +# Std: tensor([0.2476, 0.2446, 0.2626]) + import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader @@ -35,6 +39,7 @@ dataset_name = args.dataset # 设置数据集的transform(这里只使用了ToTensor) transform = transforms.Compose([ + transforms.Resize((224, 224)), transforms.ToTensor() ]) @@ -47,7 +52,10 @@ mean = torch.zeros(3) std = torch.zeros(3) nb_samples = 0 +count = 0 for data in dataloader: + count += 1 + print(f'Processing batch {count}/{len(dataloader)}', end='\r') batch_samples = data[0].size(0) data = data[0].view(batch_samples, data[0].size(1), -1) mean += data.mean(2).sum(0) diff --git a/correlation.py b/correlation.py index 49adb47..94c1304 100644 --- a/correlation.py +++ b/correlation.py @@ -40,9 +40,9 @@ parser.add_argument('--device', default="cuda", type=str, nargs='?', help='setup parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric') parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric') parser.add_argument('--datasets', default='cifar10', type=str, help='input datasets') +parser.add_argument('--start_index', default=0, type=int, help='start index of the networks to evaluate') args = parser.parse_args() - if __name__ == "__main__": device = torch.device(args.device) @@ -58,18 +58,21 @@ if __name__ == "__main__": # nasbench_len = 15625 nasbench_len = 15625 + filename = f'output/swap_results_{args.datasets}.csv' + if args.datasets == 'aircraft': + api_datasets = 'cifar10' # for index, i in arch_info.iterrows(): - for ind in range(nasbench_len): + for ind in range(args.start_index,nasbench_len): # print(f'Evaluating network: {index}') print(f'Evaluating network: {ind}') - - config = api.get_net_config(ind, args.datasets) + config = api.get_net_config(ind, api_datasets) network = get_cell_based_tiny_net(config) # nas_results = api.query_by_index(i, 'cifar10') # acc = nas_results[111].get_eval('ori-test') - nas_results = api.get_more_info(ind, args.datasets, None, hp=200, is_random=False) - acc = nas_results['test-accuracy'] + # nas_results = api.get_more_info(ind, api_datasets, None, hp=200, is_random=False) + # acc = nas_results['test-accuracy'] + acc = 99 # print(type(network)) start_time = time.time() @@ -98,6 +101,8 @@ if __name__ == "__main__": print(f'Elapsed time: {end_time - start_time:.2f} seconds') results.append([np.mean(swap_score), acc, ind]) + with open(filename, 'a') as f: + f.write(f'{np.mean(swap_score)},{acc},{ind}\n') results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index']) results.to_csv('output/swap_results.csv', float_format='%.4f', index=False) diff --git a/preprocess_aircraft.py b/preprocess_aircraft.py index 45ec2fd..6bbc94e 100644 --- a/preprocess_aircraft.py +++ b/preprocess_aircraft.py @@ -3,39 +3,51 @@ import shutil # 数据集路径 dataset_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images' -output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/sorted_images' +test_output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/test_sorted_images' +train_output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/train_sorted_images' # 类别文件,例如 'images_variant_trainval.txt' -labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_test.txt' +# 有两个文件,一个是训练集和验证集,一个是测试集 +test_labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_test.txt' +train_labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_train.txt' # 创建输出文件夹 -if not os.path.exists(output_path): - os.makedirs(output_path) +if not os.path.exists(test_output_path): + os.makedirs(test_output_path) +if not os.path.exists(train_output_path): + os.makedirs(train_output_path) # 读取类别文件 -with open(labels_file, 'r') as f: - lines = f.readlines() +with open(test_labels_file, 'r') as f: + test_lines = f.readlines() +with open(train_labels_file, 'r') as f: + train_lines = f.readlines() -count = 0 +def sort_images(lines, output_path): + count = 0 + for line in lines: + count += 1 + print(f'Processing image {count}/{len(lines)}', end='\r') + parts = line.strip().split(' ') + image_name = parts[0] + '.jpg' + category = '_'.join(parts[1:]).replace('/', '_') -for line in lines: - count += 1 - print(f'Processing image {count}/{len(lines)}', end='\r') - parts = line.strip().split(' ') - image_name = parts[0] + '.jpg' - category = '_'.join(parts[1:]).replace('/', '_') + # 创建类别文件夹 + category_path = os.path.join(output_path, category) + if not os.path.exists(category_path): + os.makedirs(category_path) - # 创建类别文件夹 - category_path = os.path.join(output_path, category) - if not os.path.exists(category_path): - os.makedirs(category_path) + # 移动图像到对应类别文件夹 + src = os.path.join(dataset_path, image_name) + dst = os.path.join(category_path, image_name) + if os.path.exists(src): + shutil.move(src, dst) + else: + print(f'Image {image_name} not found!') - # 移动图像到对应类别文件夹 - src = os.path.join(dataset_path, image_name) - dst = os.path.join(category_path, image_name) - if os.path.exists(src): - shutil.move(src, dst) - else: - print(f'Image {image_name} not found!') +print("Sorting test images into folders by category...") +sort_images(test_lines, test_output_path) +print("Sorting train images into folders by category...") +sort_images(train_lines, train_output_path) print("Images have been sorted into folders by category.")