Update xmisc with yaml
This commit is contained in:
@@ -3,63 +3,69 @@ import torch.nn as nn
|
||||
|
||||
|
||||
class ImageNetHEAD(nn.Sequential):
|
||||
def __init__(self, C, stride=2):
|
||||
super(ImageNetHEAD, self).__init__()
|
||||
self.add_module('conv1', nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False))
|
||||
self.add_module('bn1' , nn.BatchNorm2d(C // 2))
|
||||
self.add_module('relu1', nn.ReLU(inplace=True))
|
||||
self.add_module('conv2', nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False))
|
||||
self.add_module('bn2' , nn.BatchNorm2d(C))
|
||||
def __init__(self, C, stride=2):
|
||||
super(ImageNetHEAD, self).__init__()
|
||||
self.add_module(
|
||||
"conv1",
|
||||
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
)
|
||||
self.add_module("bn1", nn.BatchNorm2d(C // 2))
|
||||
self.add_module("relu1", nn.ReLU(inplace=True))
|
||||
self.add_module(
|
||||
"conv2",
|
||||
nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False),
|
||||
)
|
||||
self.add_module("bn2", nn.BatchNorm2d(C))
|
||||
|
||||
|
||||
class CifarHEAD(nn.Sequential):
|
||||
def __init__(self, C):
|
||||
super(CifarHEAD, self).__init__()
|
||||
self.add_module('conv', nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False))
|
||||
self.add_module('bn', nn.BatchNorm2d(C))
|
||||
def __init__(self, C):
|
||||
super(CifarHEAD, self).__init__()
|
||||
self.add_module("conv", nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False))
|
||||
self.add_module("bn", nn.BatchNorm2d(C))
|
||||
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(
|
||||
5, stride=3, padding=0, count_include_pad=False
|
||||
), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
|
||||
|
||||
class AuxiliaryHeadImageNet(nn.Module):
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 14x14"""
|
||||
super(AuxiliaryHeadImageNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 14x14"""
|
||||
super(AuxiliaryHeadImageNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
|
@@ -1,6 +1,8 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
######################################################################
|
||||
# This folder is deprecated, which is re-organized in "xalgorithms". #
|
||||
######################################################################
|
||||
from .starts import prepare_seed
|
||||
from .starts import prepare_logger
|
||||
from .starts import get_machine_info
|
||||
|
@@ -47,7 +47,7 @@ class SuperSelfAttention(SuperModule):
|
||||
self.v_fc = SuperLinear(input_dim, input_dim, bias=qkv_bias)
|
||||
|
||||
self.attn_drop = SuperDrop(attn_drop or 0.0, [-1, -1, -1, -1], recover=True)
|
||||
if proj_dim is None:
|
||||
if proj_dim is not None:
|
||||
self.proj = SuperLinear(input_dim, proj_dim)
|
||||
self.proj_drop = SuperDropout(proj_drop or 0.0)
|
||||
else:
|
||||
|
8
xautodl/xmisc/__init__.py
Normal file
8
xautodl/xmisc/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||
#####################################################
|
||||
from .module_utils import call_by_dict
|
||||
from .module_utils import call_by_yaml
|
||||
from .module_utils import nested_call_by_dict
|
||||
from .module_utils import nested_call_by_yaml
|
||||
from .yaml_utils import load_yaml
|
81
xautodl/xmisc/module_utils.py
Normal file
81
xautodl/xmisc/module_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
|
||||
#####################################################
|
||||
from typing import Union, Dict, Text, Any
|
||||
import importlib
|
||||
|
||||
from .yaml_utils import load_yaml
|
||||
|
||||
CLS_FUNC_KEY = "class_or_func"
|
||||
KEYS = (CLS_FUNC_KEY, "module_path", "args", "kwargs")
|
||||
|
||||
|
||||
def has_key_words(xdict):
|
||||
if not isinstance(xdict, dict):
|
||||
return False
|
||||
key_set = set(KEYS)
|
||||
cur_set = set(xdict.keys())
|
||||
return key_set.intersection(cur_set) == key_set
|
||||
|
||||
|
||||
def get_module_by_module_path(module_path):
|
||||
"""Load the module from the path."""
|
||||
|
||||
if module_path.endswith(".py"):
|
||||
module_spec = importlib.util.spec_from_file_location("", module_path)
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
module_spec.loader.exec_module(module)
|
||||
else:
|
||||
module = importlib.import_module(module_path)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def call_by_dict(config: Dict[Text, Any], *args, **kwargs) -> object:
|
||||
"""
|
||||
get initialized instance with config
|
||||
Parameters
|
||||
----------
|
||||
config : a dictionary, such as:
|
||||
{
|
||||
'cls_or_func': 'ClassName',
|
||||
'args': list,
|
||||
'kwargs': dict,
|
||||
'model_path': a string indicating the path,
|
||||
}
|
||||
Returns
|
||||
-------
|
||||
object:
|
||||
An initialized object based on the config info
|
||||
"""
|
||||
module = get_module_by_module_path(config["module_path"])
|
||||
cls_or_func = getattr(module, config[CLS_FUNC_KEY])
|
||||
args = tuple(list(config["args"]) + list(args))
|
||||
kwargs = {**config["kwargs"], **kwargs}
|
||||
return cls_or_func(*args, **kwargs)
|
||||
|
||||
|
||||
def call_by_yaml(path, *args, **kwargs) -> object:
|
||||
config = load_yaml(path)
|
||||
return call_by_config(config, *args, **kwargs)
|
||||
|
||||
|
||||
def nested_call_by_dict(config: Union[Dict[Text, Any], Any], *args, **kwargs) -> object:
|
||||
"""Similar to `call_by_dict`, but differently, the args may contain another dict needs to be called."""
|
||||
if not has_key_words(config):
|
||||
return config
|
||||
module = get_module_by_module_path(config["module_path"])
|
||||
cls_or_func = getattr(module, config[CLS_FUNC_KEY])
|
||||
args = tuple(list(config["args"]) + list(args))
|
||||
kwargs = {**config["kwargs"], **kwargs}
|
||||
# check whether there are nested special dict
|
||||
new_args = [nested_call_by_dict(x) for x in args]
|
||||
new_kwargs = {}
|
||||
for key, x in kwargs.items():
|
||||
new_kwargs[key] = nested_call_by_dict(x)
|
||||
return cls_or_func(*new_args, **new_kwargs)
|
||||
|
||||
|
||||
def nested_call_by_yaml(path, *args, **kwargs) -> object:
|
||||
config = load_yaml(path)
|
||||
return nested_call_by_dict(config, *args, **kwargs)
|
13
xautodl/xmisc/yaml_utils.py
Normal file
13
xautodl/xmisc/yaml_utils.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||
#####################################################
|
||||
import os
|
||||
import yaml
|
||||
|
||||
|
||||
def load_yaml(path):
|
||||
if not os.path.isfile(path):
|
||||
raise ValueError("{:} is not a file.".format(path))
|
||||
with open(path, "r") as stream:
|
||||
data = yaml.safe_load(stream)
|
||||
return data
|
Reference in New Issue
Block a user