first commit
This commit is contained in:
130
NAS-Bench-201/main_exp/transfer_nag/loader.py
Normal file
130
NAS-Bench-201/main_exp/transfer_nag/loader.py
Normal file
@@ -0,0 +1,130 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=True):
|
||||
dataset = MetaTrainDatabase(data_path, num_sample, is_pred)
|
||||
print(f'==> The number of tasks for meta-training: {len(dataset)}')
|
||||
|
||||
loader = DataLoader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
collate_fn=collate_fn)
|
||||
return loader
|
||||
|
||||
|
||||
def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False):
|
||||
dataset = MetaTestDataset(data_path, data_name, num_class)
|
||||
print(f'==> Meta-Test dataset {data_name}')
|
||||
|
||||
loader = DataLoader(dataset=dataset,
|
||||
batch_size=100,
|
||||
shuffle=False,
|
||||
num_workers=0)
|
||||
return loader
|
||||
|
||||
|
||||
class MetaTrainDatabase(Dataset):
|
||||
def __init__(self, data_path, num_sample, is_pred=True):
|
||||
self.mode = 'train'
|
||||
self.acc_norm = True
|
||||
self.num_sample = num_sample
|
||||
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
|
||||
|
||||
mtr_data_path = os.path.join(
|
||||
data_path, 'meta_train_tasks_predictor.pt')
|
||||
idx_path = os.path.join(
|
||||
data_path, 'meta_train_tasks_predictor_idx.pt')
|
||||
data = torch.load(mtr_data_path)
|
||||
self.acc = data['acc']
|
||||
self.task = data['task']
|
||||
self.graph = data['g']
|
||||
|
||||
random_idx_lst = torch.load(idx_path)
|
||||
self.idx_lst = {}
|
||||
self.idx_lst['valid'] = random_idx_lst[:400]
|
||||
self.idx_lst['train'] = random_idx_lst[400:]
|
||||
self.acc = torch.tensor(self.acc)
|
||||
self.mean = torch.mean(self.acc[self.idx_lst['train']]).item()
|
||||
self.std = torch.std(self.acc[self.idx_lst['train']]).item()
|
||||
self.task_lst = torch.load(os.path.join(
|
||||
data_path, 'meta_train_task_lst.pt'))
|
||||
|
||||
|
||||
def set_mode(self, mode):
|
||||
self.mode = mode
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.idx_lst[self.mode])
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
ridx = self.idx_lst[self.mode]
|
||||
tidx = self.task[ridx[index]]
|
||||
classes = self.task_lst[tidx]
|
||||
graph = self.graph[ridx[index]]
|
||||
acc = self.acc[ridx[index]]
|
||||
for cls in classes:
|
||||
cx = self.x[cls-1][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
x = torch.cat(data)
|
||||
if self.acc_norm:
|
||||
acc = ((acc - self.mean) / self.std) / 100.0
|
||||
else:
|
||||
acc = acc / 100.0
|
||||
return x, graph, acc
|
||||
|
||||
|
||||
class MetaTestDataset(Dataset):
|
||||
def __init__(self, data_path, data_name, num_sample, num_class=None):
|
||||
self.num_sample = num_sample
|
||||
self.data_name = data_name
|
||||
|
||||
num_class_dict = {
|
||||
'cifar100': 100,
|
||||
'cifar10': 10,
|
||||
'mnist': 10,
|
||||
'svhn': 10,
|
||||
'aircraft': 30,
|
||||
'pets': 37
|
||||
}
|
||||
|
||||
if num_class is not None:
|
||||
self.num_class = num_class
|
||||
else:
|
||||
self.num_class = num_class_dict[data_name]
|
||||
|
||||
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return 1000000
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
classes = list(range(self.num_class))
|
||||
for cls in classes:
|
||||
cx = self.x[cls][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
x = torch.cat(data)
|
||||
return x
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
x = torch.stack([item[0] for item in batch])
|
||||
graph = [item[1] for item in batch]
|
||||
acc = torch.stack([item[2] for item in batch])
|
||||
return [x, graph, acc]
|
Reference in New Issue
Block a user