v2
This commit is contained in:
89
pycls/models/regnet.py
Normal file
89
pycls/models/regnet.py
Normal file
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""RegNet models."""
|
||||
|
||||
import numpy as np
|
||||
from pycls.core.config import cfg
|
||||
from pycls.models.anynet import AnyNet
|
||||
|
||||
|
||||
def quantize_float(f, q):
|
||||
"""Converts a float to closest non-zero int divisible by q."""
|
||||
return int(round(f / q) * q)
|
||||
|
||||
|
||||
def adjust_ws_gs_comp(ws, bms, gs):
|
||||
"""Adjusts the compatibility of widths and groups."""
|
||||
ws_bot = [int(w * b) for w, b in zip(ws, bms)]
|
||||
gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
|
||||
ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
|
||||
ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
|
||||
return ws, gs
|
||||
|
||||
|
||||
def get_stages_from_blocks(ws, rs):
|
||||
"""Gets ws/ds of network at each stage from per block values."""
|
||||
ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
|
||||
ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
|
||||
s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
|
||||
s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
|
||||
return s_ws, s_ds
|
||||
|
||||
|
||||
def generate_regnet(w_a, w_0, w_m, d, q=8):
|
||||
"""Generates per block ws from RegNet parameters."""
|
||||
assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
|
||||
ws_cont = np.arange(d) * w_a + w_0
|
||||
ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
|
||||
ws = w_0 * np.power(w_m, ks)
|
||||
ws = np.round(np.divide(ws, q)) * q
|
||||
num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
|
||||
ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
|
||||
return ws, num_stages, max_stage, ws_cont
|
||||
|
||||
|
||||
class RegNet(AnyNet):
|
||||
"""RegNet model."""
|
||||
|
||||
@staticmethod
|
||||
def get_args():
|
||||
"""Convert RegNet to AnyNet parameter format."""
|
||||
# Generate RegNet ws per block
|
||||
w_a, w_0, w_m, d = cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH
|
||||
ws, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
|
||||
# Convert to per stage format
|
||||
s_ws, s_ds = get_stages_from_blocks(ws, ws)
|
||||
# Use the same gw, bm and ss for each stage
|
||||
s_gs = [cfg.REGNET.GROUP_W for _ in range(num_stages)]
|
||||
s_bs = [cfg.REGNET.BOT_MUL for _ in range(num_stages)]
|
||||
s_ss = [cfg.REGNET.STRIDE for _ in range(num_stages)]
|
||||
# Adjust the compatibility of ws and gws
|
||||
s_ws, s_gs = adjust_ws_gs_comp(s_ws, s_bs, s_gs)
|
||||
# Get AnyNet arguments defining the RegNet
|
||||
return {
|
||||
"stem_type": cfg.REGNET.STEM_TYPE,
|
||||
"stem_w": cfg.REGNET.STEM_W,
|
||||
"block_type": cfg.REGNET.BLOCK_TYPE,
|
||||
"ds": s_ds,
|
||||
"ws": s_ws,
|
||||
"ss": s_ss,
|
||||
"bms": s_bs,
|
||||
"gws": s_gs,
|
||||
"se_r": cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None,
|
||||
"nc": cfg.MODEL.NUM_CLASSES,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
kwargs = RegNet.get_args()
|
||||
super(RegNet, self).__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def complexity(cx, **kwargs):
|
||||
"""Computes model complexity. If you alter the model, make sure to update."""
|
||||
kwargs = RegNet.get_args() if not kwargs else kwargs
|
||||
return AnyNet.complexity(cx, **kwargs)
|
Reference in New Issue
Block a user