first commit
This commit is contained in:
289
NAS-Bench-201/models/utils.py
Normal file
289
NAS-Bench-201/models/utils.py
Normal file
@@ -0,0 +1,289 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import sde_lib
|
||||
|
||||
_MODELS = {}
|
||||
|
||||
|
||||
def register_model(cls=None, *, name=None):
|
||||
"""A decorator for registering model classes."""
|
||||
|
||||
def _register(cls):
|
||||
if name is None:
|
||||
local_name = cls.__name__
|
||||
else:
|
||||
local_name = name
|
||||
if local_name in _MODELS:
|
||||
raise ValueError(
|
||||
f'Already registered model with name: {local_name}')
|
||||
_MODELS[local_name] = cls
|
||||
return cls
|
||||
|
||||
if cls is None:
|
||||
return _register
|
||||
else:
|
||||
return _register(cls)
|
||||
|
||||
|
||||
def get_model(name):
|
||||
return _MODELS[name]
|
||||
|
||||
|
||||
def create_model(config):
|
||||
"""Create the model."""
|
||||
model_name = config.model.name
|
||||
model = get_model(model_name)(config)
|
||||
model = model.to(config.device)
|
||||
return model
|
||||
|
||||
|
||||
def get_model_fn(model, train=False):
|
||||
"""Create a function to give the output of the score-based model.
|
||||
|
||||
Args:
|
||||
model: The score model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
|
||||
Returns:
|
||||
A model function.
|
||||
"""
|
||||
|
||||
def model_fn(x, labels, *args, **kwargs):
|
||||
"""Compute the output of the score-based model.
|
||||
|
||||
Args:
|
||||
x: A mini-batch of input data (Adjacency matrices).
|
||||
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
|
||||
for different models.
|
||||
mask: Mask for adjacency matrices.
|
||||
|
||||
Returns:
|
||||
A tuple of (model output, new mutable states)
|
||||
"""
|
||||
if not train:
|
||||
model.eval()
|
||||
return model(x, labels, *args, **kwargs)
|
||||
else:
|
||||
model.train()
|
||||
return model(x, labels, *args, **kwargs)
|
||||
|
||||
return model_fn
|
||||
|
||||
|
||||
def get_score_fn(sde, model, train=False, continuous=False):
|
||||
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
model: A score model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
||||
|
||||
Returns:
|
||||
A score function.
|
||||
"""
|
||||
model_fn = get_model_fn(model, train=train)
|
||||
|
||||
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||||
def score_fn(x, t, *args, **kwargs):
|
||||
# Scale neural network output by standard deviation and flip sign
|
||||
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||||
labels = t * 999
|
||||
score = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
labels = t * (sde.N - 1)
|
||||
score = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
|
||||
labels.long()]
|
||||
|
||||
score = -score / std[:, None, None]
|
||||
return score
|
||||
|
||||
elif isinstance(sde, sde_lib.VESDE):
|
||||
def score_fn(x, t, *args, **kwargs):
|
||||
if continuous:
|
||||
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VE-trained models, t=0 corresponds to the highest noise level
|
||||
labels = sde.T - t
|
||||
labels *= sde.N - 1
|
||||
labels = torch.round(labels).long()
|
||||
|
||||
score = model_fn(x, labels, *args, **kwargs)
|
||||
return score
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"SDE class {sde.__class__.__name__} not yet supported.")
|
||||
|
||||
return score_fn
|
||||
|
||||
|
||||
def get_classifier_grad_fn(sde, classifier, train=False, continuous=False,
|
||||
regress=True, labels='max'):
|
||||
logit_fn = get_logit_fn(sde, classifier, train, continuous)
|
||||
|
||||
def classifier_grad_fn(x, t, *args, **kwargs):
|
||||
with torch.enable_grad():
|
||||
x_in = x.detach().requires_grad_(True)
|
||||
if regress:
|
||||
assert labels in ['max', 'min']
|
||||
logit = logit_fn(x_in, t, *args, **kwargs)
|
||||
if labels == 'max':
|
||||
prob = logit.sum()
|
||||
elif labels == 'min':
|
||||
prob = -logit.sum()
|
||||
else:
|
||||
logit = logit_fn(x_in, t, *args, **kwargs)
|
||||
log_prob = F.log_softmax(logit, dim=-1)
|
||||
prob = log_prob[range(len(logit)), labels.view(-1)].sum()
|
||||
classifier_grad = torch.autograd.grad(prob, x_in)[0]
|
||||
return classifier_grad
|
||||
|
||||
return classifier_grad_fn
|
||||
|
||||
|
||||
def get_logit_fn(sde, classifier, train=False, continuous=False):
|
||||
classifier_fn = get_model_fn(classifier, train=train)
|
||||
|
||||
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||||
def logit_fn(x, t, *args, **kwargs):
|
||||
# Scale neural network output by standard deviation and flip sign
|
||||
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||||
labels = t * 999
|
||||
logit = classifier_fn(x, labels, *args, **kwargs)
|
||||
else:
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
labels = t * (sde.N - 1)
|
||||
logit = classifier_fn(x, labels, *args, **kwargs)
|
||||
return logit
|
||||
|
||||
elif isinstance(sde, sde_lib.VESDE):
|
||||
def logit_fn(x, t, *args, **kwargs):
|
||||
if continuous:
|
||||
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VE-trained models, t=0 corresponds to the highest noise level
|
||||
labels = sde.T - t
|
||||
labels *= sde.N - 1
|
||||
labels = torch.round(labels).long()
|
||||
logit = classifier_fn(x, labels, *args, **kwargs)
|
||||
return logit
|
||||
|
||||
return logit_fn
|
||||
|
||||
|
||||
def get_predictor_fn(sde, model, train=False, continuous=False):
|
||||
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
model: A predictor model.
|
||||
train: `True` for training and `False` for evaluation.
|
||||
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
||||
|
||||
Returns:
|
||||
A score function.
|
||||
"""
|
||||
model_fn = get_model_fn(model, train=train)
|
||||
|
||||
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
|
||||
def predictor_fn(x, t, *args, **kwargs):
|
||||
# Scale neural network output by standard deviation and flip sign
|
||||
if continuous or isinstance(sde, sde_lib.subVPSDE):
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
|
||||
labels = t * 999
|
||||
pred = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VP-trained models, t=0 corresponds to the lowest noise level
|
||||
labels = t * (sde.N - 1)
|
||||
pred = model_fn(x, labels, *args, **kwargs)
|
||||
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
|
||||
labels.long()]
|
||||
|
||||
return pred
|
||||
|
||||
elif isinstance(sde, sde_lib.VESDE):
|
||||
def predictor_fn(x, t, *args, **kwargs):
|
||||
if continuous:
|
||||
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
||||
else:
|
||||
# For VE-trained models, t=0 corresponds to the highest noise level
|
||||
labels = sde.T - t
|
||||
labels *= sde.N - 1
|
||||
labels = torch.round(labels).long()
|
||||
|
||||
pred = model_fn(x, labels, *args, **kwargs)
|
||||
return pred
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"SDE class {sde.__class__.__name__} not yet supported.")
|
||||
|
||||
return predictor_fn
|
||||
|
||||
|
||||
def to_flattened_numpy(x):
|
||||
"""Flatten a torch tensor `x` and convert it to numpy."""
|
||||
return x.detach().cpu().numpy().reshape((-1,))
|
||||
|
||||
|
||||
def from_flattened_numpy(x, shape):
|
||||
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
||||
return torch.from_numpy(x.reshape(shape))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def mask_adj2node(adj_mask):
|
||||
"""Convert batched adjacency mask matrices to batched node mask matrices.
|
||||
|
||||
Args:
|
||||
adj_mask: [B, N, N] Batched adjacency mask matrices without self-loop edge.
|
||||
|
||||
Output:
|
||||
node_mask: [B, N] Batched node mask matrices indicating the valid nodes.
|
||||
"""
|
||||
|
||||
batch_size, max_num_nodes, _ = adj_mask.shape
|
||||
|
||||
node_mask = adj_mask[:, 0, :].clone()
|
||||
node_mask[:, 0] = 1
|
||||
|
||||
return node_mask
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_rw_feat(k_step, dense_adj):
|
||||
"""Compute k_step Random Walk for given dense adjacency matrix."""
|
||||
|
||||
rw_list = []
|
||||
deg = dense_adj.sum(-1, keepdims=True)
|
||||
AD = dense_adj / (deg + 1e-8)
|
||||
rw_list.append(AD)
|
||||
|
||||
for _ in range(k_step):
|
||||
rw = torch.bmm(rw_list[-1], AD)
|
||||
rw_list.append(rw)
|
||||
rw_map = torch.stack(rw_list[1:], dim=1) # [B, k_step, N, N]
|
||||
|
||||
rw_landing = torch.diagonal(
|
||||
rw_map, offset=0, dim1=2, dim2=3) # [B, k_step, N]
|
||||
rw_landing = rw_landing.permute(0, 2, 1) # [B, N, rw_depth]
|
||||
|
||||
# get the shortest path distance indices
|
||||
tmp_rw = rw_map.sort(dim=1)[0]
|
||||
spd_ind = (tmp_rw <= 0).sum(dim=1) # [B, N, N]
|
||||
|
||||
spd_onehot = torch.nn.functional.one_hot(
|
||||
spd_ind, num_classes=k_step+1).to(torch.float)
|
||||
spd_onehot = spd_onehot.permute(0, 3, 1, 2) # [B, kstep, N, N]
|
||||
|
||||
return rw_landing, spd_onehot
|
Reference in New Issue
Block a user