add autodl
This commit is contained in:
74
AutoDL-Projects/xautodl/models/clone_weights.py
Normal file
74
AutoDL-Projects/xautodl/models/clone_weights.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def copy_conv(module, init):
|
||||
assert isinstance(module, nn.Conv2d), "invalid module : {:}".format(module)
|
||||
assert isinstance(init, nn.Conv2d), "invalid module : {:}".format(init)
|
||||
new_i, new_o = module.in_channels, module.out_channels
|
||||
module.weight.copy_(init.weight.detach()[:new_o, :new_i])
|
||||
if module.bias is not None:
|
||||
module.bias.copy_(init.bias.detach()[:new_o])
|
||||
|
||||
|
||||
def copy_bn(module, init):
|
||||
assert isinstance(module, nn.BatchNorm2d), "invalid module : {:}".format(module)
|
||||
assert isinstance(init, nn.BatchNorm2d), "invalid module : {:}".format(init)
|
||||
num_features = module.num_features
|
||||
if module.weight is not None:
|
||||
module.weight.copy_(init.weight.detach()[:num_features])
|
||||
if module.bias is not None:
|
||||
module.bias.copy_(init.bias.detach()[:num_features])
|
||||
if module.running_mean is not None:
|
||||
module.running_mean.copy_(init.running_mean.detach()[:num_features])
|
||||
if module.running_var is not None:
|
||||
module.running_var.copy_(init.running_var.detach()[:num_features])
|
||||
|
||||
|
||||
def copy_fc(module, init):
|
||||
assert isinstance(module, nn.Linear), "invalid module : {:}".format(module)
|
||||
assert isinstance(init, nn.Linear), "invalid module : {:}".format(init)
|
||||
new_i, new_o = module.in_features, module.out_features
|
||||
module.weight.copy_(init.weight.detach()[:new_o, :new_i])
|
||||
if module.bias is not None:
|
||||
module.bias.copy_(init.bias.detach()[:new_o])
|
||||
|
||||
|
||||
def copy_base(module, init):
|
||||
assert type(module).__name__ in [
|
||||
"ConvBNReLU",
|
||||
"Downsample",
|
||||
], "invalid module : {:}".format(module)
|
||||
assert type(init).__name__ in [
|
||||
"ConvBNReLU",
|
||||
"Downsample",
|
||||
], "invalid module : {:}".format(init)
|
||||
if module.conv is not None:
|
||||
copy_conv(module.conv, init.conv)
|
||||
if module.bn is not None:
|
||||
copy_bn(module.bn, init.bn)
|
||||
|
||||
|
||||
def copy_basic(module, init):
|
||||
copy_base(module.conv_a, init.conv_a)
|
||||
copy_base(module.conv_b, init.conv_b)
|
||||
if module.downsample is not None:
|
||||
if init.downsample is not None:
|
||||
copy_base(module.downsample, init.downsample)
|
||||
# else:
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
|
||||
def init_from_model(network, init_model):
|
||||
with torch.no_grad():
|
||||
copy_fc(network.classifier, init_model.classifier)
|
||||
for base, target in zip(init_model.layers, network.layers):
|
||||
assert (
|
||||
type(base).__name__ == type(target).__name__
|
||||
), "invalid type : {:} vs {:}".format(base, target)
|
||||
if type(base).__name__ == "ConvBNReLU":
|
||||
copy_base(target, base)
|
||||
elif type(base).__name__ == "ResNetBasicblock":
|
||||
copy_basic(target, base)
|
||||
else:
|
||||
raise ValueError("unknown type name : {:}".format(type(base).__name__))
|
Reference in New Issue
Block a user