upload
This commit is contained in:
88
pycls/core/builders.py
Normal file
88
pycls/core/builders.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Model and loss construction functions."""
|
||||
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
from pycls.models.anynet import AnyNet
|
||||
from pycls.models.effnet import EffNet
|
||||
from pycls.models.regnet import RegNet
|
||||
from pycls.models.resnet import ResNet
|
||||
from pycls.models.nas.nas import NAS
|
||||
from pycls.models.nas.nas_search import NAS_Search
|
||||
from pycls.models.nas_bench.model_builder import NAS_Bench
|
||||
|
||||
|
||||
class LabelSmoothedCrossEntropyLoss(torch.nn.Module):
|
||||
"""CrossEntropyLoss with label smoothing."""
|
||||
def __init__(self):
|
||||
super(LabelSmoothedCrossEntropyLoss, self).__init__()
|
||||
self.eps = cfg.MODEL.LABEL_SMOOTHING_EPS
|
||||
self.num_classes = cfg.MODEL.NUM_CLASSES
|
||||
|
||||
def forward(self, logits, target):
|
||||
pred = logits.log_softmax(dim=-1)
|
||||
with torch.no_grad():
|
||||
target_dist = torch.ones_like(pred) * self.eps / (self.num_classes - 1)
|
||||
target_dist.scatter_(-1, target.unsqueeze(-1), 1 - self.eps)
|
||||
return (-target_dist * pred).sum(dim=-1).mean()
|
||||
|
||||
|
||||
# Supported models
|
||||
_models = {
|
||||
"anynet": AnyNet,
|
||||
"effnet": EffNet,
|
||||
"resnet": ResNet,
|
||||
"regnet": RegNet,
|
||||
"nas": NAS,
|
||||
"nas_search": NAS_Search,
|
||||
"nas_bench": NAS_Bench,
|
||||
}
|
||||
|
||||
# Supported loss functions
|
||||
_loss_funs = {
|
||||
"cross_entropy": torch.nn.CrossEntropyLoss,
|
||||
"label_smoothed_cross_entropy": LabelSmoothedCrossEntropyLoss,
|
||||
}
|
||||
|
||||
|
||||
def get_model():
|
||||
"""Gets the model class specified in the config."""
|
||||
err_str = "Model type '{}' not supported"
|
||||
assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE)
|
||||
return _models[cfg.MODEL.TYPE]
|
||||
|
||||
|
||||
def get_loss_fun():
|
||||
"""Gets the loss function class specified in the config."""
|
||||
err_str = "Loss function type '{}' not supported"
|
||||
assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS)
|
||||
return _loss_funs[cfg.MODEL.LOSS_FUN]
|
||||
|
||||
|
||||
def build_model():
|
||||
"""Builds the model."""
|
||||
return get_model()()
|
||||
|
||||
|
||||
def build_loss_fun():
|
||||
"""Build the loss function."""
|
||||
if cfg.TASK == "seg":
|
||||
return get_loss_fun()(ignore_index=255)
|
||||
else:
|
||||
return get_loss_fun()()
|
||||
|
||||
|
||||
def register_model(name, ctor):
|
||||
"""Registers a model dynamically."""
|
||||
_models[name] = ctor
|
||||
|
||||
|
||||
def register_loss_fun(name, ctor):
|
||||
"""Registers a loss function dynamically."""
|
||||
_loss_funs[name] = ctor
|
Reference in New Issue
Block a user