This commit is contained in:
Jack Turner
2021-02-26 16:12:51 +00:00
parent c895924c99
commit b74255e1f3
74 changed files with 11326 additions and 537 deletions

1
nas_101_api/__init__.py Normal file
View File

@@ -0,0 +1 @@

65
nas_101_api/base_ops.py Normal file
View File

@@ -0,0 +1,65 @@
"""Base operations used by the modules in this search space."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBnRelu(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
super(ConvBnRelu, self).__init__()
self.conv_bn_relu = nn.Sequential(
#nn.ReLU(),
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
#nn.ReLU(inplace=True)
nn.ReLU()
)
def forward(self, x):
return self.conv_bn_relu(x)
class Conv3x3BnRelu(nn.Module):
"""3x3 convolution with batch norm and ReLU activation."""
def __init__(self, in_channels, out_channels):
super(Conv3x3BnRelu, self).__init__()
self.conv3x3 = ConvBnRelu(in_channels, out_channels, 3, 1, 1)
def forward(self, x):
x = self.conv3x3(x)
return x
class Conv1x1BnRelu(nn.Module):
"""1x1 convolution with batch norm and ReLU activation."""
def __init__(self, in_channels, out_channels):
super(Conv1x1BnRelu, self).__init__()
self.conv1x1 = ConvBnRelu(in_channels, out_channels, 1, 1, 0)
def forward(self, x):
x = self.conv1x1(x)
return x
class MaxPool3x3(nn.Module):
"""3x3 max pool with no subsampling."""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(MaxPool3x3, self).__init__()
self.maxpool = nn.MaxPool2d(kernel_size, stride, padding)
#self.maxpool = nn.AvgPool2d(kernel_size, stride, padding)
def forward(self, x):
x = self.maxpool(x)
return x
# Commas should not be used in op names
OP_MAP = {
'conv3x3-bn-relu': Conv3x3BnRelu,
'conv1x1-bn-relu': Conv1x1BnRelu,
'maxpool3x3': MaxPool3x3
}

167
nas_101_api/graph_util.py Normal file
View File

@@ -0,0 +1,167 @@
# Copyright 2019 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions used by generate_graph.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import itertools
import numpy as np
def gen_is_edge_fn(bits):
"""Generate a boolean function for the edge connectivity.
Given a bitstring FEDCBA and a 4x4 matrix, the generated matrix is
[[0, A, B, D],
[0, 0, C, E],
[0, 0, 0, F],
[0, 0, 0, 0]]
Note that this function is agnostic to the actual matrix dimension due to
order in which elements are filled out (column-major, starting from least
significant bit). For example, the same FEDCBA bitstring (0-padded) on a 5x5
matrix is
[[0, A, B, D, 0],
[0, 0, C, E, 0],
[0, 0, 0, F, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]
Args:
bits: integer which will be interpreted as a bit mask.
Returns:
vectorized function that returns True when an edge is present.
"""
def is_edge(x, y):
"""Is there an edge from x to y (0-indexed)?"""
if x >= y:
return 0
# Map x, y to index into bit string
index = x + (y * (y - 1) // 2)
return (bits >> index) % 2 == 1
return np.vectorize(is_edge)
def is_full_dag(matrix):
"""Full DAG == all vertices on a path from vert 0 to (V-1).
i.e. no disconnected or "hanging" vertices.
It is sufficient to check for:
1) no rows of 0 except for row V-1 (only output vertex has no out-edges)
2) no cols of 0 except for col 0 (only input vertex has no in-edges)
Args:
matrix: V x V upper-triangular adjacency matrix
Returns:
True if the there are no dangling vertices.
"""
shape = np.shape(matrix)
rows = matrix[:shape[0]-1, :] == 0
rows = np.all(rows, axis=1) # Any row with all 0 will be True
rows_bad = np.any(rows)
cols = matrix[:, 1:] == 0
cols = np.all(cols, axis=0) # Any col with all 0 will be True
cols_bad = np.any(cols)
return (not rows_bad) and (not cols_bad)
def num_edges(matrix):
"""Computes number of edges in adjacency matrix."""
return np.sum(matrix)
def hash_module(matrix, labeling):
"""Computes a graph-invariance MD5 hash of the matrix and label pair.
Args:
matrix: np.ndarray square upper-triangular adjacency matrix.
labeling: list of int labels of length equal to both dimensions of
matrix.
Returns:
MD5 hash of the matrix and labeling.
"""
vertices = np.shape(matrix)[0]
in_edges = np.sum(matrix, axis=0).tolist()
out_edges = np.sum(matrix, axis=1).tolist()
assert len(in_edges) == len(out_edges) == len(labeling)
hashes = list(zip(out_edges, in_edges, labeling))
hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes]
# Computing this up to the diameter is probably sufficient but since the
# operation is fast, it is okay to repeat more times.
for _ in range(vertices):
new_hashes = []
for v in range(vertices):
in_neighbors = [hashes[w] for w in range(vertices) if matrix[w, v]]
out_neighbors = [hashes[w] for w in range(vertices) if matrix[v, w]]
new_hashes.append(hashlib.md5(
(''.join(sorted(in_neighbors)) + '|' +
''.join(sorted(out_neighbors)) + '|' +
hashes[v]).encode('utf-8')).hexdigest())
hashes = new_hashes
fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest()
return fingerprint
def permute_graph(graph, label, permutation):
"""Permutes the graph and labels based on permutation.
Args:
graph: np.ndarray adjacency matrix.
label: list of labels of same length as graph dimensions.
permutation: a permutation list of ints of same length as graph dimensions.
Returns:
np.ndarray where vertex permutation[v] is vertex v from the original graph
"""
# vertex permutation[v] in new graph is vertex v in the old graph
forward_perm = zip(permutation, list(range(len(permutation))))
inverse_perm = [x[1] for x in sorted(forward_perm)]
edge_fn = lambda x, y: graph[inverse_perm[x], inverse_perm[y]] == 1
new_matrix = np.fromfunction(np.vectorize(edge_fn),
(len(label), len(label)),
dtype=np.int8)
new_label = [label[inverse_perm[i]] for i in range(len(label))]
return new_matrix, new_label
def is_isomorphic(graph1, graph2):
"""Exhaustively checks if 2 graphs are isomorphic."""
matrix1, label1 = np.array(graph1[0]), graph1[1]
matrix2, label2 = np.array(graph2[0]), graph2[1]
assert np.shape(matrix1) == np.shape(matrix2)
assert len(label1) == len(label2)
vertices = np.shape(matrix1)[0]
# Note: input and output in our constrained graphs always map to themselves
# but this script does not enforce that.
for perm in itertools.permutations(range(0, vertices)):
pmatrix1, plabel1 = permute_graph(matrix1, label1, perm)
if np.array_equal(pmatrix1, matrix2) and plabel1 == label2:
return True
return False

252
nas_101_api/model.py Normal file
View File

@@ -0,0 +1,252 @@
"""Builds the Pytorch computational graph.
Tensors flowing into a single vertex are added together for all vertices
except the output, which is concatenated instead. Tensors flowing out of input
are always added.
If interior edge channels don't match, drop the extra channels (channels are
guaranteed non-decreasing). Tensors flowing out of the input as always
projected instead.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import math
from .base_ops import *
import torch
import torch.nn as nn
import torch.nn.functional as F
class Network(nn.Module):
def __init__(self, spec, args, searchspace=[]):
super(Network, self).__init__()
self.layers = nn.ModuleList([])
in_channels = 3
out_channels = args.stem_out_channels
# initial stem convolution
stem_conv = ConvBnRelu(in_channels, out_channels, 3, 1, 1)
self.layers.append(stem_conv)
in_channels = out_channels
for stack_num in range(args.num_stacks):
if stack_num > 0:
#downsample = nn.MaxPool2d(kernel_size=3, stride=2)
downsample = nn.MaxPool2d(kernel_size=2, stride=2)
#downsample = nn.AvgPool2d(kernel_size=2, stride=2)
#downsample = nn.Conv2d(in_channels, out_channels, kernel_size=(2, 2), stride=2)
self.layers.append(downsample)
out_channels *= 2
for module_num in range(args.num_modules_per_stack):
cell = Cell(spec, in_channels, out_channels)
self.layers.append(cell)
in_channels = out_channels
self.classifier = nn.Linear(out_channels, args.num_labels)
# for DARTS search
num_edge = np.shape(spec.matrix)[0]
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(searchspace)))
self._initialize_weights()
def forward(self, x, get_ints=True):
ints = []
for _, layer in enumerate(self.layers):
x = layer(x)
ints.append(x)
out = torch.mean(x, (2, 3))
ints.append(out)
out = self.classifier(out)
if get_ints:
return out, ints[-1]
else:
return out
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
if m.bias is not None:
m.bias.data.zero_()
pass
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
pass
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
pass
def get_weights(self):
xlist = []
for m in self.modules():
xlist.append(m.parameters())
return xlist
def get_alphas(self):
return [self.arch_parameters]
def genotype(self):
return str(spec)
class Cell(nn.Module):
"""
Builds the model using the adjacency matrix and op labels specified. Channels
controls the module output channel count but the interior channels are
determined via equally splitting the channel count whenever there is a
concatenation of Tensors.
"""
def __init__(self, spec, in_channels, out_channels):
super(Cell, self).__init__()
self.spec = spec
self.num_vertices = np.shape(self.spec.matrix)[0]
# vertex_channels[i] = number of output channels of vertex i
self.vertex_channels = ComputeVertexChannels(in_channels, out_channels, self.spec.matrix)
#self.vertex_channels = [in_channels] + [out_channels] * (self.num_vertices - 1)
# operation for each node
self.vertex_op = nn.ModuleList([None])
for t in range(1, self.num_vertices-1):
op = OP_MAP[spec.ops[t]](self.vertex_channels[t], self.vertex_channels[t])
self.vertex_op.append(op)
# operation for input on each vertex
self.input_op = nn.ModuleList([None])
for t in range(1, self.num_vertices):
if self.spec.matrix[0, t]:
self.input_op.append(Projection(in_channels, self.vertex_channels[t]))
else:
self.input_op.append(None)
def forward(self, x):
tensors = [x]
out_concat = []
for t in range(1, self.num_vertices-1):
fan_in = [Truncate(tensors[src], self.vertex_channels[t]) for src in range(1, t) if self.spec.matrix[src, t]]
fan_in_inds = [src for src in range(1, t) if self.spec.matrix[src, t]]
if self.spec.matrix[0, t]:
fan_in.append(self.input_op[t](x))
fan_in_inds = [0] + fan_in_inds
# perform operation on node
#vertex_input = torch.stack(fan_in, dim=0).sum(dim=0)
vertex_input = sum(fan_in)
#vertex_input = sum(fan_in) / len(fan_in)
vertex_output = self.vertex_op[t](vertex_input)
tensors.append(vertex_output)
if self.spec.matrix[t, self.num_vertices-1]:
out_concat.append(tensors[t])
if not out_concat: # empty list
assert self.spec.matrix[0, self.num_vertices-1]
outputs = self.input_op[self.num_vertices-1](tensors[0])
else:
if len(out_concat) == 1:
outputs = out_concat[0]
else:
outputs = torch.cat(out_concat, 1)
if self.spec.matrix[0, self.num_vertices-1]:
outputs += self.input_op[self.num_vertices-1](tensors[0])
#if self.spec.matrix[0, self.num_vertices-1]:
# out_concat.append(self.input_op[self.num_vertices-1](tensors[0]))
#outputs = sum(out_concat) / len(out_concat)
return outputs
def Projection(in_channels, out_channels):
"""1x1 projection (as in ResNet) followed by batch normalization and ReLU."""
return ConvBnRelu(in_channels, out_channels, 1)
def Truncate(inputs, channels):
"""Slice the inputs to channels if necessary."""
input_channels = inputs.size()[1]
if input_channels < channels:
raise ValueError('input channel < output channels for truncate')
elif input_channels == channels:
return inputs # No truncation necessary
else:
# Truncation should only be necessary when channel division leads to
# vertices with +1 channels. The input vertex should always be projected to
# the minimum channel count.
assert input_channels - channels == 1
return inputs[:, :channels, :, :]
def ComputeVertexChannels(in_channels, out_channels, matrix):
"""Computes the number of channels at every vertex.
Given the input channels and output channels, this calculates the number of
channels at each interior vertex. Interior vertices have the same number of
channels as the max of the channels of the vertices it feeds into. The output
channels are divided amongst the vertices that are directly connected to it.
When the division is not even, some vertices may receive an extra channel to
compensate.
Returns:
list of channel counts, in order of the vertices.
"""
num_vertices = np.shape(matrix)[0]
vertex_channels = [0] * num_vertices
vertex_channels[0] = in_channels
vertex_channels[num_vertices - 1] = out_channels
if num_vertices == 2:
# Edge case where module only has input and output vertices
return vertex_channels
# Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is
# the dst vertex. Summing over 0 gives the in-degree count of each vertex.
in_degree = np.sum(matrix[1:], axis=0)
interior_channels = out_channels // in_degree[num_vertices - 1]
correction = out_channels % in_degree[num_vertices - 1] # Remainder to add
# Set channels of vertices that flow directly to output
for v in range(1, num_vertices - 1):
if matrix[v, num_vertices - 1]:
vertex_channels[v] = interior_channels
if correction:
vertex_channels[v] += 1
correction -= 1
# Set channels for all other vertices to the max of the out edges, going
# backwards. (num_vertices - 2) index skipped because it only connects to
# output.
for v in range(num_vertices - 3, 0, -1):
if not matrix[v, num_vertices - 1]:
for dst in range(v + 1, num_vertices - 1):
if matrix[v, dst]:
vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst])
assert vertex_channels[v] > 0
# Sanity check, verify that channels never increase and final channels add up.
final_fan_in = 0
for v in range(1, num_vertices - 1):
if matrix[v, num_vertices - 1]:
final_fan_in += vertex_channels[v]
for dst in range(v + 1, num_vertices - 1):
if matrix[v, dst]:
assert vertex_channels[v] >= vertex_channels[dst]
assert final_fan_in == out_channels or num_vertices == 2
# num_vertices == 2 means only input/output nodes, so 0 fan-in
return vertex_channels

152
nas_101_api/model_spec.py Normal file
View File

@@ -0,0 +1,152 @@
"""Model specification for module connectivity individuals.
This module handles pruning the unused parts of the computation graph but should
avoid creating any TensorFlow models (this is done inside model_builder.py).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import numpy as np
from . import graph_util
# Graphviz is optional and only required for visualization.
try:
import graphviz # pylint: disable=g-import-not-at-top
except ImportError:
pass
class ModelSpec(object):
"""Model specification given adjacency matrix and labeling."""
def __init__(self, matrix, ops, data_format='channels_last'):
"""Initialize the module spec.
Args:
matrix: ndarray or nested list with shape [V, V] for the adjacency matrix.
ops: V-length list of labels for the base ops used. The first and last
elements are ignored because they are the input and output vertices
which have no operations. The elements are retained to keep consistent
indexing.
data_format: channels_last or channels_first.
Raises:
ValueError: invalid matrix or ops
"""
if not isinstance(matrix, np.ndarray):
matrix = np.array(matrix)
shape = np.shape(matrix)
if len(shape) != 2 or shape[0] != shape[1]:
raise ValueError('matrix must be square')
if shape[0] != len(ops):
raise ValueError('length of ops must match matrix dimensions')
if not is_upper_triangular(matrix):
raise ValueError('matrix must be upper triangular')
# Both the original and pruned matrices are deep copies of the matrix and
# ops so any changes to those after initialization are not recognized by the
# spec.
self.original_matrix = copy.deepcopy(matrix)
self.original_ops = copy.deepcopy(ops)
self.matrix = copy.deepcopy(matrix)
self.ops = copy.deepcopy(ops)
self.valid_spec = True
self._prune()
self.data_format = data_format
def _prune(self):
"""Prune the extraneous parts of the graph.
General procedure:
1) Remove parts of graph not connected to input.
2) Remove parts of graph not connected to output.
3) Reorder the vertices so that they are consecutive after steps 1 and 2.
These 3 steps can be combined by deleting the rows and columns of the
vertices that are not reachable from both the input and output (in reverse).
"""
num_vertices = np.shape(self.original_matrix)[0]
# DFS forward from input
visited_from_input = set([0])
frontier = [0]
while frontier:
top = frontier.pop()
for v in range(top + 1, num_vertices):
if self.original_matrix[top, v] and v not in visited_from_input:
visited_from_input.add(v)
frontier.append(v)
# DFS backward from output
visited_from_output = set([num_vertices - 1])
frontier = [num_vertices - 1]
while frontier:
top = frontier.pop()
for v in range(0, top):
if self.original_matrix[v, top] and v not in visited_from_output:
visited_from_output.add(v)
frontier.append(v)
# Any vertex that isn't connected to both input and output is extraneous to
# the computation graph.
extraneous = set(range(num_vertices)).difference(
visited_from_input.intersection(visited_from_output))
# If the non-extraneous graph is less than 2 vertices, the input is not
# connected to the output and the spec is invalid.
if len(extraneous) > num_vertices - 2:
self.matrix = None
self.ops = None
self.valid_spec = False
return
self.matrix = np.delete(self.matrix, list(extraneous), axis=0)
self.matrix = np.delete(self.matrix, list(extraneous), axis=1)
for index in sorted(extraneous, reverse=True):
del self.ops[index]
def hash_spec(self, canonical_ops):
"""Computes the isomorphism-invariant graph hash of this spec.
Args:
canonical_ops: list of operations in the canonical ordering which they
were assigned (i.e. the order provided in the config['available_ops']).
Returns:
MD5 hash of this spec which can be used to query the dataset.
"""
# Invert the operations back to integer label indices used in graph gen.
labeling = [-1] + [canonical_ops.index(op) for op in self.ops[1:-1]] + [-2]
return graph_util.hash_module(self.matrix, labeling)
def visualize(self):
"""Creates a dot graph. Can be visualized in colab directly."""
num_vertices = np.shape(self.matrix)[0]
g = graphviz.Digraph()
g.node(str(0), 'input')
for v in range(1, num_vertices - 1):
g.node(str(v), self.ops[v])
g.node(str(num_vertices - 1), 'output')
for src in range(num_vertices - 1):
for dst in range(src + 1, num_vertices):
if self.matrix[src, dst]:
g.edge(str(src), str(dst))
return g
def is_upper_triangular(matrix):
"""True if matrix is 0 on diagonal and below."""
for src in range(np.shape(matrix)[0]):
for dst in range(0, src + 1):
if matrix[src, dst] != 0:
return False
return True