From 33452adc3b9834add50b1ef50cf0165310ed5307 Mon Sep 17 00:00:00 2001 From: Mhrooz Date: Sat, 31 Aug 2024 12:20:59 +0200 Subject: [PATCH] preprocess aircraft dataset to get the statistics. which can be used in swap-nas --- calculate_datasets_statistics.py | 61 ++++++++++++++++++++++++++++++++ preprocess_aircraft.py | 41 +++++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 calculate_datasets_statistics.py create mode 100644 preprocess_aircraft.py diff --git a/calculate_datasets_statistics.py b/calculate_datasets_statistics.py new file mode 100644 index 0000000..1d6322a --- /dev/null +++ b/calculate_datasets_statistics.py @@ -0,0 +1,61 @@ +# import torch +# import torchvision +# import torchvision.transforms as transforms + +# # 加载CIFAR-10数据集 +# transform = transforms.Compose([transforms.ToTensor()]) +# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) +# trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2) + +# # 将所有数据加载到内存中 +# data = next(iter(trainloader)) +# images, _ = data + +# # 计算每个通道的均值和标准差 +# mean = images.mean([0, 2, 3]) +# std = images.std([0, 2, 3]) + +# print(f'Mean: {mean}') +# print(f'Std: {std}') + +import torch +from torchvision import datasets, transforms +from torch.utils.data import DataLoader +import argparse + +parser = argparse.ArgumentParser(description='Calculate mean and std of dataset') +parser.add_argument('--dataset', type=str, default='cifar10', help='dataset name') +parser.add_argument('--data_path', type=str, default='./datasets/cifar-10-batches-py', help='path to dataset image folder') + +args = parser.parse_args() + +# 设置数据集路径 +dataset_path = args.data_path +dataset_name = args.dataset + +# 设置数据集的transform(这里只使用了ToTensor) +transform = transforms.Compose([ + transforms.ToTensor() +]) + +# 使用ImageFolder加载数据集 +dataset = datasets.ImageFolder(root=dataset_path, transform=transform) +dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4) + +# 初始化变量来累积均值和标准差 +mean = torch.zeros(3) +std = torch.zeros(3) +nb_samples = 0 + +for data in dataloader: + batch_samples = data[0].size(0) + data = data[0].view(batch_samples, data[0].size(1), -1) + mean += data.mean(2).sum(0) + std += data.std(2).sum(0) + nb_samples += batch_samples + +mean /= nb_samples +std /= nb_samples + +print(f'Mean: {mean}') +print(f'Std: {std}') diff --git a/preprocess_aircraft.py b/preprocess_aircraft.py new file mode 100644 index 0000000..45ec2fd --- /dev/null +++ b/preprocess_aircraft.py @@ -0,0 +1,41 @@ +import os +import shutil + +# 数据集路径 +dataset_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images' +output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/sorted_images' + +# 类别文件,例如 'images_variant_trainval.txt' +labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_test.txt' + +# 创建输出文件夹 +if not os.path.exists(output_path): + os.makedirs(output_path) + +# 读取类别文件 +with open(labels_file, 'r') as f: + lines = f.readlines() + +count = 0 + +for line in lines: + count += 1 + print(f'Processing image {count}/{len(lines)}', end='\r') + parts = line.strip().split(' ') + image_name = parts[0] + '.jpg' + category = '_'.join(parts[1:]).replace('/', '_') + + # 创建类别文件夹 + category_path = os.path.join(output_path, category) + if not os.path.exists(category_path): + os.makedirs(category_path) + + # 移动图像到对应类别文件夹 + src = os.path.join(dataset_path, image_name) + dst = os.path.join(category_path, image_name) + if os.path.exists(src): + shutil.move(src, dst) + else: + print(f'Image {image_name} not found!') + +print("Images have been sorted into folders by category.")