first commit
This commit is contained in:
@@ -0,0 +1,149 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=False):
|
||||
dataset = MetaTrainDatabase(data_path, num_sample, is_pred)
|
||||
print(f'==> The number of tasks for meta-training: {len(dataset)}')
|
||||
|
||||
loader = DataLoader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=1,
|
||||
collate_fn=collate_fn)
|
||||
return loader
|
||||
|
||||
|
||||
def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False):
|
||||
dataset = MetaTestDataset(data_path, data_name, num_class)
|
||||
print(f'==> Meta-Test dataset {data_name}')
|
||||
|
||||
loader = DataLoader(dataset=dataset,
|
||||
batch_size=100,
|
||||
shuffle=False,
|
||||
num_workers=1)
|
||||
return loader
|
||||
|
||||
|
||||
class MetaTrainDatabase(Dataset):
|
||||
def __init__(self, data_path, num_sample, is_pred=False):
|
||||
self.mode = 'train'
|
||||
self.acc_norm = True
|
||||
self.num_sample = num_sample
|
||||
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
|
||||
|
||||
self.dpath = '{}/{}/processed/'.format(data_path, 'predictor' if is_pred else 'generator')
|
||||
self.dname = f'database_219152_14.0K'
|
||||
|
||||
if not os.path.exists(self.dpath + f'{self.dname}_train.pt'):
|
||||
raise ValueError('')
|
||||
database = torch.load(self.dpath + f'{self.dname}.pt')
|
||||
|
||||
rand_idx = torch.randperm(len(database))
|
||||
test_len = int(len(database) * 0.15)
|
||||
idxlst = {'test': rand_idx[:test_len],
|
||||
'valid': rand_idx[test_len:2 * test_len],
|
||||
'train': rand_idx[2 * test_len:]}
|
||||
|
||||
for m in ['train', 'valid', 'test']:
|
||||
acc, graph, cls, net, flops = [], [], [], [], []
|
||||
for idx in tqdm(idxlst[m].tolist(), desc=f'data-{m}'):
|
||||
acc.append(database[idx]['top1'])
|
||||
net.append(database[idx]['net'])
|
||||
cls.append(database[idx]['class'])
|
||||
flops.append(database[idx]['flops'])
|
||||
if m == 'train':
|
||||
mean = torch.mean(torch.tensor(acc)).item()
|
||||
std = torch.std(torch.tensor(acc)).item()
|
||||
torch.save({'acc': acc,
|
||||
'class': cls,
|
||||
'net': net,
|
||||
'flops': flops,
|
||||
'mean': mean,
|
||||
'std': std},
|
||||
self.dpath + f'{self.dname}_{m}.pt')
|
||||
|
||||
self.set_mode(self.mode)
|
||||
|
||||
def set_mode(self, mode):
|
||||
self.mode = mode
|
||||
data = torch.load(self.dpath + f'{self.dname}_{self.mode}.pt')
|
||||
self.acc = data['acc']
|
||||
self.cls = data['class']
|
||||
self.net = data['net']
|
||||
self.flops = data['flops']
|
||||
self.mean = data['mean']
|
||||
self.std = data['std']
|
||||
|
||||
def __len__(self):
|
||||
return len(self.acc)
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
classes = self.cls[index]
|
||||
acc = self.acc[index]
|
||||
graph = self.net[index]
|
||||
|
||||
for i, cls in enumerate(classes):
|
||||
cx = self.x[cls.item()][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
x = torch.cat(data)
|
||||
if self.acc_norm:
|
||||
acc = ((acc - self.mean) / self.std) / 100.0
|
||||
else:
|
||||
acc = acc / 100.0
|
||||
return x, graph, torch.tensor(acc).view(1, 1)
|
||||
|
||||
|
||||
class MetaTestDataset(Dataset):
|
||||
def __init__(self, data_path, data_name, num_sample, num_class=None):
|
||||
self.num_sample = num_sample
|
||||
self.data_name = data_name
|
||||
if data_name == 'aircraft':
|
||||
data_name = 'aircraft100'
|
||||
num_class_dict = {
|
||||
'cifar100': 100,
|
||||
'cifar10': 10,
|
||||
'mnist': 10,
|
||||
'aircraft100': 30,
|
||||
'svhn': 10,
|
||||
'pets': 37
|
||||
}
|
||||
# 'aircraft30': 30,
|
||||
# 'aircraft100': 100,
|
||||
|
||||
if num_class is not None:
|
||||
self.num_class = num_class
|
||||
else:
|
||||
self.num_class = num_class_dict[data_name]
|
||||
|
||||
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
|
||||
|
||||
def __len__(self):
|
||||
return 1000000
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
classes = list(range(self.num_class))
|
||||
for cls in classes:
|
||||
cx = self.x[cls][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
x = torch.cat(data)
|
||||
return x
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
# x = torch.stack([item[0] for item in batch])
|
||||
# graph = [item[1] for item in batch]
|
||||
# acc = torch.stack([item[2] for item in batch])
|
||||
return batch
|
Reference in New Issue
Block a user