add naswot
This commit is contained in:
129
graph_dit/naswot/pycls/core/net.py
Normal file
129
graph_dit/naswot/pycls/core/net.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Functions for manipulating networks."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
"""Performs ResNet-style weight initialization."""
|
||||
if isinstance(m, nn.Conv2d):
|
||||
# Note that there is no bias due to BN
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA
|
||||
zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma
|
||||
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.weight.data.normal_(mean=0.0, std=0.01)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_precise_bn_stats(model, loader):
|
||||
"""Computes precise BN stats on training data."""
|
||||
# Compute the number of minibatches to use
|
||||
num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader))
|
||||
# Retrieve the BN layers
|
||||
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
|
||||
# Initialize stats storage
|
||||
mus = [torch.zeros_like(bn.running_mean) for bn in bns]
|
||||
sqs = [torch.zeros_like(bn.running_var) for bn in bns]
|
||||
# Remember momentum values
|
||||
moms = [bn.momentum for bn in bns]
|
||||
# Disable momentum
|
||||
for bn in bns:
|
||||
bn.momentum = 1.0
|
||||
# Accumulate the stats across the data samples
|
||||
for inputs, _labels in itertools.islice(loader, num_iter):
|
||||
model(inputs.cuda())
|
||||
# Accumulate the stats for each BN layer
|
||||
for i, bn in enumerate(bns):
|
||||
m, v = bn.running_mean, bn.running_var
|
||||
sqs[i] += (v + m * m) / num_iter
|
||||
mus[i] += m / num_iter
|
||||
# Set the stats and restore momentum values
|
||||
for i, bn in enumerate(bns):
|
||||
bn.running_var = sqs[i] - mus[i] * mus[i]
|
||||
bn.running_mean = mus[i]
|
||||
bn.momentum = moms[i]
|
||||
|
||||
|
||||
def reset_bn_stats(model):
|
||||
"""Resets running BN stats."""
|
||||
for m in model.modules():
|
||||
if isinstance(m, torch.nn.BatchNorm2d):
|
||||
m.reset_running_stats()
|
||||
|
||||
|
||||
def complexity_conv2d(cx, w_in, w_out, k, stride, padding, groups=1, bias=False):
|
||||
"""Accumulates complexity of Conv2D into cx = (h, w, flops, params, acts)."""
|
||||
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||||
h = (h + 2 * padding - k) // stride + 1
|
||||
w = (w + 2 * padding - k) // stride + 1
|
||||
flops += k * k * w_in * w_out * h * w // groups
|
||||
params += k * k * w_in * w_out // groups
|
||||
flops += w_out if bias else 0
|
||||
params += w_out if bias else 0
|
||||
acts += w_out * h * w
|
||||
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||||
|
||||
|
||||
def complexity_batchnorm2d(cx, w_in):
|
||||
"""Accumulates complexity of BatchNorm2D into cx = (h, w, flops, params, acts)."""
|
||||
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||||
params += 2 * w_in
|
||||
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||||
|
||||
|
||||
def complexity_maxpool2d(cx, k, stride, padding):
|
||||
"""Accumulates complexity of MaxPool2d into cx = (h, w, flops, params, acts)."""
|
||||
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||||
h = (h + 2 * padding - k) // stride + 1
|
||||
w = (w + 2 * padding - k) // stride + 1
|
||||
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||||
|
||||
|
||||
def complexity(model):
|
||||
"""Compute model complexity (model can be model instance or model class)."""
|
||||
size = cfg.TRAIN.IM_SIZE
|
||||
cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0}
|
||||
cx = model.complexity(cx)
|
||||
return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]}
|
||||
|
||||
|
||||
def drop_connect(x, drop_ratio):
|
||||
"""Drop connect (adapted from DARTS)."""
|
||||
keep_ratio = 1.0 - drop_ratio
|
||||
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
|
||||
mask.bernoulli_(keep_ratio)
|
||||
x.div_(keep_ratio)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
def get_flat_weights(model):
|
||||
"""Gets all model weights as a single flat vector."""
|
||||
return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0)
|
||||
|
||||
|
||||
def set_flat_weights(model, flat_weights):
|
||||
"""Sets all model weights from a single flat vector."""
|
||||
k = 0
|
||||
for p in model.parameters():
|
||||
n = p.data.numel()
|
||||
p.data.copy_(flat_weights[k : (k + n)].view_as(p.data))
|
||||
k += n
|
||||
assert k == flat_weights.numel()
|
Reference in New Issue
Block a user