upload
This commit is contained in:
16
zero-cost-nas/foresight/__init__.py
Normal file
16
zero-cost-nas/foresight/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
from .version import *
|
121
zero-cost-nas/foresight/dataset.py
Normal file
121
zero-cost-nas/foresight/dataset.py
Normal file
@@ -0,0 +1,121 @@
|
||||
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN
|
||||
from torchvision.transforms import Compose, ToTensor, Normalize
|
||||
from torchvision import transforms
|
||||
|
||||
from .imagenet16 import *
|
||||
|
||||
|
||||
def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'):
|
||||
|
||||
if 'ImageNet16' in dataset:
|
||||
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
||||
std = [x / 255 for x in [63.22, 61.26 , 65.09]]
|
||||
size, pad = 16, 2
|
||||
elif 'cifar' in dataset:
|
||||
mean = (0.4914, 0.4822, 0.4465)
|
||||
std = (0.2023, 0.1994, 0.2010)
|
||||
size, pad = 32, 4
|
||||
elif 'svhn' in dataset:
|
||||
mean = (0.5, 0.5, 0.5)
|
||||
std = (0.5, 0.5, 0.5)
|
||||
size, pad = 32, 0
|
||||
elif dataset == 'ImageNet1k':
|
||||
from .h5py_dataset import H5Dataset
|
||||
size,pad = 224,2
|
||||
mean = (0.485, 0.456, 0.406)
|
||||
std = (0.229, 0.224, 0.225)
|
||||
#resize = 256
|
||||
|
||||
if resize is None:
|
||||
resize = size
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.RandomCrop(size, padding=pad),
|
||||
transforms.Resize(resize),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean,std),
|
||||
])
|
||||
|
||||
test_transform = transforms.Compose([
|
||||
transforms.Resize(resize),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean,std),
|
||||
])
|
||||
|
||||
if dataset == 'cifar10':
|
||||
train_dataset = CIFAR10(datadir, True, train_transform, download=True)
|
||||
test_dataset = CIFAR10(datadir, False, test_transform, download=True)
|
||||
elif dataset == 'cifar100':
|
||||
train_dataset = CIFAR100(datadir, True, train_transform, download=True)
|
||||
test_dataset = CIFAR100(datadir, False, test_transform, download=True)
|
||||
elif dataset == 'svhn':
|
||||
train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True)
|
||||
test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True)
|
||||
elif dataset == 'ImageNet16-120':
|
||||
train_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), True , train_transform, 120)
|
||||
test_dataset = ImageNet16(os.path.join(datadir, 'ImageNet16'), False, test_transform , 120)
|
||||
elif dataset == 'ImageNet1k':
|
||||
train_dataset = H5Dataset(os.path.join(datadir, 'imagenet-train-256.h5'), transform=train_transform)
|
||||
test_dataset = H5Dataset(os.path.join(datadir, 'imagenet-val-256.h5'), transform=test_transform)
|
||||
|
||||
else:
|
||||
raise ValueError('There are no more cifars or imagenets.')
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
test_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True)
|
||||
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
def get_mnist_dataloaders(train_batch_size, val_batch_size, num_workers):
|
||||
|
||||
data_transform = Compose([transforms.ToTensor()])
|
||||
|
||||
# Normalise? transforms.Normalize((0.1307,), (0.3081,))
|
||||
|
||||
train_dataset = MNIST("_dataset", True, data_transform, download=True)
|
||||
test_dataset = MNIST("_dataset", False, data_transform, download=True)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
val_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True)
|
||||
|
||||
return train_loader, test_loader
|
||||
|
55
zero-cost-nas/foresight/h5py_dataset.py
Normal file
55
zero-cost-nas/foresight/h5py_dataset.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class H5Dataset(Dataset):
|
||||
def __init__(self, h5_path, transform=None):
|
||||
self.h5_path = h5_path
|
||||
self.h5_file = None
|
||||
self.length = len(h5py.File(h5_path, 'r'))
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
#loading in getitem allows us to use multiple processes for data loading
|
||||
#because hdf5 files aren't pickelable so can't transfer them across processes
|
||||
# https://discuss.pytorch.org/t/hdf5-a-data-format-for-pytorch/40379
|
||||
# https://discuss.pytorch.org/t/dataloader-when-num-worker-0-there-is-bug/25643/16
|
||||
# TODO possible look at __getstate__ and __setstate__ as a more elegant solution
|
||||
if self.h5_file is None:
|
||||
self.h5_file = h5py.File(self.h5_path, 'r')
|
||||
|
||||
record = self.h5_file[str(index)]
|
||||
|
||||
if self.transform:
|
||||
x = Image.fromarray(record['data'][()])
|
||||
x = self.transform(x)
|
||||
else:
|
||||
x = torch.from_numpy(record['data'][()])
|
||||
|
||||
y = record['target'][()]
|
||||
y = torch.from_numpy(np.asarray(y))
|
||||
|
||||
return (x,y)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
129
zero-cost-nas/foresight/imagenet16.py
Normal file
129
zero-cost-nas/foresight/imagenet16.py
Normal file
@@ -0,0 +1,129 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, hashlib, torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch.utils.data as data
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
|
||||
def calculate_md5(fpath, chunk_size=1024 * 1024):
|
||||
md5 = hashlib.md5()
|
||||
with open(fpath, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b''):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def check_md5(fpath, md5, **kwargs):
|
||||
return md5 == calculate_md5(fpath, **kwargs)
|
||||
|
||||
|
||||
def check_integrity(fpath, md5=None):
|
||||
if not os.path.isfile(fpath): return False
|
||||
if md5 is None: return True
|
||||
else : return check_md5(fpath, md5)
|
||||
|
||||
|
||||
class ImageNet16(data.Dataset):
|
||||
# http://image-net.org/download-images
|
||||
# A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
|
||||
# https://arxiv.org/pdf/1707.08819.pdf
|
||||
|
||||
train_list = [
|
||||
['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'],
|
||||
['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'],
|
||||
['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'],
|
||||
['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'],
|
||||
['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'],
|
||||
['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'],
|
||||
['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'],
|
||||
['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'],
|
||||
['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'],
|
||||
['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'],
|
||||
]
|
||||
valid_list = [
|
||||
['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'],
|
||||
]
|
||||
|
||||
def __init__(self, root, train, transform, use_num_of_class_only=None):
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.train = train # training set or valid set
|
||||
if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.')
|
||||
|
||||
if self.train: downloaded_list = self.train_list
|
||||
else : downloaded_list = self.valid_list
|
||||
self.data = []
|
||||
self.targets = []
|
||||
|
||||
# now load the picked numpy arrays
|
||||
for i, (file_name, checksum) in enumerate(downloaded_list):
|
||||
file_path = os.path.join(self.root, file_name)
|
||||
#print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
|
||||
with open(file_path, 'rb') as f:
|
||||
if sys.version_info[0] == 2:
|
||||
entry = pickle.load(f)
|
||||
else:
|
||||
entry = pickle.load(f, encoding='latin1')
|
||||
self.data.append(entry['data'])
|
||||
self.targets.extend(entry['labels'])
|
||||
self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
|
||||
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
||||
if use_num_of_class_only is not None:
|
||||
assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only)
|
||||
new_data, new_targets = [], []
|
||||
for I, L in zip(self.data, self.targets):
|
||||
if 1 <= L <= use_num_of_class_only:
|
||||
new_data.append( I )
|
||||
new_targets.append( L )
|
||||
self.data = new_data
|
||||
self.targets = new_targets
|
||||
# self.mean.append(entry['mean'])
|
||||
#self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16)
|
||||
#self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1)
|
||||
#print ('Mean : {:}'.format(self.mean))
|
||||
#temp = self.data - np.reshape(self.mean, (1, 1, 1, 3))
|
||||
#std_data = np.std(temp, axis=0)
|
||||
#std_data = np.mean(np.mean(std_data, axis=0), axis=0)
|
||||
#print ('Std : {:}'.format(std_data))
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.data[index], self.targets[index] - 1
|
||||
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def _check_integrity(self):
|
||||
root = self.root
|
||||
for fentry in (self.train_list + self.valid_list):
|
||||
filename, md5 = fentry[0], fentry[1]
|
||||
fpath = os.path.join(root, filename)
|
||||
if not check_integrity(fpath, md5):
|
||||
return False
|
||||
return True
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None)
|
||||
valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)
|
||||
|
||||
print ( len(train) )
|
||||
print ( len(valid) )
|
||||
image, label = train[111]
|
||||
trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200)
|
||||
validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200)
|
||||
print ( len(trainX) )
|
||||
print ( len(validX) )
|
||||
#import pdb; pdb.set_trace()
|
19
zero-cost-nas/foresight/models/__init__.py
Normal file
19
zero-cost-nas/foresight/models/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
from os.path import dirname, basename, isfile, join
|
||||
import glob
|
||||
modules = glob.glob(join(dirname(__file__), "*.py"))
|
||||
__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]
|
251
zero-cost-nas/foresight/models/nasbench1.py
Normal file
251
zero-cost-nas/foresight/models/nasbench1.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Builds the Pytorch computational graph.
|
||||
Tensors flowing into a single vertex are added together for all vertices
|
||||
except the output, which is concatenated instead. Tensors flowing out of input
|
||||
are always added.
|
||||
If interior edge channels don't match, drop the extra channels (channels are
|
||||
guaranteed non-decreasing). Tensors flowing out of the input as always
|
||||
projected instead.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
from .nasbench1_ops import *
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Network(nn.Module):
|
||||
def __init__(self, spec, stem_out, num_stacks, num_mods, num_classes, bn=True):
|
||||
super(Network, self).__init__()
|
||||
|
||||
self.spec=spec
|
||||
self.stem_out=stem_out
|
||||
self.num_stacks=num_stacks
|
||||
self.num_mods=num_mods
|
||||
self.num_classes=num_classes
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
in_channels = 3
|
||||
out_channels = stem_out
|
||||
|
||||
# initial stem convolution
|
||||
stem_conv = ConvBnRelu(in_channels, out_channels, 3, 1, 1, bn=bn)
|
||||
self.layers.append(stem_conv)
|
||||
|
||||
in_channels = out_channels
|
||||
for stack_num in range(num_stacks):
|
||||
if stack_num > 0:
|
||||
downsample = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.layers.append(downsample)
|
||||
|
||||
out_channels *= 2
|
||||
|
||||
for _ in range(num_mods):
|
||||
cell = Cell(spec, in_channels, out_channels, bn=bn)
|
||||
self.layers.append(cell)
|
||||
in_channels = out_channels
|
||||
|
||||
self.classifier = nn.Linear(out_channels, num_classes)
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x):
|
||||
for _, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
out = torch.mean(x, (2, 3))
|
||||
out = self.classifier(out)
|
||||
|
||||
return out
|
||||
|
||||
def get_prunable_copy(self, bn=False):
|
||||
|
||||
model_new = Network(self.spec, self.stem_out, self.num_stacks, self.num_mods, self.num_classes, bn=bn)
|
||||
|
||||
#TODO this is quite brittle and doesn't work with nn.Sequential when bn is different
|
||||
# it is only required to maintain initialization -- maybe init after get_punable_copy?
|
||||
model_new.load_state_dict(self.state_dict(), strict=False)
|
||||
model_new.train()
|
||||
|
||||
return model_new
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / n))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
n = m.weight.size(1)
|
||||
m.weight.data.normal_(0, 0.01)
|
||||
m.bias.data.zero_()
|
||||
|
||||
class Cell(nn.Module):
|
||||
"""
|
||||
Builds the model using the adjacency matrix and op labels specified. Channels
|
||||
controls the module output channel count but the interior channels are
|
||||
determined via equally splitting the channel count whenever there is a
|
||||
concatenation of Tensors.
|
||||
"""
|
||||
def __init__(self, spec, in_channels, out_channels, bn=True):
|
||||
super(Cell, self).__init__()
|
||||
|
||||
self.spec = spec
|
||||
self.num_vertices = np.shape(self.spec.matrix)[0]
|
||||
|
||||
# vertex_channels[i] = number of output channels of vertex i
|
||||
self.vertex_channels = ComputeVertexChannels(in_channels, out_channels, self.spec.matrix)
|
||||
#self.vertex_channels = [in_channels] + [out_channels] * (self.num_vertices - 1)
|
||||
|
||||
# operation for each node
|
||||
self.vertex_op = nn.ModuleList([None])
|
||||
for t in range(1, self.num_vertices-1):
|
||||
op = OP_MAP[spec.ops[t]](self.vertex_channels[t], self.vertex_channels[t], bn=bn)
|
||||
self.vertex_op.append(op)
|
||||
|
||||
# operation for input on each vertex
|
||||
self.input_op = nn.ModuleList([None])
|
||||
for t in range(1, self.num_vertices):
|
||||
if self.spec.matrix[0, t]:
|
||||
self.input_op.append(Projection(in_channels, self.vertex_channels[t], bn=bn))
|
||||
else:
|
||||
self.input_op.append(None)
|
||||
|
||||
def forward(self, x):
|
||||
tensors = [x]
|
||||
|
||||
out_concat = []
|
||||
for t in range(1, self.num_vertices-1):
|
||||
fan_in = [Truncate(tensors[src], self.vertex_channels[t]) for src in range(1, t) if self.spec.matrix[src, t]]
|
||||
|
||||
if self.spec.matrix[0, t]:
|
||||
fan_in.append(self.input_op[t](x))
|
||||
|
||||
# perform operation on node
|
||||
#vertex_input = torch.stack(fan_in, dim=0).sum(dim=0)
|
||||
vertex_input = sum(fan_in)
|
||||
#vertex_input = sum(fan_in) / len(fan_in)
|
||||
vertex_output = self.vertex_op[t](vertex_input)
|
||||
|
||||
tensors.append(vertex_output)
|
||||
if self.spec.matrix[t, self.num_vertices-1]:
|
||||
out_concat.append(tensors[t])
|
||||
|
||||
if not out_concat:
|
||||
assert self.spec.matrix[0, self.num_vertices-1]
|
||||
outputs = self.input_op[self.num_vertices-1](tensors[0])
|
||||
else:
|
||||
if len(out_concat) == 1:
|
||||
outputs = out_concat[0]
|
||||
else:
|
||||
outputs = torch.cat(out_concat, 1)
|
||||
|
||||
if self.spec.matrix[0, self.num_vertices-1]:
|
||||
outputs += self.input_op[self.num_vertices-1](tensors[0])
|
||||
|
||||
#if self.spec.matrix[0, self.num_vertices-1]:
|
||||
# out_concat.append(self.input_op[self.num_vertices-1](tensors[0]))
|
||||
#outputs = sum(out_concat) / len(out_concat)
|
||||
|
||||
return outputs
|
||||
|
||||
def Projection(in_channels, out_channels, bn=True):
|
||||
"""1x1 projection (as in ResNet) followed by batch normalization and ReLU."""
|
||||
return ConvBnRelu(in_channels, out_channels, 1, bn=bn)
|
||||
|
||||
def Truncate(inputs, channels):
|
||||
"""Slice the inputs to channels if necessary."""
|
||||
input_channels = inputs.size()[1]
|
||||
if input_channels < channels:
|
||||
raise ValueError('input channel < output channels for truncate')
|
||||
elif input_channels == channels:
|
||||
return inputs # No truncation necessary
|
||||
else:
|
||||
# Truncation should only be necessary when channel division leads to
|
||||
# vertices with +1 channels. The input vertex should always be projected to
|
||||
# the minimum channel count.
|
||||
assert input_channels - channels == 1
|
||||
return inputs[:, :channels, :, :]
|
||||
|
||||
def ComputeVertexChannels(in_channels, out_channels, matrix):
|
||||
"""Computes the number of channels at every vertex.
|
||||
Given the input channels and output channels, this calculates the number of
|
||||
channels at each interior vertex. Interior vertices have the same number of
|
||||
channels as the max of the channels of the vertices it feeds into. The output
|
||||
channels are divided amongst the vertices that are directly connected to it.
|
||||
When the division is not even, some vertices may receive an extra channel to
|
||||
compensate.
|
||||
Returns:
|
||||
list of channel counts, in order of the vertices.
|
||||
"""
|
||||
num_vertices = np.shape(matrix)[0]
|
||||
|
||||
vertex_channels = [0] * num_vertices
|
||||
vertex_channels[0] = in_channels
|
||||
vertex_channels[num_vertices - 1] = out_channels
|
||||
|
||||
if num_vertices == 2:
|
||||
# Edge case where module only has input and output vertices
|
||||
return vertex_channels
|
||||
|
||||
# Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is
|
||||
# the dst vertex. Summing over 0 gives the in-degree count of each vertex.
|
||||
in_degree = np.sum(matrix[1:], axis=0)
|
||||
interior_channels = out_channels // in_degree[num_vertices - 1]
|
||||
correction = out_channels % in_degree[num_vertices - 1] # Remainder to add
|
||||
|
||||
# Set channels of vertices that flow directly to output
|
||||
for v in range(1, num_vertices - 1):
|
||||
if matrix[v, num_vertices - 1]:
|
||||
vertex_channels[v] = interior_channels
|
||||
if correction:
|
||||
vertex_channels[v] += 1
|
||||
correction -= 1
|
||||
|
||||
# Set channels for all other vertices to the max of the out edges, going
|
||||
# backwards. (num_vertices - 2) index skipped because it only connects to
|
||||
# output.
|
||||
for v in range(num_vertices - 3, 0, -1):
|
||||
if not matrix[v, num_vertices - 1]:
|
||||
for dst in range(v + 1, num_vertices - 1):
|
||||
if matrix[v, dst]:
|
||||
vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst])
|
||||
assert vertex_channels[v] > 0
|
||||
|
||||
# Sanity check, verify that channels never increase and final channels add up.
|
||||
final_fan_in = 0
|
||||
for v in range(1, num_vertices - 1):
|
||||
if matrix[v, num_vertices - 1]:
|
||||
final_fan_in += vertex_channels[v]
|
||||
for dst in range(v + 1, num_vertices - 1):
|
||||
if matrix[v, dst]:
|
||||
assert vertex_channels[v] >= vertex_channels[dst]
|
||||
assert final_fan_in == out_channels or num_vertices == 2
|
||||
# num_vertices == 2 means only input/output nodes, so 0 fan-in
|
||||
|
||||
return vertex_channels
|
83
zero-cost-nas/foresight/models/nasbench1_ops.py
Normal file
83
zero-cost-nas/foresight/models/nasbench1_ops.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Base operations used by the modules in this search space."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ConvBnRelu(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bn=True):
|
||||
super(ConvBnRelu, self).__init__()
|
||||
|
||||
if bn:
|
||||
self.conv_bn_relu = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=False)
|
||||
)
|
||||
else:
|
||||
self.conv_bn_relu = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
|
||||
nn.ReLU(inplace=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv_bn_relu(x)
|
||||
|
||||
class Conv3x3BnRelu(nn.Module):
|
||||
"""3x3 convolution with batch norm and ReLU activation."""
|
||||
def __init__(self, in_channels, out_channels, bn=True):
|
||||
super(Conv3x3BnRelu, self).__init__()
|
||||
|
||||
self.conv3x3 = ConvBnRelu(in_channels, out_channels, 3, 1, 1, bn=bn)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv3x3(x)
|
||||
return x
|
||||
|
||||
class Conv1x1BnRelu(nn.Module):
|
||||
"""1x1 convolution with batch norm and ReLU activation."""
|
||||
def __init__(self, in_channels, out_channels, bn=True):
|
||||
super(Conv1x1BnRelu, self).__init__()
|
||||
|
||||
self.conv1x1 = ConvBnRelu(in_channels, out_channels, 1, 1, 0, bn=bn)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
class MaxPool3x3(nn.Module):
|
||||
"""3x3 max pool with no subsampling."""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bn=None):
|
||||
super(MaxPool3x3, self).__init__()
|
||||
|
||||
self.maxpool = nn.MaxPool2d(kernel_size, stride, padding)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.maxpool(x)
|
||||
return x
|
||||
|
||||
# Commas should not be used in op names
|
||||
OP_MAP = {
|
||||
'conv3x3-bn-relu': Conv3x3BnRelu,
|
||||
'conv1x1-bn-relu': Conv1x1BnRelu,
|
||||
'maxpool3x3': MaxPool3x3
|
||||
}
|
294
zero-cost-nas/foresight/models/nasbench1_spec.py
Normal file
294
zero-cost-nas/foresight/models/nasbench1_spec.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Model specification for module connectivity individuals.
|
||||
This module handles pruning the unused parts of the computation graph but should
|
||||
avoid creating any TensorFlow models (this is done inside model_builder.py).
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import itertools
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Graphviz is optional and only required for visualization.
|
||||
try:
|
||||
import graphviz # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def _ToModelSpec(mat, ops):
|
||||
return ModelSpec(mat, ops)
|
||||
|
||||
def gen_is_edge_fn(bits):
|
||||
"""Generate a boolean function for the edge connectivity.
|
||||
Given a bitstring FEDCBA and a 4x4 matrix, the generated matrix is
|
||||
[[0, A, B, D],
|
||||
[0, 0, C, E],
|
||||
[0, 0, 0, F],
|
||||
[0, 0, 0, 0]]
|
||||
Note that this function is agnostic to the actual matrix dimension due to
|
||||
order in which elements are filled out (column-major, starting from least
|
||||
significant bit). For example, the same FEDCBA bitstring (0-padded) on a 5x5
|
||||
matrix is
|
||||
[[0, A, B, D, 0],
|
||||
[0, 0, C, E, 0],
|
||||
[0, 0, 0, F, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]]
|
||||
Args:
|
||||
bits: integer which will be interpreted as a bit mask.
|
||||
Returns:
|
||||
vectorized function that returns True when an edge is present.
|
||||
"""
|
||||
def is_edge(x, y):
|
||||
"""Is there an edge from x to y (0-indexed)?"""
|
||||
if x >= y:
|
||||
return 0
|
||||
# Map x, y to index into bit string
|
||||
index = x + (y * (y - 1) // 2)
|
||||
return (bits >> index) % 2 == 1
|
||||
|
||||
return np.vectorize(is_edge)
|
||||
|
||||
|
||||
def is_full_dag(matrix):
|
||||
"""Full DAG == all vertices on a path from vert 0 to (V-1).
|
||||
i.e. no disconnected or "hanging" vertices.
|
||||
It is sufficient to check for:
|
||||
1) no rows of 0 except for row V-1 (only output vertex has no out-edges)
|
||||
2) no cols of 0 except for col 0 (only input vertex has no in-edges)
|
||||
Args:
|
||||
matrix: V x V upper-triangular adjacency matrix
|
||||
Returns:
|
||||
True if the there are no dangling vertices.
|
||||
"""
|
||||
shape = np.shape(matrix)
|
||||
|
||||
rows = matrix[:shape[0]-1, :] == 0
|
||||
rows = np.all(rows, axis=1) # Any row with all 0 will be True
|
||||
rows_bad = np.any(rows)
|
||||
|
||||
cols = matrix[:, 1:] == 0
|
||||
cols = np.all(cols, axis=0) # Any col with all 0 will be True
|
||||
cols_bad = np.any(cols)
|
||||
|
||||
return (not rows_bad) and (not cols_bad)
|
||||
|
||||
|
||||
def num_edges(matrix):
|
||||
"""Computes number of edges in adjacency matrix."""
|
||||
return np.sum(matrix)
|
||||
|
||||
|
||||
def hash_module(matrix, labeling):
|
||||
"""Computes a graph-invariance MD5 hash of the matrix and label pair.
|
||||
Args:
|
||||
matrix: np.ndarray square upper-triangular adjacency matrix.
|
||||
labeling: list of int labels of length equal to both dimensions of
|
||||
matrix.
|
||||
Returns:
|
||||
MD5 hash of the matrix and labeling.
|
||||
"""
|
||||
vertices = np.shape(matrix)[0]
|
||||
in_edges = np.sum(matrix, axis=0).tolist()
|
||||
out_edges = np.sum(matrix, axis=1).tolist()
|
||||
|
||||
assert len(in_edges) == len(out_edges) == len(labeling)
|
||||
hashes = list(zip(out_edges, in_edges, labeling))
|
||||
hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes]
|
||||
# Computing this up to the diameter is probably sufficient but since the
|
||||
# operation is fast, it is okay to repeat more times.
|
||||
for _ in range(vertices):
|
||||
new_hashes = []
|
||||
for v in range(vertices):
|
||||
in_neighbors = [hashes[w] for w in range(vertices) if matrix[w, v]]
|
||||
out_neighbors = [hashes[w] for w in range(vertices) if matrix[v, w]]
|
||||
new_hashes.append(hashlib.md5(
|
||||
(''.join(sorted(in_neighbors)) + '|' +
|
||||
''.join(sorted(out_neighbors)) + '|' +
|
||||
hashes[v]).encode('utf-8')).hexdigest())
|
||||
hashes = new_hashes
|
||||
fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest()
|
||||
|
||||
return fingerprint
|
||||
|
||||
|
||||
def permute_graph(graph, label, permutation):
|
||||
"""Permutes the graph and labels based on permutation.
|
||||
Args:
|
||||
graph: np.ndarray adjacency matrix.
|
||||
label: list of labels of same length as graph dimensions.
|
||||
permutation: a permutation list of ints of same length as graph dimensions.
|
||||
Returns:
|
||||
np.ndarray where vertex permutation[v] is vertex v from the original graph
|
||||
"""
|
||||
# vertex permutation[v] in new graph is vertex v in the old graph
|
||||
forward_perm = zip(permutation, list(range(len(permutation))))
|
||||
inverse_perm = [x[1] for x in sorted(forward_perm)]
|
||||
edge_fn = lambda x, y: graph[inverse_perm[x], inverse_perm[y]] == 1
|
||||
new_matrix = np.fromfunction(np.vectorize(edge_fn),
|
||||
(len(label), len(label)),
|
||||
dtype=np.int8)
|
||||
new_label = [label[inverse_perm[i]] for i in range(len(label))]
|
||||
return new_matrix, new_label
|
||||
|
||||
|
||||
def is_isomorphic(graph1, graph2):
|
||||
"""Exhaustively checks if 2 graphs are isomorphic."""
|
||||
matrix1, label1 = np.array(graph1[0]), graph1[1]
|
||||
matrix2, label2 = np.array(graph2[0]), graph2[1]
|
||||
assert np.shape(matrix1) == np.shape(matrix2)
|
||||
assert len(label1) == len(label2)
|
||||
|
||||
vertices = np.shape(matrix1)[0]
|
||||
# Note: input and output in our constrained graphs always map to themselves
|
||||
# but this script does not enforce that.
|
||||
for perm in itertools.permutations(range(0, vertices)):
|
||||
pmatrix1, plabel1 = permute_graph(matrix1, label1, perm)
|
||||
if np.array_equal(pmatrix1, matrix2) and plabel1 == label2:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
class ModelSpec(object):
|
||||
"""Model specification given adjacency matrix and labeling."""
|
||||
|
||||
def __init__(self, matrix, ops, data_format='channels_last'):
|
||||
"""Initialize the module spec.
|
||||
Args:
|
||||
matrix: ndarray or nested list with shape [V, V] for the adjacency matrix.
|
||||
ops: V-length list of labels for the base ops used. The first and last
|
||||
elements are ignored because they are the input and output vertices
|
||||
which have no operations. The elements are retained to keep consistent
|
||||
indexing.
|
||||
data_format: channels_last or channels_first.
|
||||
Raises:
|
||||
ValueError: invalid matrix or ops
|
||||
"""
|
||||
if not isinstance(matrix, np.ndarray):
|
||||
matrix = np.array(matrix)
|
||||
shape = np.shape(matrix)
|
||||
if len(shape) != 2 or shape[0] != shape[1]:
|
||||
raise ValueError('matrix must be square')
|
||||
if shape[0] != len(ops):
|
||||
raise ValueError('length of ops must match matrix dimensions')
|
||||
if not is_upper_triangular(matrix):
|
||||
raise ValueError('matrix must be upper triangular')
|
||||
|
||||
# Both the original and pruned matrices are deep copies of the matrix and
|
||||
# ops so any changes to those after initialization are not recognized by the
|
||||
# spec.
|
||||
self.original_matrix = copy.deepcopy(matrix)
|
||||
self.original_ops = copy.deepcopy(ops)
|
||||
|
||||
self.matrix = copy.deepcopy(matrix)
|
||||
self.ops = copy.deepcopy(ops)
|
||||
self.valid_spec = True
|
||||
self._prune()
|
||||
|
||||
self.data_format = data_format
|
||||
|
||||
def _prune(self):
|
||||
"""Prune the extraneous parts of the graph.
|
||||
General procedure:
|
||||
1) Remove parts of graph not connected to input.
|
||||
2) Remove parts of graph not connected to output.
|
||||
3) Reorder the vertices so that they are consecutive after steps 1 and 2.
|
||||
These 3 steps can be combined by deleting the rows and columns of the
|
||||
vertices that are not reachable from both the input and output (in reverse).
|
||||
"""
|
||||
num_vertices = np.shape(self.original_matrix)[0]
|
||||
|
||||
# DFS forward from input
|
||||
visited_from_input = set([0])
|
||||
frontier = [0]
|
||||
while frontier:
|
||||
top = frontier.pop()
|
||||
for v in range(top + 1, num_vertices):
|
||||
if self.original_matrix[top, v] and v not in visited_from_input:
|
||||
visited_from_input.add(v)
|
||||
frontier.append(v)
|
||||
|
||||
# DFS backward from output
|
||||
visited_from_output = set([num_vertices - 1])
|
||||
frontier = [num_vertices - 1]
|
||||
while frontier:
|
||||
top = frontier.pop()
|
||||
for v in range(0, top):
|
||||
if self.original_matrix[v, top] and v not in visited_from_output:
|
||||
visited_from_output.add(v)
|
||||
frontier.append(v)
|
||||
|
||||
# Any vertex that isn't connected to both input and output is extraneous to
|
||||
# the computation graph.
|
||||
extraneous = set(range(num_vertices)).difference(
|
||||
visited_from_input.intersection(visited_from_output))
|
||||
|
||||
# If the non-extraneous graph is less than 2 vertices, the input is not
|
||||
# connected to the output and the spec is invalid.
|
||||
if len(extraneous) > num_vertices - 2:
|
||||
self.matrix = None
|
||||
self.ops = None
|
||||
self.valid_spec = False
|
||||
return
|
||||
|
||||
self.matrix = np.delete(self.matrix, list(extraneous), axis=0)
|
||||
self.matrix = np.delete(self.matrix, list(extraneous), axis=1)
|
||||
for index in sorted(extraneous, reverse=True):
|
||||
del self.ops[index]
|
||||
|
||||
def hash_spec(self, canonical_ops):
|
||||
"""Computes the isomorphism-invariant graph hash of this spec.
|
||||
Args:
|
||||
canonical_ops: list of operations in the canonical ordering which they
|
||||
were assigned (i.e. the order provided in the config['available_ops']).
|
||||
Returns:
|
||||
MD5 hash of this spec which can be used to query the dataset.
|
||||
"""
|
||||
# Invert the operations back to integer label indices used in graph gen.
|
||||
labeling = [-1] + [canonical_ops.index(op) for op in self.ops[1:-1]] + [-2]
|
||||
return graph_util.hash_module(self.matrix, labeling)
|
||||
|
||||
def visualize(self):
|
||||
"""Creates a dot graph. Can be visualized in colab directly."""
|
||||
num_vertices = np.shape(self.matrix)[0]
|
||||
g = graphviz.Digraph()
|
||||
g.node(str(0), 'input')
|
||||
for v in range(1, num_vertices - 1):
|
||||
g.node(str(v), self.ops[v])
|
||||
g.node(str(num_vertices - 1), 'output')
|
||||
|
||||
for src in range(num_vertices - 1):
|
||||
for dst in range(src + 1, num_vertices):
|
||||
if self.matrix[src, dst]:
|
||||
g.edge(str(src), str(dst))
|
||||
|
||||
return g
|
||||
|
||||
|
||||
def is_upper_triangular(matrix):
|
||||
"""True if matrix is 0 on diagonal and below."""
|
||||
for src in range(np.shape(matrix)[0]):
|
||||
for dst in range(0, src + 1):
|
||||
if matrix[src, dst] != 0:
|
||||
return False
|
||||
|
||||
return True
|
121
zero-cost-nas/foresight/models/nasbench2.py
Normal file
121
zero-cost-nas/foresight/models/nasbench2.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .nasbench2_ops import *
|
||||
|
||||
|
||||
def gen_searchcell_mask_from_arch_str(arch_str):
|
||||
nodes = arch_str.split('+')
|
||||
nodes = [node[1:-1].split('|') for node in nodes]
|
||||
nodes = [[op_and_input.split('~') for op_and_input in node] for node in nodes]
|
||||
|
||||
keep_mask = []
|
||||
for curr_node_idx in range(len(nodes)):
|
||||
for prev_node_idx in range(curr_node_idx+1):
|
||||
_op = [edge[0] for edge in nodes[curr_node_idx] if int(edge[1]) == prev_node_idx]
|
||||
assert len(_op) == 1, 'The arch string does not follow the assumption of 1 connection between two nodes.'
|
||||
for _op_name in OPS.keys():
|
||||
keep_mask.append(_op[0] == _op_name)
|
||||
return keep_mask
|
||||
|
||||
|
||||
def get_model_from_arch_str(arch_str, num_classes, use_bn=True, init_channels=16):
|
||||
keep_mask = gen_searchcell_mask_from_arch_str(arch_str)
|
||||
net = NAS201Model(arch_str=arch_str, num_classes=num_classes, use_bn=use_bn, keep_mask=keep_mask, stem_ch=init_channels)
|
||||
return net
|
||||
|
||||
|
||||
def get_super_model(num_classes, use_bn=True):
|
||||
net = NAS201Model(arch_str=arch_str, num_classes=num_classes, use_bn=use_bn)
|
||||
return net
|
||||
|
||||
|
||||
class NAS201Model(nn.Module):
|
||||
|
||||
def __init__(self, arch_str, num_classes, use_bn=True, keep_mask=None, stem_ch=16):
|
||||
super(NAS201Model, self).__init__()
|
||||
self.arch_str=arch_str
|
||||
self.num_classes=num_classes
|
||||
self.use_bn= use_bn
|
||||
|
||||
self.stem = stem(out_channels=stem_ch, use_bn=use_bn)
|
||||
self.stack_cell1 = nn.Sequential(*[SearchCell(in_channels=stem_ch, out_channels=stem_ch, stride=1, affine=False, track_running_stats=False, use_bn=use_bn, keep_mask=keep_mask) for i in range(5)])
|
||||
self.reduction1 = reduction(in_channels=stem_ch, out_channels=stem_ch*2)
|
||||
self.stack_cell2 = nn.Sequential(*[SearchCell(in_channels=stem_ch*2, out_channels=stem_ch*2, stride=1, affine=False, track_running_stats=False, use_bn=use_bn, keep_mask=keep_mask) for i in range(5)])
|
||||
self.reduction2 = reduction(in_channels=stem_ch*2, out_channels=stem_ch*4)
|
||||
self.stack_cell3 = nn.Sequential(*[SearchCell(in_channels=stem_ch*4, out_channels=stem_ch*4, stride=1, affine=False, track_running_stats=False, use_bn=use_bn, keep_mask=keep_mask) for i in range(5)])
|
||||
self.top = top(in_dims=stem_ch*4, num_classes=num_classes, use_bn=use_bn)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
|
||||
x = self.stack_cell1(x)
|
||||
x = self.reduction1(x)
|
||||
|
||||
x = self.stack_cell2(x)
|
||||
x = self.reduction2(x)
|
||||
|
||||
x = self.stack_cell3(x)
|
||||
|
||||
x = self.top(x)
|
||||
return x
|
||||
|
||||
def get_prunable_copy(self, bn=False):
|
||||
model_new = get_model_from_arch_str(self.arch_str, self.num_classes, use_bn=bn)
|
||||
|
||||
#TODO this is quite brittle and doesn't work with nn.Sequential when bn is different
|
||||
# it is only required to maintain initialization -- maybe init after get_punable_copy?
|
||||
model_new.load_state_dict(self.state_dict(), strict=False)
|
||||
model_new.train()
|
||||
|
||||
return model_new
|
||||
|
||||
|
||||
def get_arch_str_from_model(net):
|
||||
search_cell = net.stack_cell1[0].options
|
||||
keep_mask = net.stack_cell1[0].keep_mask
|
||||
num_nodes = net.stack_cell1[0].num_nodes
|
||||
|
||||
nodes = []
|
||||
idx = 0
|
||||
for curr_node in range(num_nodes -1):
|
||||
edges = []
|
||||
for prev_node in range(curr_node+1): # n-1 prev nodes
|
||||
for _op_name in OPS.keys():
|
||||
if keep_mask[idx]:
|
||||
edges.append(f'{_op_name}~{prev_node}')
|
||||
idx += 1
|
||||
node_str = '|'.join(edges)
|
||||
node_str = f'|{node_str}|'
|
||||
nodes.append(node_str)
|
||||
arch_str = '+'.join(nodes)
|
||||
return arch_str
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
arch_str = '|nor_conv_3x3~0|+|none~0|none~1|+|avg_pool_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|'
|
||||
|
||||
n = get_model_from_arch_str(arch_str=arch_str, num_classes=10)
|
||||
print(n.stack_cell1[0])
|
||||
|
||||
arch_str2 = get_arch_str_from_model(n)
|
||||
print(arch_str)
|
||||
print(arch_str2)
|
||||
print(f'Are the two arch strings same? {arch_str == arch_str2}')
|
164
zero-cost-nas/foresight/models/nasbench2_ops.py
Normal file
164
zero-cost-nas/foresight/models/nasbench2_ops.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, affine, track_running_stats=True, use_bn=True, name='ReLUConvBN'):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.name = name
|
||||
if use_bn:
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=not affine),
|
||||
nn.BatchNorm2d(out_channels, affine=affine, track_running_stats=track_running_stats)
|
||||
)
|
||||
else:
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=not affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self, name='Identity'):
|
||||
self.name = name
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, stride, name='Zero'):
|
||||
self.name = name
|
||||
super(Zero, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
return x.mul(0.)
|
||||
return x[:,:,::self.stride,::self.stride].mul(0.)
|
||||
|
||||
class POOLING(nn.Module):
|
||||
def __init__(self, kernel_size, stride, padding, name='POOLING'):
|
||||
super(POOLING, self).__init__()
|
||||
self.name = name
|
||||
self.avgpool = nn.AvgPool2d(kernel_size=kernel_size, stride=1, padding=1, count_include_pad=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.avgpool(x)
|
||||
|
||||
|
||||
class reduction(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(reduction, self).__init__()
|
||||
self.residual = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False))
|
||||
|
||||
self.conv_a = ReLUConvBN(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, dilation=1, affine=True, track_running_stats=True)
|
||||
self.conv_b = ReLUConvBN(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, dilation=1, affine=True, track_running_stats=True)
|
||||
|
||||
def forward(self, x):
|
||||
basicblock = self.conv_a(x)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
residual = self.residual(x)
|
||||
return residual + basicblock
|
||||
|
||||
class stem(nn.Module):
|
||||
def __init__(self, out_channels, use_bn=True):
|
||||
super(stem, self).__init__()
|
||||
if use_bn:
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_channels=3, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels))
|
||||
else:
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_channels=3, out_channels=out_channels, kernel_size=3, padding=1, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class top(nn.Module):
|
||||
def __init__(self, in_dims, num_classes, use_bn=True):
|
||||
super(top, self).__init__()
|
||||
if use_bn:
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(in_dims), nn.ReLU(inplace=True))
|
||||
else:
|
||||
self.lastact = nn.ReLU(inplace=True)
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(in_dims, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.lastact(x)
|
||||
x = self.global_pooling(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
logits = self.classifier(x)
|
||||
return logits
|
||||
|
||||
|
||||
class SearchCell(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride, affine, track_running_stats, use_bn=True, num_nodes=4, keep_mask=None):
|
||||
super(SearchCell, self).__init__()
|
||||
self.num_nodes = num_nodes
|
||||
self.options = nn.ModuleList()
|
||||
for curr_node in range(self.num_nodes-1):
|
||||
for prev_node in range(curr_node+1):
|
||||
for _op_name in OPS.keys():
|
||||
op = OPS[_op_name](in_channels, out_channels, stride, affine, track_running_stats, use_bn)
|
||||
self.options.append(op)
|
||||
|
||||
if keep_mask is not None:
|
||||
self.keep_mask = keep_mask
|
||||
else:
|
||||
self.keep_mask = [True]*len(self.options)
|
||||
|
||||
def forward(self, x):
|
||||
outs = [x]
|
||||
|
||||
idx = 0
|
||||
for curr_node in range(self.num_nodes-1):
|
||||
edges_in = []
|
||||
for prev_node in range(curr_node+1): # n-1 prev nodes
|
||||
for op_idx in range(len(OPS.keys())):
|
||||
if self.keep_mask[idx]:
|
||||
edges_in.append(self.options[idx](outs[prev_node]))
|
||||
idx += 1
|
||||
node_output = sum(edges_in)
|
||||
outs.append(node_output)
|
||||
|
||||
return outs[-1]
|
||||
|
||||
|
||||
|
||||
OPS = {
|
||||
'none' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: Zero(stride, name='none'),
|
||||
'avg_pool_3x3' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: POOLING(3, 1, 1, name='avg_pool_3x3'),
|
||||
'nor_conv_3x3' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: ReLUConvBN(in_channels, out_channels, 3, 1, 1, 1, affine, track_running_stats, use_bn, name='nor_conv_3x3'),
|
||||
'nor_conv_1x1' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: ReLUConvBN(in_channels, out_channels, 1, 1, 0, 1, affine, track_running_stats, use_bn, name='nor_conv_1x1'),
|
||||
'skip_connect' : lambda in_channels, out_channels, stride, affine, track_running_stats, use_bn: Identity(name='skip_connect'),
|
||||
}
|
||||
|
||||
|
19
zero-cost-nas/foresight/pruners/__init__.py
Normal file
19
zero-cost-nas/foresight/pruners/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
from os.path import dirname, basename, isfile, join
|
||||
import glob
|
||||
modules = glob.glob(join(dirname(__file__), "*.py"))
|
||||
__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]
|
66
zero-cost-nas/foresight/pruners/measures/__init__.py
Normal file
66
zero-cost-nas/foresight/pruners/measures/__init__.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
|
||||
available_measures = []
|
||||
_measure_impls = {}
|
||||
|
||||
|
||||
def measure(name, bn=True, copy_net=True, force_clean=True, **impl_args):
|
||||
def make_impl(func):
|
||||
def measure_impl(net_orig, device, *args, **kwargs):
|
||||
if copy_net:
|
||||
net = net_orig.get_prunable_copy(bn=bn).to(device)
|
||||
else:
|
||||
net = net_orig
|
||||
ret = func(net, *args, **kwargs, **impl_args)
|
||||
if copy_net and force_clean:
|
||||
import gc
|
||||
import torch
|
||||
del net
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
return ret
|
||||
|
||||
global _measure_impls
|
||||
if name in _measure_impls:
|
||||
raise KeyError(f'Duplicated measure! {name}')
|
||||
available_measures.append(name)
|
||||
_measure_impls[name] = measure_impl
|
||||
return func
|
||||
return make_impl
|
||||
|
||||
|
||||
def calc_measure(name, net, device, *args, **kwargs):
|
||||
return _measure_impls[name](net, device, *args, **kwargs)
|
||||
|
||||
|
||||
def load_all():
|
||||
from . import grad_norm
|
||||
from . import snip
|
||||
from . import grasp
|
||||
from . import fisher
|
||||
from . import jacob_cov
|
||||
from . import plain
|
||||
from . import synflow
|
||||
from . import var
|
||||
from . import cor
|
||||
from . import norm
|
||||
from . import meco
|
||||
from . import zico
|
||||
|
||||
|
||||
# TODO: should we do that by default?
|
||||
load_all()
|
53
zero-cost-nas/foresight/pruners/measures/cor.py
Normal file
53
zero-cost-nas/foresight/pruners/measures/cor.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
import time
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import measure
|
||||
|
||||
|
||||
def get_score(net, x, target, device, split_data):
|
||||
result_list = []
|
||||
def forward_hook(module, data_input, data_output):
|
||||
corr = np.mean(np.corrcoef(data_input[0].detach().cpu().numpy()))
|
||||
result_list.append(corr)
|
||||
net.classifier.register_forward_hook(forward_hook)
|
||||
|
||||
N = x.shape[0]
|
||||
for sp in range(split_data):
|
||||
st = sp * N // split_data
|
||||
en = (sp + 1) * N // split_data
|
||||
y = net(x[st:en])
|
||||
cor = result_list[0].item()
|
||||
result_list.clear()
|
||||
return cor
|
||||
|
||||
|
||||
|
||||
@measure('cor', bn=True)
|
||||
def compute_norm(net, inputs, targets, split_data=1, loss_fn=None):
|
||||
device = inputs.device
|
||||
# Compute gradients (but don't apply them)
|
||||
net.zero_grad()
|
||||
|
||||
try:
|
||||
cor= get_score(net, inputs, targets, device, split_data=split_data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
cor= np.nan
|
||||
|
||||
return cor
|
107
zero-cost-nas/foresight/pruners/measures/fisher.py
Normal file
107
zero-cost-nas/foresight/pruners/measures/fisher.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import types
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array, reshape_elements
|
||||
|
||||
|
||||
def fisher_forward_conv2d(self, x):
|
||||
x = F.conv2d(x, self.weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
#intercept and store the activations after passing through 'hooked' identity op
|
||||
self.act = self.dummy(x)
|
||||
return self.act
|
||||
|
||||
def fisher_forward_linear(self, x):
|
||||
x = F.linear(x, self.weight, self.bias)
|
||||
self.act = self.dummy(x)
|
||||
return self.act
|
||||
|
||||
@measure('fisher', bn=True, mode='channel')
|
||||
def compute_fisher_per_weight(net, inputs, targets, loss_fn, mode, split_data=1):
|
||||
|
||||
device = inputs.device
|
||||
|
||||
if mode == 'param':
|
||||
raise ValueError('Fisher pruning does not support parameter pruning.')
|
||||
|
||||
net.train()
|
||||
all_hooks = []
|
||||
for layer in net.modules():
|
||||
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
|
||||
#variables/op needed for fisher computation
|
||||
layer.fisher = None
|
||||
layer.act = 0.
|
||||
layer.dummy = nn.Identity()
|
||||
|
||||
#replace forward method of conv/linear
|
||||
if isinstance(layer, nn.Conv2d):
|
||||
layer.forward = types.MethodType(fisher_forward_conv2d, layer)
|
||||
if isinstance(layer, nn.Linear):
|
||||
layer.forward = types.MethodType(fisher_forward_linear, layer)
|
||||
|
||||
#function to call during backward pass (hooked on identity op at output of layer)
|
||||
def hook_factory(layer):
|
||||
def hook(module, grad_input, grad_output):
|
||||
act = layer.act.detach()
|
||||
grad = grad_output[0].detach()
|
||||
if len(act.shape) > 2:
|
||||
g_nk = torch.sum((act * grad), list(range(2,len(act.shape))))
|
||||
else:
|
||||
g_nk = act * grad
|
||||
del_k = g_nk.pow(2).mean(0).mul(0.5)
|
||||
if layer.fisher is None:
|
||||
layer.fisher = del_k
|
||||
else:
|
||||
layer.fisher += del_k
|
||||
del layer.act #without deleting this, a nasty memory leak occurs! related: https://discuss.pytorch.org/t/memory-leak-when-using-forward-hook-and-backward-hook-simultaneously/27555
|
||||
return hook
|
||||
|
||||
#register backward hook on identity fcn to compute fisher info
|
||||
layer.dummy.register_backward_hook(hook_factory(layer))
|
||||
|
||||
N = inputs.shape[0]
|
||||
for sp in range(split_data):
|
||||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
|
||||
net.zero_grad()
|
||||
outputs = net(inputs[st:en])
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
loss.backward()
|
||||
|
||||
# retrieve fisher info
|
||||
def fisher(layer):
|
||||
if layer.fisher is not None:
|
||||
return torch.abs(layer.fisher.detach())
|
||||
else:
|
||||
return torch.zeros(layer.weight.shape[0]) #size=ch
|
||||
|
||||
grads_abs_ch = get_layer_metric_array(net, fisher, mode)
|
||||
|
||||
#broadcast channel value here to all parameters in that channel
|
||||
#to be compatible with stuff downstream (which expects per-parameter metrics)
|
||||
#TODO cleanup on the selectors/apply_prune_mask side (?)
|
||||
shapes = get_layer_metric_array(net, lambda l : l.weight.shape[1:], mode)
|
||||
|
||||
grads_abs = reshape_elements(grads_abs_ch, shapes, device)
|
||||
|
||||
return grads_abs
|
38
zero-cost-nas/foresight/pruners/measures/grad_norm.py
Normal file
38
zero-cost-nas/foresight/pruners/measures/grad_norm.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import copy
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
@measure('grad_norm', bn=True)
|
||||
def get_grad_norm_arr(net, inputs, targets, loss_fn, split_data=1, skip_grad=False):
|
||||
net.zero_grad()
|
||||
N = inputs.shape[0]
|
||||
for sp in range(split_data):
|
||||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
|
||||
outputs = net.forward(inputs[st:en])
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
loss.backward()
|
||||
|
||||
grad_norm_arr = get_layer_metric_array(net, lambda l: l.weight.grad.norm() if l.weight.grad is not None else torch.zeros_like(l.weight), mode='param')
|
||||
|
||||
return grad_norm_arr
|
87
zero-cost-nas/foresight/pruners/measures/grasp.py
Normal file
87
zero-cost-nas/foresight/pruners/measures/grasp.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.autograd as autograd
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
|
||||
@measure('grasp', bn=True, mode='param')
|
||||
def compute_grasp_per_weight(net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1):
|
||||
|
||||
# get all applicable weights
|
||||
weights = []
|
||||
for layer in net.modules():
|
||||
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
|
||||
weights.append(layer.weight)
|
||||
layer.weight.requires_grad_(True) # TODO isn't this already true?
|
||||
|
||||
# NOTE original code had some input/target splitting into 2
|
||||
# I am guessing this was because of GPU mem limit
|
||||
net.zero_grad()
|
||||
N = inputs.shape[0]
|
||||
for sp in range(split_data):
|
||||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
|
||||
#forward/grad pass #1
|
||||
grad_w = None
|
||||
for _ in range(num_iters):
|
||||
#TODO get new data, otherwise num_iters is useless!
|
||||
outputs = net.forward(inputs[st:en])/T
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
grad_w_p = autograd.grad(loss, weights, allow_unused=True)
|
||||
if grad_w is None:
|
||||
grad_w = list(grad_w_p)
|
||||
else:
|
||||
for idx in range(len(grad_w)):
|
||||
grad_w[idx] += grad_w_p[idx]
|
||||
|
||||
|
||||
for sp in range(split_data):
|
||||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
|
||||
# forward/grad pass #2
|
||||
outputs = net.forward(inputs[st:en])/T
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
grad_f = autograd.grad(loss, weights, create_graph=True, allow_unused=True)
|
||||
|
||||
# accumulate gradients computed in previous step and call backwards
|
||||
z, count = 0,0
|
||||
for layer in net.modules():
|
||||
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
|
||||
if grad_w[count] is not None:
|
||||
z += (grad_w[count].data * grad_f[count]).sum()
|
||||
count += 1
|
||||
z.backward()
|
||||
|
||||
# compute final sensitivity metric and put in grads
|
||||
def grasp(layer):
|
||||
if layer.weight.grad is not None:
|
||||
return -layer.weight.data * layer.weight.grad # -theta_q Hg
|
||||
#NOTE in the grasp code they take the *bottom* (1-p)% of values
|
||||
#but we take the *top* (1-p)%, therefore we remove the -ve sign
|
||||
#EDIT accuracy seems to be negatively correlated with this metric, so we add -ve sign here!
|
||||
else:
|
||||
return torch.zeros_like(layer.weight)
|
||||
|
||||
grads = get_layer_metric_array(net, grasp, mode)
|
||||
|
||||
return grads
|
57
zero-cost-nas/foresight/pruners/measures/jacob_cov.py
Normal file
57
zero-cost-nas/foresight/pruners/measures/jacob_cov.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from . import measure
|
||||
|
||||
|
||||
def get_batch_jacobian(net, x, target, device, split_data):
|
||||
x.requires_grad_(True)
|
||||
|
||||
N = x.shape[0]
|
||||
for sp in range(split_data):
|
||||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
y = net(x[st:en])
|
||||
y.backward(torch.ones_like(y))
|
||||
|
||||
jacob = x.grad.detach()
|
||||
x.requires_grad_(False)
|
||||
return jacob, target.detach()
|
||||
|
||||
def eval_score(jacob, labels=None):
|
||||
corrs = np.corrcoef(jacob)
|
||||
v, _ = np.linalg.eig(corrs)
|
||||
k = 1e-5
|
||||
return -np.sum(np.log(v + k) + 1./(v + k))
|
||||
|
||||
@measure('jacob_cov', bn=True)
|
||||
def compute_jacob_cov(net, inputs, targets, split_data=1, loss_fn=None):
|
||||
device = inputs.device
|
||||
# Compute gradients (but don't apply them)
|
||||
net.zero_grad()
|
||||
|
||||
jacobs, labels = get_batch_jacobian(net, inputs, targets, device, split_data=split_data)
|
||||
jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy()
|
||||
|
||||
try:
|
||||
jc = eval_score(jacobs, labels)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
jc = np.nan
|
||||
|
||||
return jc
|
22
zero-cost-nas/foresight/pruners/measures/l2_norm.py
Normal file
22
zero-cost-nas/foresight/pruners/measures/l2_norm.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
|
||||
@measure('l2_norm', copy_net=False, mode='param')
|
||||
def get_l2_norm_array(net, inputs, targets, mode, split_data=1):
|
||||
return get_layer_metric_array(net, lambda l: l.weight.norm(), mode=mode)
|
69
zero-cost-nas/foresight/pruners/measures/meco.py
Normal file
69
zero-cost-nas/foresight/pruners/measures/meco.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
import copy
|
||||
import time
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from . import measure
|
||||
|
||||
|
||||
def get_score(net, x, target, device, split_data):
|
||||
result_list = []
|
||||
|
||||
def forward_hook(module, data_input, data_output):
|
||||
|
||||
fea = data_output[0].detach()
|
||||
fea = fea.reshape(fea.shape[0], -1)
|
||||
corr = torch.corrcoef(fea)
|
||||
corr[torch.isnan(corr)] = 0
|
||||
corr[torch.isinf(corr)] = 0
|
||||
values = torch.linalg.eig(corr)[0]
|
||||
# result = np.real(np.min(values)) / np.real(np.max(values))
|
||||
result = torch.min(torch.real(values))
|
||||
result_list.append(result)
|
||||
|
||||
for name, modules in net.named_modules():
|
||||
modules.register_forward_hook(forward_hook)
|
||||
|
||||
|
||||
|
||||
N = x.shape[0]
|
||||
for sp in range(split_data):
|
||||
st = sp * N // split_data
|
||||
en = (sp + 1) * N // split_data
|
||||
y = net(x[st:en])
|
||||
results = torch.tensor(result_list)
|
||||
results = results[torch.logical_not(torch.isnan(results))]
|
||||
v = torch.sum(results)
|
||||
result_list.clear()
|
||||
return v.item()
|
||||
|
||||
|
||||
|
||||
@measure('meco', bn=True)
|
||||
def compute_meco(net, inputs, targets, split_data=1, loss_fn=None):
|
||||
device = inputs.device
|
||||
# Compute gradients (but don't apply them)
|
||||
net.zero_grad()
|
||||
|
||||
try:
|
||||
meco = get_score(net, inputs, targets, device, split_data=split_data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
meco = np.nan, None
|
||||
return meco
|
55
zero-cost-nas/foresight/pruners/measures/norm.py
Normal file
55
zero-cost-nas/foresight/pruners/measures/norm.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
import time
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import measure
|
||||
|
||||
|
||||
def get_score(net, x, target, device, split_data):
|
||||
result_list = []
|
||||
def forward_hook(module, data_input, data_output):
|
||||
norm = torch.norm(data_input[0])
|
||||
result_list.append(norm)
|
||||
net.classifier.register_forward_hook(forward_hook)
|
||||
|
||||
N = x.shape[0]
|
||||
for sp in range(split_data):
|
||||
st = sp * N // split_data
|
||||
en = (sp + 1) * N // split_data
|
||||
y = net(x[st:en])
|
||||
n = result_list[0].item()
|
||||
result_list.clear()
|
||||
return n
|
||||
|
||||
|
||||
|
||||
@measure('norm', bn=True)
|
||||
def compute_norm(net, inputs, targets, split_data=1, loss_fn=None):
|
||||
device = inputs.device
|
||||
# Compute gradients (but don't apply them)
|
||||
net.zero_grad()
|
||||
|
||||
# print('var:', feature.shape)
|
||||
try:
|
||||
norm, t = get_score(net, inputs, targets, device, split_data=split_data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
norm, t = np.nan, None
|
||||
# print(jc)
|
||||
# print(f'norm time: {t} s')
|
||||
return norm, t
|
16
zero-cost-nas/foresight/pruners/measures/param_count.py
Normal file
16
zero-cost-nas/foresight/pruners/measures/param_count.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import time
|
||||
import torch
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
|
||||
|
||||
@measure('param_count', copy_net=False, mode='param')
|
||||
def get_param_count_array(net, inputs, targets, mode, loss_fn, split_data=1):
|
||||
s = time.time()
|
||||
count = get_layer_metric_array(net, lambda l: torch.tensor(sum(p.numel() for p in l.parameters() if p.requires_grad)), mode=mode)
|
||||
e = time.time()
|
||||
t = e - s
|
||||
# print(f'param_count time: {t} s')
|
||||
return count, t
|
44
zero-cost-nas/foresight/pruners/measures/plain.py
Normal file
44
zero-cost-nas/foresight/pruners/measures/plain.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
|
||||
@measure('plain', bn=True, mode='param')
|
||||
def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
|
||||
|
||||
net.zero_grad()
|
||||
N = inputs.shape[0]
|
||||
for sp in range(split_data):
|
||||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
|
||||
outputs = net.forward(inputs[st:en])
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
loss.backward()
|
||||
|
||||
# select the gradients that we want to use for search/prune
|
||||
def plain(layer):
|
||||
if layer.weight.grad is not None:
|
||||
return layer.weight.grad * layer.weight
|
||||
else:
|
||||
return torch.zeros_like(layer.weight)
|
||||
|
||||
grads_abs = get_layer_metric_array(net, plain, mode)
|
||||
return grads_abs
|
69
zero-cost-nas/foresight/pruners/measures/snip.py
Normal file
69
zero-cost-nas/foresight/pruners/measures/snip.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import copy
|
||||
import types
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
|
||||
def snip_forward_conv2d(self, x):
|
||||
return F.conv2d(x, self.weight * self.weight_mask, self.bias,
|
||||
self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
def snip_forward_linear(self, x):
|
||||
return F.linear(x, self.weight * self.weight_mask, self.bias)
|
||||
|
||||
@measure('snip', bn=True, mode='param')
|
||||
def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1):
|
||||
for layer in net.modules():
|
||||
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
|
||||
layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))
|
||||
layer.weight.requires_grad = False
|
||||
|
||||
# Override the forward methods:
|
||||
if isinstance(layer, nn.Conv2d):
|
||||
layer.forward = types.MethodType(snip_forward_conv2d, layer)
|
||||
|
||||
if isinstance(layer, nn.Linear):
|
||||
layer.forward = types.MethodType(snip_forward_linear, layer)
|
||||
|
||||
# Compute gradients (but don't apply them)
|
||||
net.zero_grad()
|
||||
N = inputs.shape[0]
|
||||
for sp in range(split_data):
|
||||
st=sp*N//split_data
|
||||
en=(sp+1)*N//split_data
|
||||
|
||||
outputs = net.forward(inputs[st:en])
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
loss.backward()
|
||||
|
||||
# select the gradients that we want to use for search/prune
|
||||
def snip(layer):
|
||||
if layer.weight_mask.grad is not None:
|
||||
return torch.abs(layer.weight_mask.grad)
|
||||
else:
|
||||
return torch.zeros_like(layer.weight)
|
||||
|
||||
grads_abs = get_layer_metric_array(net, snip, mode)
|
||||
|
||||
return grads_abs
|
69
zero-cost-nas/foresight/pruners/measures/synflow.py
Normal file
69
zero-cost-nas/foresight/pruners/measures/synflow.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
|
||||
from . import measure
|
||||
from ..p_utils import get_layer_metric_array
|
||||
|
||||
|
||||
@measure('synflow', bn=False, mode='param')
|
||||
@measure('synflow_bn', bn=True, mode='param')
|
||||
def compute_synflow_per_weight(net, inputs, targets, mode, split_data=1, loss_fn=None):
|
||||
|
||||
device = inputs.device
|
||||
|
||||
#convert params to their abs. Keep sign for converting it back.
|
||||
@torch.no_grad()
|
||||
def linearize(net):
|
||||
signs = {}
|
||||
for name, param in net.state_dict().items():
|
||||
signs[name] = torch.sign(param)
|
||||
param.abs_()
|
||||
return signs
|
||||
|
||||
#convert to orig values
|
||||
@torch.no_grad()
|
||||
def nonlinearize(net, signs):
|
||||
for name, param in net.state_dict().items():
|
||||
if 'weight_mask' not in name:
|
||||
param.mul_(signs[name])
|
||||
|
||||
# keep signs of all params
|
||||
signs = linearize(net)
|
||||
|
||||
# Compute gradients with input of 1s
|
||||
net.zero_grad()
|
||||
net.double()
|
||||
input_dim = list(inputs[0,:].shape)
|
||||
inputs = torch.ones([1] + input_dim).double().to(device)
|
||||
output = net.forward(inputs)
|
||||
torch.sum(output).backward()
|
||||
|
||||
# select the gradients that we want to use for search/prune
|
||||
def synflow(layer):
|
||||
if layer.weight.grad is not None:
|
||||
return torch.abs(layer.weight * layer.weight.grad)
|
||||
else:
|
||||
return torch.zeros_like(layer.weight)
|
||||
|
||||
grads_abs = get_layer_metric_array(net, synflow, mode)
|
||||
|
||||
# apply signs of all params
|
||||
nonlinearize(net, signs)
|
||||
|
||||
return grads_abs
|
||||
|
||||
|
55
zero-cost-nas/foresight/pruners/measures/var.py
Normal file
55
zero-cost-nas/foresight/pruners/measures/var.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
import time
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import measure
|
||||
|
||||
|
||||
def get_score(net, x, target, device, split_data):
|
||||
result_list = []
|
||||
def forward_hook(module, data_input, data_output):
|
||||
var = torch.var(data_input[0])
|
||||
result_list.append(var)
|
||||
net.classifier.register_forward_hook(forward_hook)
|
||||
|
||||
N = x.shape[0]
|
||||
for sp in range(split_data):
|
||||
st = sp * N // split_data
|
||||
en = (sp + 1) * N // split_data
|
||||
y = net(x[st:en])
|
||||
v = result_list[0].item()
|
||||
result_list.clear()
|
||||
return v
|
||||
|
||||
|
||||
|
||||
@measure('var', bn=True)
|
||||
def compute_var(net, inputs, targets, split_data=1, loss_fn=None):
|
||||
device = inputs.device
|
||||
# Compute gradients (but don't apply them)
|
||||
net.zero_grad()
|
||||
|
||||
# print('var:', feature.shape)
|
||||
try:
|
||||
var= get_score(net, inputs, targets, device, split_data=split_data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
var= np.nan
|
||||
# print(jc)
|
||||
# print(f'var time: {t} s')
|
||||
return var
|
106
zero-cost-nas/foresight/pruners/measures/zico.py
Normal file
106
zero-cost-nas/foresight/pruners/measures/zico.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
import time
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import measure
|
||||
from torch import nn
|
||||
|
||||
from ...dataset import get_cifar_dataloaders
|
||||
|
||||
|
||||
def getgrad(model: torch.nn.Module, grad_dict: dict, step_iter=0):
|
||||
if step_iter == 0:
|
||||
for name, mod in model.named_modules():
|
||||
if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
|
||||
# print(mod.weight.grad.data.size())
|
||||
# print(mod.weight.data.size())
|
||||
try:
|
||||
grad_dict[name] = [mod.weight.grad.data.cpu().reshape(-1).numpy()]
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
for name, mod in model.named_modules():
|
||||
if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
|
||||
try:
|
||||
grad_dict[name].append(mod.weight.grad.data.cpu().reshape(-1).numpy())
|
||||
except:
|
||||
continue
|
||||
return grad_dict
|
||||
|
||||
|
||||
def caculate_zico(grad_dict):
|
||||
allgrad_array = None
|
||||
for i, modname in enumerate(grad_dict.keys()):
|
||||
grad_dict[modname] = np.array(grad_dict[modname])
|
||||
nsr_mean_sum = 0
|
||||
nsr_mean_sum_abs = 0
|
||||
nsr_mean_avg = 0
|
||||
nsr_mean_avg_abs = 0
|
||||
for j, modname in enumerate(grad_dict.keys()):
|
||||
nsr_std = np.std(grad_dict[modname], axis=0)
|
||||
# print(grad_dict[modname].shape)
|
||||
# print(grad_dict[modname].shape, nsr_std.shape)
|
||||
nonzero_idx = np.nonzero(nsr_std)[0]
|
||||
nsr_mean_abs = np.mean(np.abs(grad_dict[modname]), axis=0)
|
||||
tmpsum = np.sum(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx])
|
||||
if tmpsum == 0:
|
||||
pass
|
||||
else:
|
||||
nsr_mean_sum_abs += np.log(tmpsum)
|
||||
nsr_mean_avg_abs += np.log(np.mean(nsr_mean_abs[nonzero_idx] / nsr_std[nonzero_idx]))
|
||||
return nsr_mean_sum_abs
|
||||
|
||||
|
||||
def getzico(network, inputs, targets, loss_fn, split_data=2):
|
||||
grad_dict = {}
|
||||
network.train()
|
||||
device = inputs.device
|
||||
network.to(device)
|
||||
N = inputs.shape[0]
|
||||
split_data = 2
|
||||
|
||||
for sp in range(split_data):
|
||||
st = sp * N // split_data
|
||||
en = (sp + 1) * N // split_data
|
||||
outputs = network.forward(inputs[st:en])
|
||||
loss = loss_fn(outputs, targets[st:en])
|
||||
loss.backward()
|
||||
grad_dict = getgrad(network, grad_dict, sp)
|
||||
# print(grad_dict)
|
||||
res = caculate_zico(grad_dict)
|
||||
return res
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@measure('zico', bn=True)
|
||||
def compute_zico(net, inputs, targets, split_data=2, loss_fn=None):
|
||||
|
||||
# Compute gradients (but don't apply them)
|
||||
net.zero_grad()
|
||||
|
||||
# print('var:', feature.shape)
|
||||
try:
|
||||
zico = getzico(net, inputs, targets, loss_fn, split_data=split_data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
zico= np.nan
|
||||
# print(jc)
|
||||
# print(f'var time: {t} s')
|
||||
return zico
|
83
zero-cost-nas/foresight/pruners/p_utils.py
Normal file
83
zero-cost-nas/foresight/pruners/p_utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..models import *
|
||||
|
||||
def get_some_data(train_dataloader, num_batches, device):
|
||||
traindata = []
|
||||
dataloader_iter = iter(train_dataloader)
|
||||
for _ in range(num_batches):
|
||||
traindata.append(next(dataloader_iter))
|
||||
inputs = torch.cat([a for a,_ in traindata])
|
||||
targets = torch.cat([b for _,b in traindata])
|
||||
inputs = inputs.to(device)
|
||||
targets = targets.to(device)
|
||||
return inputs, targets
|
||||
|
||||
def get_some_data_grasp(train_dataloader, num_classes, samples_per_class, device):
|
||||
datas = [[] for _ in range(num_classes)]
|
||||
labels = [[] for _ in range(num_classes)]
|
||||
mark = dict()
|
||||
dataloader_iter = iter(train_dataloader)
|
||||
while True:
|
||||
inputs, targets = next(dataloader_iter)
|
||||
for idx in range(inputs.shape[0]):
|
||||
x, y = inputs[idx:idx+1], targets[idx:idx+1]
|
||||
category = y.item()
|
||||
if len(datas[category]) == samples_per_class:
|
||||
mark[category] = True
|
||||
continue
|
||||
datas[category].append(x)
|
||||
labels[category].append(y)
|
||||
if len(mark) == num_classes:
|
||||
break
|
||||
|
||||
x = torch.cat([torch.cat(_, 0) for _ in datas]).to(device)
|
||||
y = torch.cat([torch.cat(_) for _ in labels]).view(-1).to(device)
|
||||
return x, y
|
||||
|
||||
def get_layer_metric_array(net, metric, mode):
|
||||
metric_array = []
|
||||
|
||||
for layer in net.modules():
|
||||
if mode=='channel' and hasattr(layer,'dont_ch_prune'):
|
||||
continue
|
||||
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
|
||||
metric_array.append(metric(layer))
|
||||
|
||||
return metric_array
|
||||
|
||||
def reshape_elements(elements, shapes, device):
|
||||
def broadcast_val(elements, shapes):
|
||||
ret_grads = []
|
||||
for e,sh in zip(elements, shapes):
|
||||
ret_grads.append(torch.stack([torch.Tensor(sh).fill_(v) for v in e], dim=0).to(device))
|
||||
return ret_grads
|
||||
if type(elements[0]) == list:
|
||||
outer = []
|
||||
for e,sh in zip(elements, shapes):
|
||||
outer.append(broadcast_val(e,sh))
|
||||
return outer
|
||||
else:
|
||||
return broadcast_val(elements, shapes)
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
116
zero-cost-nas/foresight/pruners/predictive.py
Normal file
116
zero-cost-nas/foresight/pruners/predictive.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .p_utils import *
|
||||
from . import measures
|
||||
|
||||
import types
|
||||
import copy
|
||||
|
||||
|
||||
def no_op(self,x):
|
||||
return x
|
||||
|
||||
def copynet(self, bn):
|
||||
net = copy.deepcopy(self)
|
||||
if bn==False:
|
||||
for l in net.modules():
|
||||
if isinstance(l,nn.BatchNorm2d) or isinstance(l,nn.BatchNorm1d) :
|
||||
l.forward = types.MethodType(no_op, l)
|
||||
return net
|
||||
|
||||
def find_measures_arrays(net_orig, trainloader, dataload_info, device, measure_names=None, loss_fn=F.cross_entropy):
|
||||
if measure_names is None:
|
||||
measure_names = measures.available_measures
|
||||
|
||||
dataload, num_imgs_or_batches, num_classes = dataload_info
|
||||
|
||||
if not hasattr(net_orig,'get_prunable_copy'):
|
||||
net_orig.get_prunable_copy = types.MethodType(copynet, net_orig)
|
||||
|
||||
#move to cpu to free up mem
|
||||
torch.cuda.empty_cache()
|
||||
net_orig = net_orig.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
#given 1 minibatch of data
|
||||
if dataload == 'random':
|
||||
inputs, targets = get_some_data(trainloader, num_batches=num_imgs_or_batches, device=device)
|
||||
elif dataload == 'grasp':
|
||||
inputs, targets = get_some_data_grasp(trainloader, num_classes, samples_per_class=num_imgs_or_batches, device=device)
|
||||
else:
|
||||
raise NotImplementedError(f'dataload {dataload} is not supported')
|
||||
|
||||
done, ds = False, 1
|
||||
measure_values = {}
|
||||
|
||||
while not done:
|
||||
try:
|
||||
for measure_name in measure_names:
|
||||
if measure_name not in measure_values:
|
||||
val = measures.calc_measure(measure_name, net_orig, device, inputs, targets, loss_fn=loss_fn, split_data=ds)
|
||||
measure_values[measure_name] = val
|
||||
|
||||
done = True
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
done=False
|
||||
if ds == inputs.shape[0]//2:
|
||||
raise ValueError(f'Can\'t split data anymore, but still unable to run. Something is wrong')
|
||||
ds += 1
|
||||
while inputs.shape[0] % ds != 0:
|
||||
ds += 1
|
||||
torch.cuda.empty_cache()
|
||||
print(f'Caught CUDA OOM, retrying with data split into {ds} parts')
|
||||
else:
|
||||
raise e
|
||||
|
||||
net_orig = net_orig.to(device).train()
|
||||
return measure_values
|
||||
|
||||
def find_measures(net_orig, # neural network
|
||||
dataloader, # a data loader (typically for training data)
|
||||
dataload_info, # a tuple with (dataload_type = {random, grasp}, number_of_batches_for_random_or_images_per_class_for_grasp, number of classes)
|
||||
device, # GPU/CPU device used
|
||||
loss_fn=F.cross_entropy, # loss function to use within the zero-cost metrics
|
||||
measure_names=None, # an array of measure names to compute, if left blank, all measures are computed by default
|
||||
measures_arr=None): # [not used] if the measures are already computed but need to be summarized, pass them here
|
||||
|
||||
#Given a neural net
|
||||
#and some information about the input data (dataloader)
|
||||
#and loss function (loss_fn)
|
||||
#this function returns an array of zero-cost proxy metrics.
|
||||
|
||||
def sum_arr(arr):
|
||||
sum = 0.
|
||||
for i in range(len(arr)):
|
||||
sum += torch.sum(arr[i])
|
||||
return sum.item()
|
||||
|
||||
if measures_arr is None:
|
||||
measures_arr = find_measures_arrays(net_orig, dataloader, dataload_info, device, loss_fn=loss_fn, measure_names=measure_names)
|
||||
|
||||
measures = {}
|
||||
for k,v in measures_arr.items():
|
||||
if k in ['jacob_cov', 'var', 'cor', 'norm', 'meco', 'zico']:
|
||||
measures[k] = v
|
||||
else:
|
||||
measures[k] = sum_arr(v)
|
||||
|
||||
return measures
|
51
zero-cost-nas/foresight/version.py
Normal file
51
zero-cost-nas/foresight/version.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
version = '1.0.0'
|
||||
repo = 'unknown'
|
||||
commit = 'unknown'
|
||||
has_repo = False
|
||||
|
||||
try:
|
||||
import git
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
r = git.Repo(Path(__file__).parents[1])
|
||||
has_repo = True
|
||||
|
||||
if not r.remotes:
|
||||
repo = 'local'
|
||||
else:
|
||||
repo = r.remotes.origin.url
|
||||
|
||||
commit = r.head.commit.hexsha
|
||||
if r.is_dirty():
|
||||
commit += ' (dirty)'
|
||||
except git.InvalidGitRepositoryError:
|
||||
raise ImportError()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from . import _dist_info as info
|
||||
assert not has_repo, '_dist_info should not exist when repo is in place'
|
||||
assert version == info.version
|
||||
repo = info.repo
|
||||
commit = info.commit
|
||||
except (ImportError, SystemError):
|
||||
pass
|
||||
|
||||
__all__ = ['version', 'repo', 'commit', 'has_repo']
|
68
zero-cost-nas/foresight/weight_initializers.py
Normal file
68
zero-cost-nas/foresight/weight_initializers.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright 2021 Samsung Electronics Co., Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
def init_net(net, w_type, b_type):
|
||||
if w_type == 'none':
|
||||
pass
|
||||
elif w_type == 'xavier':
|
||||
net.apply(init_weights_vs)
|
||||
elif w_type == 'kaiming':
|
||||
net.apply(init_weights_he)
|
||||
elif w_type == 'zero':
|
||||
net.apply(init_weights_zero)
|
||||
else:
|
||||
raise NotImplementedError(f'init_type={w_type} is not supported.')
|
||||
|
||||
if b_type == 'none':
|
||||
pass
|
||||
elif b_type == 'xavier':
|
||||
net.apply(init_bias_vs)
|
||||
elif b_type == 'kaiming':
|
||||
net.apply(init_bias_he)
|
||||
elif b_type == 'zero':
|
||||
net.apply(init_bias_zero)
|
||||
else:
|
||||
raise NotImplementedError(f'init_type={b_type} is not supported.')
|
||||
|
||||
def init_weights_vs(m):
|
||||
if type(m) == nn.Linear or type(m) == nn.Conv2d:
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
|
||||
def init_bias_vs(m):
|
||||
if type(m) == nn.Linear or type(m) == nn.Conv2d:
|
||||
if m.bias is not None:
|
||||
nn.init.xavier_normal_(m.bias)
|
||||
|
||||
def init_weights_he(m):
|
||||
if type(m) == nn.Linear or type(m) == nn.Conv2d:
|
||||
nn.init.kaiming_normal_(m.weight)
|
||||
|
||||
def init_bias_he(m):
|
||||
if type(m) == nn.Linear or type(m) == nn.Conv2d:
|
||||
if m.bias is not None:
|
||||
nn.init.kaiming_normal_(m.bias)
|
||||
|
||||
def init_weights_zero(m):
|
||||
if type(m) == nn.Linear or type(m) == nn.Conv2d:
|
||||
m.weight.data.fill_(.0)
|
||||
|
||||
def init_bias_zero(m):
|
||||
if type(m) == nn.Linear or type(m) == nn.Conv2d:
|
||||
if m.bias is not None:
|
||||
m.bias.data.fill_(.0)
|
||||
|
||||
|
Reference in New Issue
Block a user