first commit
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from .db_ofa import DatabaseOFA
|
@@ -0,0 +1,57 @@
|
||||
######################################################################################
|
||||
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
__all__ = ['DataProvider']
|
||||
|
||||
|
||||
class DataProvider:
|
||||
SUB_SEED = 937162211 # random seed for sampling subset
|
||||
VALID_SEED = 2147483647 # random seed for the validation set
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
""" Return name of the dataset """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
""" Return shape as python list of one data entry """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
""" Return `int` of num classes """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
""" local path to save the data """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
""" link to download the data """
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def random_sample_valid_set(train_size, valid_size):
|
||||
assert train_size > valid_size
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.VALID_SEED) # set random seed before sampling validation set
|
||||
rand_indexes = torch.randperm(train_size, generator=g).tolist()
|
||||
|
||||
valid_indexes = rand_indexes[:valid_size]
|
||||
train_indexes = rand_indexes[valid_size:]
|
||||
return train_indexes, valid_indexes
|
||||
|
||||
@staticmethod
|
||||
def labels_to_one_hot(n_classes, labels):
|
||||
new_labels = np.zeros((labels.shape[0], n_classes), dtype=np.float32)
|
||||
new_labels[range(labels.shape[0]), labels] = np.ones(labels.shape)
|
||||
return new_labels
|
@@ -0,0 +1,107 @@
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import copy
|
||||
import glob
|
||||
from .imagenet import ImagenetDataProvider
|
||||
from .imagenet_loader import ImagenetRunConfig
|
||||
from .run_manager import RunManager
|
||||
from ofa.model_zoo import ofa_net
|
||||
|
||||
|
||||
class DatabaseOFA:
|
||||
def __init__(self, args, predictor=None):
|
||||
self.path = f'{args.data_path}/{args.model_name}'
|
||||
self.model_name = args.model_name
|
||||
self.index = args.index
|
||||
self.args = args
|
||||
self.predictor = predictor
|
||||
ImagenetDataProvider.DEFAULT_PATH = args.imgnet
|
||||
|
||||
if not os.path.exists(self.path):
|
||||
os.makedirs(self.path)
|
||||
|
||||
def make_db(self):
|
||||
self.ofa_network = ofa_net('ofa_mbv3_d234_e346_k357_w1.0', pretrained=True)
|
||||
self.run_config = ImagenetRunConfig(test_batch_size=self.args.batch_size,
|
||||
n_worker=20)
|
||||
database = []
|
||||
st_time = time.time()
|
||||
f = open(f'{self.path}/txt_{self.index}.txt', 'w')
|
||||
for dn in range(10000):
|
||||
best_pp = -1
|
||||
best_info = None
|
||||
dls = None
|
||||
with torch.no_grad():
|
||||
if self.model_name == 'generator':
|
||||
for i in range(10):
|
||||
net_setting = self.ofa_network.sample_active_subnet()
|
||||
subnet = self.ofa_network.get_active_subnet(preserve_weight=True)
|
||||
if i == 0:
|
||||
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet,
|
||||
self.run_config, init=False, pp=self.predictor)
|
||||
self.run_config.data_provider.assign_active_img_size(224)
|
||||
dls = {j: copy.deepcopy(run_manager.data_loader) for j in range(1, 10)}
|
||||
else:
|
||||
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet,
|
||||
self.run_config,
|
||||
init=False, data_loader=dls[i], pp=self.predictor)
|
||||
run_manager.reset_running_statistics(net=subnet)
|
||||
|
||||
loss, (top1, top5), pred_acc \
|
||||
= run_manager.validate(net=subnet, net_setting=net_setting)
|
||||
|
||||
if best_pp < pred_acc:
|
||||
best_pp = pred_acc
|
||||
print('[%d] class=%d,\t loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (
|
||||
dn, len(run_manager.cls_lst), loss, top1, top5))
|
||||
info_dict = {'loss': loss,
|
||||
'top1': top1,
|
||||
'top5': top5,
|
||||
'net': net_setting,
|
||||
'class': run_manager.cls_lst,
|
||||
'params': run_manager.net_info['params'],
|
||||
'flops': run_manager.net_info['flops'],
|
||||
'test_transform': run_manager.test_transform
|
||||
}
|
||||
best_info = info_dict
|
||||
elif self.model_name == 'predictor':
|
||||
net_setting = self.ofa_network.sample_active_subnet()
|
||||
subnet = self.ofa_network.get_active_subnet(preserve_weight=True)
|
||||
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet, self.run_config, init=False)
|
||||
self.run_config.data_provider.assign_active_img_size(224)
|
||||
run_manager.reset_running_statistics(net=subnet)
|
||||
|
||||
loss, (top1, top5), _ = run_manager.validate(net=subnet)
|
||||
print('[%d] class=%d,\t loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (
|
||||
dn, len(run_manager.cls_lst), loss, top1, top5))
|
||||
best_info = {'loss': loss,
|
||||
'top1': top1,
|
||||
'top5': top5,
|
||||
'net': net_setting,
|
||||
'class': run_manager.cls_lst,
|
||||
'params': run_manager.net_info['params'],
|
||||
'flops': run_manager.net_info['flops'],
|
||||
'test_transform': run_manager.test_transform
|
||||
}
|
||||
database.append(best_info)
|
||||
if (len(database)) % 10 == 0:
|
||||
msg = f'{(time.time() - st_time) / 60.0:0.2f}(min) save {len(database)} database, {self.index} id'
|
||||
print(msg)
|
||||
f.write(msg + '\n')
|
||||
f.flush()
|
||||
torch.save(database, f'{self.path}/database_{self.index}.pt')
|
||||
|
||||
def collect_db(self):
|
||||
if not os.path.exists(self.path + f'/processed'):
|
||||
os.makedirs(self.path + f'/processed')
|
||||
|
||||
database = []
|
||||
dlst = glob.glob(self.path + '/*.pt')
|
||||
for filepath in dlst:
|
||||
database += torch.load(filepath)
|
||||
|
||||
assert len(database) != 0
|
||||
|
||||
print(f'The number of database: {len(database)}')
|
||||
torch.save(database, self.path + f'/processed/collected_database.pt')
|
@@ -0,0 +1,240 @@
|
||||
######################################################################################
|
||||
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import warnings
|
||||
import os
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa_local.imagenet_classification.data_providers.base_provider import DataProvider
|
||||
from ofa_local.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
|
||||
from .metaloader import MetaImageNetDataset, EpisodeSampler, MetaDataLoader
|
||||
|
||||
|
||||
__all__ = ['ImagenetDataProvider']
|
||||
|
||||
|
||||
class ImagenetDataProvider(DataProvider):
|
||||
DEFAULT_PATH = '/dataset/imagenet'
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = 'None' if distort_color is None else distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
from ofa.utils.my_dataloader import MyDataLoader
|
||||
assert isinstance(self.image_size, list)
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size) # active resolution for test
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
|
||||
########################## modification ########################
|
||||
train_dataset = self.train_dataset(self.build_train_transform())
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset), valid_size)
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, True, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, True, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
# test_dataset = self.test_dataset(valid_transforms)
|
||||
test_dataset = self.meta_test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
# self.test = torch.utils.data.DataLoader(
|
||||
# test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
# )
|
||||
sampler = EpisodeSampler(
|
||||
max_way=1000, query=10, ylst=test_dataset.ylst)
|
||||
self.test = MetaDataLoader(dataset=test_dataset,
|
||||
sampler=sampler,
|
||||
batch_size=test_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'imagenet'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 1000
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = self.DEFAULT_PATH
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = os.path.expanduser('~/dataset/imagenet')
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
return datasets.ImageFolder(self.train_path, _transforms)
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
return datasets.ImageFolder(self.valid_path, _transforms)
|
||||
|
||||
def meta_test_dataset(self, _transforms):
|
||||
return MetaImageNetDataset('val', max_way=1000, query=10,
|
||||
dpath='/w14/dataset/ILSVRC2012', transform=_transforms)
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'val')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
# random_resize_crop -> random_horizontal_flip
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
# color augmentation (optional)
|
||||
color_transform = None
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting BN running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, True, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
@@ -0,0 +1,40 @@
|
||||
from .imagenet import ImagenetDataProvider
|
||||
from ofa_local.imagenet_classification.run_manager import RunConfig
|
||||
|
||||
|
||||
__all__ = ['ImagenetRunConfig']
|
||||
|
||||
|
||||
class ImagenetRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='imagenet', train_batch_size=256, test_batch_size=500, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None,
|
||||
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224, **kwargs):
|
||||
super(ImagenetRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == ImagenetDataProvider.name():
|
||||
DataProviderClass = ImagenetDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
@@ -0,0 +1,210 @@
|
||||
from torch.utils.data.sampler import Sampler
|
||||
import os
|
||||
import random
|
||||
from PIL import Image
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import glob
|
||||
|
||||
|
||||
class RandCycleIter:
|
||||
'''
|
||||
Return data_list per class
|
||||
Shuffle the returning order after one epoch
|
||||
'''
|
||||
def __init__ (self, data, shuffle=True):
|
||||
self.data_list = list(data)
|
||||
self.length = len(self.data_list)
|
||||
self.i = self.length - 1
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__ (self):
|
||||
return self
|
||||
|
||||
def __next__ (self):
|
||||
self.i += 1
|
||||
|
||||
if self.i == self.length:
|
||||
self.i = 0
|
||||
if self.shuffle:
|
||||
random.shuffle(self.data_list)
|
||||
|
||||
return self.data_list[self.i]
|
||||
|
||||
|
||||
class EpisodeSampler(Sampler):
|
||||
def __init__(self, max_way, query, ylst):
|
||||
self.max_way = max_way
|
||||
self.query = query
|
||||
self.ylst = ylst
|
||||
# self.n_epi = n_epi
|
||||
|
||||
clswise_xidx = defaultdict(list)
|
||||
for i, y in enumerate(ylst):
|
||||
clswise_xidx[y].append(i)
|
||||
self.cws_xidx_iter = [RandCycleIter(cxidx, shuffle=True)
|
||||
for cxidx in clswise_xidx.values()]
|
||||
self.n_cls = len(clswise_xidx)
|
||||
|
||||
self.create_episode()
|
||||
|
||||
|
||||
def __iter__ (self):
|
||||
return self.get_index()
|
||||
|
||||
def __len__ (self):
|
||||
return self.get_len()
|
||||
|
||||
def create_episode(self):
|
||||
self.way = torch.randperm(int(self.max_way/10.0)-1)[0] * 10 + 10
|
||||
cls_lst = torch.sort(torch.randperm(self.max_way)[:self.way])[0]
|
||||
self.cls_itr = iter(cls_lst)
|
||||
self.cls_lst = cls_lst
|
||||
|
||||
def get_len(self):
|
||||
return self.way * self.query
|
||||
|
||||
def get_index(self):
|
||||
x_itr = self.cws_xidx_iter
|
||||
|
||||
i, j = 0, 0
|
||||
while i < self.query * self.way:
|
||||
if j >= self.query:
|
||||
j = 0
|
||||
if j == 0:
|
||||
cls_idx = next(self.cls_itr).item()
|
||||
bb = [x_itr[cls_idx]] * self.query
|
||||
didx = next(zip(*bb))
|
||||
yield didx[j]
|
||||
# yield (didx[j], self.way)
|
||||
|
||||
i += 1; j += 1
|
||||
|
||||
|
||||
class MetaImageNetDataset(Dataset):
|
||||
def __init__(self, mode='val',
|
||||
max_way=1000, query=10,
|
||||
dpath='/w14/dataset/ILSVRC2012', transform=None):
|
||||
self.dpath = dpath
|
||||
self.transform = transform
|
||||
self.mode = mode
|
||||
|
||||
self.max_way = max_way
|
||||
self.query = query
|
||||
classes, class_to_idx = self._find_classes(dpath+'/'+mode)
|
||||
self.classes, self.class_to_idx = classes, class_to_idx
|
||||
# self.class_folder_lst = \
|
||||
# glob.glob(dpath+'/'+mode+'/*')
|
||||
# ## sorting alphabetically
|
||||
# self.class_folder_lst = sorted(self.class_folder_lst)
|
||||
self.file_path_lst, self.ylst = [], []
|
||||
for cls in classes:
|
||||
xlst = glob.glob(dpath+'/'+mode+'/'+cls+'/*')
|
||||
self.file_path_lst += xlst[:self.query]
|
||||
y = class_to_idx[cls]
|
||||
self.ylst += [y] * len(xlst[:self.query])
|
||||
|
||||
# for y, cls in enumerate(self.class_folder_lst):
|
||||
# xlst = glob.glob(cls+'/*')
|
||||
# self.file_path_lst += xlst[:self.query]
|
||||
# self.ylst += [y] * len(xlst[:self.query])
|
||||
# # self.file_path_lst += [xlst[_] for _ in
|
||||
# # torch.randperm(len(xlst))[:self.query]]
|
||||
# # self.ylst += [cls.split('/')[-1]] * len(xlst)
|
||||
|
||||
self.way_idx = 0
|
||||
self.x_idx = 0
|
||||
self.way = 2
|
||||
self.cls_lst = None
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.way * self.query
|
||||
|
||||
def __getitem__(self, index):
|
||||
# if self.way != index[1]:
|
||||
# self.way = index[1]
|
||||
# index = index[0]
|
||||
|
||||
x = Image.open(
|
||||
self.file_path_lst[index]).convert('RGB')
|
||||
|
||||
if self.transform is not None:
|
||||
x = self.transform(x)
|
||||
cls_name = self.ylst[index]
|
||||
y = self.cls_lst.index(cls_name)
|
||||
# y = self.way_idx
|
||||
# self.x_idx += 1
|
||||
# if self.x_idx == self.query:
|
||||
# self.way_idx += 1
|
||||
# self.x_idx = 0
|
||||
# if self.way_idx == self.way:
|
||||
# self.way_idx = 0
|
||||
# self.x_idx = 0
|
||||
return x, y #, cls_name # y # cls_name #y
|
||||
|
||||
def _find_classes(self, dir: str):
|
||||
"""
|
||||
Finds the class folders in a dataset.
|
||||
|
||||
Args:
|
||||
dir (string): Root directory path.
|
||||
|
||||
Returns:
|
||||
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
|
||||
|
||||
Ensures:
|
||||
No class is a subdirectory of another.
|
||||
"""
|
||||
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
|
||||
classes.sort()
|
||||
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
||||
return classes, class_to_idx
|
||||
|
||||
|
||||
class MetaDataLoader(DataLoader):
|
||||
def __init__(self,
|
||||
dataset, sampler, batch_size, shuffle, num_workers):
|
||||
super(MetaDataLoader, self).__init__(
|
||||
dataset=dataset,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers)
|
||||
|
||||
|
||||
def create_episode(self):
|
||||
self.sampler.create_episode()
|
||||
self.dataset.way = self.sampler.way
|
||||
self.dataset.cls_lst = self.sampler.cls_lst.tolist()
|
||||
|
||||
|
||||
def get_cls_idx(self):
|
||||
return self.sampler.cls_lst
|
||||
|
||||
|
||||
def get_loader(mode='val', way=10, query=10,
|
||||
n_epi=100, dpath='/w14/dataset/ILSVRC2012',
|
||||
transform=None):
|
||||
trans = get_transforms(mode)
|
||||
dataset = MetaImageNetDataset(mode, way, query, dpath, trans)
|
||||
sampler = EpisodeSampler(
|
||||
way, query, n_epi, dataset.ylst)
|
||||
dataset.way = sampler.way
|
||||
dataset.cls_lst = sampler.cls_lst
|
||||
loader = MetaDataLoader(dataset=dataset,
|
||||
sampler=sampler,
|
||||
batch_size=10,
|
||||
shuffle=False,
|
||||
num_workers=4)
|
||||
return loader
|
||||
|
||||
# trloader = get_loader()
|
||||
|
||||
# trloader.create_episode()
|
||||
# print(len(trloader))
|
||||
# print(trloader.dataset.way)
|
||||
# print(trloader.sampler.way)
|
||||
# for i, episode in enumerate(trloader, start=1):
|
||||
# print(episode[2])
|
@@ -0,0 +1,302 @@
|
||||
######################################################################################
|
||||
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import os
|
||||
import json
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim
|
||||
from tqdm import tqdm
|
||||
from utils import decode_ofa_mbv3_to_igraph
|
||||
from ofa_local.utils import get_net_info, cross_entropy_loss_with_soft_target, cross_entropy_with_label_smoothing
|
||||
from ofa_local.utils import AverageMeter, accuracy, write_log, mix_images, mix_labels, init_models
|
||||
|
||||
__all__ = ['RunManager']
|
||||
import torchvision.models as models
|
||||
|
||||
|
||||
class RunManager:
|
||||
|
||||
def __init__(self, path, args, net, run_config, init=True, measure_latency=None,
|
||||
no_gpu=False, data_loader=None, pp=None):
|
||||
self.path = path
|
||||
self.mode = args.model_name
|
||||
self.net = net
|
||||
self.run_config = run_config
|
||||
|
||||
self.best_acc = 0
|
||||
self.start_epoch = 0
|
||||
|
||||
os.makedirs(self.path, exist_ok=True)
|
||||
# dataloader
|
||||
if data_loader is not None:
|
||||
self.data_loader = data_loader
|
||||
cls_lst = self.data_loader.get_cls_idx()
|
||||
self.cls_lst = cls_lst
|
||||
else:
|
||||
self.data_loader = self.run_config.valid_loader
|
||||
self.data_loader.create_episode()
|
||||
cls_lst = self.data_loader.get_cls_idx()
|
||||
self.cls_lst = cls_lst
|
||||
|
||||
state_dict = self.net.classifier.state_dict()
|
||||
new_state_dict = {'weight': state_dict['linear.weight'][cls_lst],
|
||||
'bias': state_dict['linear.bias'][cls_lst]}
|
||||
|
||||
self.net.classifier = nn.Linear(1280, len(cls_lst), bias=True)
|
||||
self.net.classifier.load_state_dict(new_state_dict)
|
||||
|
||||
# move network to GPU if available
|
||||
if torch.cuda.is_available() and (not no_gpu):
|
||||
self.device = torch.device('cuda:0')
|
||||
self.net = self.net.to(self.device)
|
||||
cudnn.benchmark = True
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
|
||||
# net info
|
||||
net_info = get_net_info(
|
||||
self.net, self.run_config.data_provider.data_shape, measure_latency, False)
|
||||
self.net_info = net_info
|
||||
self.test_transform = self.run_config.data_provider.test.dataset.transform
|
||||
|
||||
# criterion
|
||||
if isinstance(self.run_config.mixup_alpha, float):
|
||||
self.train_criterion = cross_entropy_loss_with_soft_target
|
||||
elif self.run_config.label_smoothing > 0:
|
||||
self.train_criterion = \
|
||||
lambda pred, target: cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
|
||||
else:
|
||||
self.train_criterion = nn.CrossEntropyLoss()
|
||||
self.test_criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# optimizer
|
||||
if self.run_config.no_decay_keys:
|
||||
keys = self.run_config.no_decay_keys.split('#')
|
||||
net_params = [
|
||||
self.network.get_parameters(keys, mode='exclude'), # parameters with weight decay
|
||||
self.network.get_parameters(keys, mode='include'), # parameters without weight decay
|
||||
]
|
||||
else:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
net_params = self.network.weight_parameters()
|
||||
except Exception:
|
||||
net_params = []
|
||||
for param in self.network.parameters():
|
||||
if param.requires_grad:
|
||||
net_params.append(param)
|
||||
self.optimizer = self.run_config.build_optimizer(net_params)
|
||||
|
||||
self.net = torch.nn.DataParallel(self.net)
|
||||
|
||||
if self.mode == 'generator':
|
||||
# PP
|
||||
save_dir = f'{args.save_path}/predictor/model/ckpt_max_corr.pt'
|
||||
|
||||
self.acc_predictor = pp.to('cuda')
|
||||
self.acc_predictor.load_state_dict(torch.load(save_dir))
|
||||
self.acc_predictor = torch.nn.DataParallel(self.acc_predictor)
|
||||
model = models.resnet18(pretrained=True).eval()
|
||||
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1]).to(self.device)
|
||||
self.feature_extractor = torch.nn.DataParallel(feature_extractor)
|
||||
|
||||
""" save path and log path """
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self.__dict__.get('_save_path', None) is None:
|
||||
save_path = os.path.join(self.path, 'checkpoint')
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
self.__dict__['_save_path'] = save_path
|
||||
return self.__dict__['_save_path']
|
||||
|
||||
@property
|
||||
def logs_path(self):
|
||||
if self.__dict__.get('_logs_path', None) is None:
|
||||
logs_path = os.path.join(self.path, 'logs')
|
||||
os.makedirs(logs_path, exist_ok=True)
|
||||
self.__dict__['_logs_path'] = logs_path
|
||||
return self.__dict__['_logs_path']
|
||||
|
||||
@property
|
||||
def network(self):
|
||||
return self.net.module if isinstance(self.net, nn.DataParallel) else self.net
|
||||
|
||||
def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
|
||||
write_log(self.logs_path, log_str, prefix, should_print, mode)
|
||||
|
||||
""" save and load models """
|
||||
|
||||
def save_model(self, checkpoint=None, is_best=False, model_name=None):
|
||||
if checkpoint is None:
|
||||
checkpoint = {'state_dict': self.network.state_dict()}
|
||||
|
||||
if model_name is None:
|
||||
model_name = 'checkpoint.pth.tar'
|
||||
|
||||
checkpoint['dataset'] = self.run_config.dataset # add `dataset` info to the checkpoint
|
||||
latest_fname = os.path.join(self.save_path, 'latest.txt')
|
||||
model_path = os.path.join(self.save_path, model_name)
|
||||
with open(latest_fname, 'w') as fout:
|
||||
fout.write(model_path + '\n')
|
||||
torch.save(checkpoint, model_path)
|
||||
|
||||
if is_best:
|
||||
best_path = os.path.join(self.save_path, 'model_best.pth.tar')
|
||||
torch.save({'state_dict': checkpoint['state_dict']}, best_path)
|
||||
|
||||
def load_model(self, model_fname=None):
|
||||
latest_fname = os.path.join(self.save_path, 'latest.txt')
|
||||
if model_fname is None and os.path.exists(latest_fname):
|
||||
with open(latest_fname, 'r') as fin:
|
||||
model_fname = fin.readline()
|
||||
if model_fname[-1] == '\n':
|
||||
model_fname = model_fname[:-1]
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if model_fname is None or not os.path.exists(model_fname):
|
||||
model_fname = '%s/checkpoint.pth.tar' % self.save_path
|
||||
with open(latest_fname, 'w') as fout:
|
||||
fout.write(model_fname + '\n')
|
||||
print("=> loading checkpoint '{}'".format(model_fname))
|
||||
checkpoint = torch.load(model_fname, map_location='cpu')
|
||||
except Exception:
|
||||
print('fail to load checkpoint from %s' % self.save_path)
|
||||
return {}
|
||||
|
||||
self.network.load_state_dict(checkpoint['state_dict'])
|
||||
if 'epoch' in checkpoint:
|
||||
self.start_epoch = checkpoint['epoch'] + 1
|
||||
if 'best_acc' in checkpoint:
|
||||
self.best_acc = checkpoint['best_acc']
|
||||
if 'optimizer' in checkpoint:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
print("=> loaded checkpoint '{}'".format(model_fname))
|
||||
return checkpoint
|
||||
|
||||
def save_config(self, extra_run_config=None, extra_net_config=None):
|
||||
""" dump run_config and net_config to the model_folder """
|
||||
run_save_path = os.path.join(self.path, 'run.config')
|
||||
if not os.path.isfile(run_save_path):
|
||||
run_config = self.run_config.config
|
||||
if extra_run_config is not None:
|
||||
run_config.update(extra_run_config)
|
||||
json.dump(run_config, open(run_save_path, 'w'), indent=4)
|
||||
print('Run configs dump to %s' % run_save_path)
|
||||
|
||||
try:
|
||||
net_save_path = os.path.join(self.path, 'net.config')
|
||||
net_config = self.network.config
|
||||
if extra_net_config is not None:
|
||||
net_config.update(extra_net_config)
|
||||
json.dump(net_config, open(net_save_path, 'w'), indent=4)
|
||||
print('Network configs dump to %s' % net_save_path)
|
||||
except Exception:
|
||||
print('%s do not support net config' % type(self.network))
|
||||
|
||||
""" metric related """
|
||||
|
||||
def get_metric_dict(self):
|
||||
return {
|
||||
'top1': AverageMeter(),
|
||||
'top5': AverageMeter(),
|
||||
}
|
||||
|
||||
def update_metric(self, metric_dict, output, labels):
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
metric_dict['top1'].update(acc1[0].item(), output.size(0))
|
||||
metric_dict['top5'].update(acc5[0].item(), output.size(0))
|
||||
|
||||
def get_metric_vals(self, metric_dict, return_dict=False):
|
||||
if return_dict:
|
||||
return {
|
||||
key: metric_dict[key].avg for key in metric_dict
|
||||
}
|
||||
else:
|
||||
return [metric_dict[key].avg for key in metric_dict]
|
||||
|
||||
def get_metric_names(self):
|
||||
return 'top1', 'top5'
|
||||
|
||||
""" train and test """
|
||||
def validate(self, epoch=0, is_test=False, run_str='', net=None,
|
||||
data_loader=None, no_logs=False, train_mode=False, net_setting=None):
|
||||
if net is None:
|
||||
net = self.net
|
||||
if not isinstance(net, nn.DataParallel):
|
||||
net = nn.DataParallel(net)
|
||||
|
||||
if data_loader is not None:
|
||||
self.data_loader = data_loader
|
||||
|
||||
if train_mode:
|
||||
net.train()
|
||||
else:
|
||||
net.eval()
|
||||
|
||||
losses = AverageMeter()
|
||||
metric_dict = self.get_metric_dict()
|
||||
|
||||
features_stack = []
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(self.data_loader),
|
||||
desc='Validate Epoch #{} {}'.format(epoch + 1, run_str), disable=no_logs) as t:
|
||||
for i, (images, labels) in enumerate(self.data_loader):
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
if self.mode == 'generator':
|
||||
features = self.feature_extractor(images).squeeze()
|
||||
features_stack.append(features)
|
||||
# compute output
|
||||
output = net(images)
|
||||
loss = self.test_criterion(output, labels)
|
||||
# measure accuracy and record loss
|
||||
self.update_metric(metric_dict, output, labels)
|
||||
|
||||
losses.update(loss.item(), images.size(0))
|
||||
t.set_postfix({
|
||||
'loss': losses.avg,
|
||||
**self.get_metric_vals(metric_dict, return_dict=True),
|
||||
'img_size': images.size(2),
|
||||
})
|
||||
t.update(1)
|
||||
|
||||
if self.mode == 'generator':
|
||||
features_stack = torch.cat(features_stack)
|
||||
igraph_g = decode_ofa_mbv3_to_igraph(net_setting)[0]
|
||||
D_mu = self.acc_predictor.module.set_encode(features_stack.unsqueeze(0).to('cuda'))
|
||||
G_mu = self.acc_predictor.module.graph_encode(igraph_g)
|
||||
pred_acc = self.acc_predictor.module.predict(D_mu.unsqueeze(0), G_mu).item()
|
||||
|
||||
return losses.avg, self.get_metric_vals(metric_dict), \
|
||||
pred_acc if self.mode == 'generator' else None
|
||||
|
||||
|
||||
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
|
||||
if net is None:
|
||||
net = self.network
|
||||
if isinstance(self.run_config.data_provider.image_size, list):
|
||||
img_size_list, loss_list, top1_list, top5_list = [], [], [], []
|
||||
for img_size in self.run_config.data_provider.image_size:
|
||||
img_size_list.append(img_size)
|
||||
self.run_config.data_provider.assign_active_img_size(img_size)
|
||||
self.reset_running_statistics(net=net)
|
||||
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
|
||||
loss_list.append(loss)
|
||||
top1_list.append(top1)
|
||||
top5_list.append(top5)
|
||||
return img_size_list, loss_list, top1_list, top5_list
|
||||
else:
|
||||
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
|
||||
return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]
|
||||
|
||||
def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
|
||||
from ofa_local.imagenet_classification.elastic_nn.utils import set_running_statistics
|
||||
if net is None:
|
||||
net = self.network
|
||||
if data_loader is None:
|
||||
data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
|
||||
set_running_statistics(net, data_loader)
|
Reference in New Issue
Block a user