update NAS-Bench
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import os, sys, time, random, argparse
|
||||
import random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
def obtain_attention_args():
|
||||
|
@@ -1,7 +1,7 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
##################################################
|
||||
import os, sys, time, random, argparse
|
||||
import random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
def obtain_basic_args():
|
||||
|
@@ -1,4 +1,4 @@
|
||||
import os, sys, time, random, argparse
|
||||
import random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
def obtain_cls_init_args():
|
||||
|
@@ -1,4 +1,4 @@
|
||||
import os, sys, time, random, argparse
|
||||
import random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
def obtain_cls_kd_args():
|
||||
|
@@ -4,7 +4,7 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
import os, sys, json
|
||||
import os, json
|
||||
from os import path as osp
|
||||
from pathlib import Path
|
||||
from collections import namedtuple
|
||||
|
@@ -39,6 +39,13 @@ def get_cell_based_tiny_net(config):
|
||||
genotype = CellStructure.str2structure(config.arch_str)
|
||||
else: raise ValueError('Can not find genotype from this config : {:}'.format(config))
|
||||
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
|
||||
elif config.name == 'infer.shape.tiny':
|
||||
from .shape_infers import DynamicShapeTinyNet
|
||||
if isinstance(config.channels, str):
|
||||
channels = tuple([int(x) for x in config.channels.split(':')])
|
||||
else: channels = config.channels
|
||||
genotype = CellStructure.str2structure(config.genotype)
|
||||
return DynamicShapeTinyNet(channels, genotype, config.num_classes)
|
||||
elif config.name == 'infer.nasnet-cifar':
|
||||
from .cell_infers import NASNetonCIFAR
|
||||
raise NotImplementedError
|
||||
|
@@ -1,7 +1,6 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .cells import InferCell
|
||||
|
@@ -172,14 +172,19 @@ class FactorizedReduce(nn.Module):
|
||||
for i in range(2):
|
||||
self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) )
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
elif stride == 1:
|
||||
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
|
||||
else:
|
||||
raise ValueError('Invalid stride : {:}'.format(stride))
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
|
||||
if self.stride == 2:
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
|
||||
else:
|
||||
out = self.conv(x)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
|
@@ -14,11 +14,11 @@ from .search_model_darts_nasnet import NASNetworkDARTS
|
||||
|
||||
|
||||
nas201_super_nets = {'DARTS-V1': TinyNetworkDarts,
|
||||
'DARTS-V2': TinyNetworkDarts,
|
||||
'GDAS' : TinyNetworkGDAS,
|
||||
'SETN' : TinyNetworkSETN,
|
||||
'ENAS' : TinyNetworkENAS,
|
||||
'RANDOM' : TinyNetworkRANDOM}
|
||||
"DARTS-V2": TinyNetworkDarts,
|
||||
"GDAS": TinyNetworkGDAS,
|
||||
"SETN": TinyNetworkSETN,
|
||||
"ENAS": TinyNetworkENAS,
|
||||
"RANDOM": TinyNetworkRANDOM}
|
||||
|
||||
nasnet_super_nets = {'GDAS' : NASNetworkGDAS,
|
||||
'DARTS': NASNetworkDARTS}
|
||||
nasnet_super_nets = {"GDAS": NASNetworkGDAS,
|
||||
"DARTS": NASNetworkDARTS}
|
||||
|
@@ -1,5 +1,5 @@
|
||||
####################
|
||||
# DARTS, ICLR 2019 #
|
||||
# DARTS, ICLR 2019 #
|
||||
####################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -11,7 +11,8 @@ from .search_cells import NASNetSearchCell as SearchCell
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkDARTS(nn.Module):
|
||||
|
||||
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
|
||||
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int,
|
||||
num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
|
||||
super(NASNetworkDARTS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
|
@@ -6,14 +6,15 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from typing import List, Text, Dict
|
||||
from .search_cells import NASNetSearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkSETN(nn.Module):
|
||||
|
||||
def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats):
|
||||
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int,
|
||||
num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
|
||||
super(NASNetworkSETN, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
@@ -45,6 +46,16 @@ class NASNetworkSETN(nn.Module):
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||
self.mode = 'urs'
|
||||
self.dynamic_cell = None
|
||||
|
||||
def set_cal_mode(self, mode, dynamic_cell=None):
|
||||
assert mode in ['urs', 'joint', 'select', 'dynamic']
|
||||
self.mode = mode
|
||||
if mode == 'dynamic':
|
||||
self.dynamic_cell = deepcopy(dynamic_cell)
|
||||
else:
|
||||
self.dynamic_cell = None
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
||||
@@ -70,6 +81,24 @@ class NASNetworkSETN(nn.Module):
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def dync_genotype(self, use_random=False):
|
||||
genotypes = []
|
||||
with torch.no_grad():
|
||||
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
if use_random:
|
||||
op_name = random.choice(self.op_names)
|
||||
else:
|
||||
weights = alphas_cpu[ self.edge2index[node_str] ]
|
||||
op_index = torch.multinomial(weights, 1).item()
|
||||
op_name = self.op_names[ op_index ]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append( tuple(xlist) )
|
||||
return Structure( genotypes )
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights):
|
||||
gene = []
|
||||
@@ -94,9 +123,6 @@ class NASNetworkSETN(nn.Module):
|
||||
def forward(self, inputs):
|
||||
normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1)
|
||||
reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1)
|
||||
with torch.no_grad():
|
||||
normal_hardwts_cpu = normal_hardwts.detach().cpu()
|
||||
reduce_hardwts_cpu = reduce_hardwts.detach().cpu()
|
||||
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import math, torch
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import math, torch
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import math
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import math, torch
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
@@ -1,7 +1,10 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
|
||||
from torch import nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func, parse_channel_info
|
||||
from ..SharedUtils import parse_channel_info
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
58
lib/models/shape_infers/InferTinyCellNet.py
Normal file
58
lib/models/shape_infers/InferTinyCellNet.py
Normal file
@@ -0,0 +1,58 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from typing import List, Text, Any
|
||||
import torch.nn as nn
|
||||
from models.cell_operations import ResNetBasicblock
|
||||
from models.cell_infers.cells import InferCell
|
||||
|
||||
|
||||
class DynamicShapeTinyNet(nn.Module):
|
||||
|
||||
def __init__(self, channels: List[int], genotype: Any, num_classes: int):
|
||||
super(DynamicShapeTinyNet, self).__init__()
|
||||
self._channels = channels
|
||||
if len(channels) % 3 != 2:
|
||||
raise ValueError('invalid number of layers : {:}'.format(len(channels)))
|
||||
self._num_stage = N = len(channels) // 3
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(channels[0]))
|
||||
|
||||
# layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
c_prev = channels[0]
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)):
|
||||
if reduction : cell = ResNetBasicblock(c_prev, c_curr, 2, True)
|
||||
else : cell = InferCell(genotype, c_prev, c_curr, 1)
|
||||
self.cells.append( cell )
|
||||
c_prev = cell.out_dim
|
||||
self._num_layer = len(self.cells)
|
||||
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(c_prev, num_classes)
|
||||
|
||||
def get_message(self) -> Text:
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_channels}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling( out )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
@@ -1,5 +1,9 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from .InferCifarResNet_width import InferWidthCifarResNet
|
||||
from .InferImagenetResNet import InferImagenetResNet
|
||||
from .InferImagenetResNet import InferImagenetResNet
|
||||
from .InferCifarResNet_depth import InferDepthCifarResNet
|
||||
from .InferCifarResNet import InferCifarResNet
|
||||
from .InferMobileNetV2 import InferMobileNetV2
|
||||
from .InferCifarResNet import InferCifarResNet
|
||||
from .InferMobileNetV2 import InferMobileNetV2
|
||||
from .InferTinyCellNet import DynamicShapeTinyNet
|
@@ -7,7 +7,8 @@
|
||||
# [2020.03.08] Next version (coming soon)
|
||||
#
|
||||
#
|
||||
import os, sys, copy, random, torch, numpy as np
|
||||
import os, copy, random, torch, numpy as np
|
||||
from typing import List, Text, Union, Dict, Any
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
|
||||
@@ -43,7 +44,7 @@ This is the class for API of NAS-Bench-201.
|
||||
class NASBench201API(object):
|
||||
|
||||
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||
def __init__(self, file_path_or_dict, verbose=True):
|
||||
def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True):
|
||||
if isinstance(file_path_or_dict, str):
|
||||
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||
@@ -69,7 +70,7 @@ class NASBench201API(object):
|
||||
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
|
||||
self.archstr2index[ arch ] = idx
|
||||
|
||||
def __getitem__(self, index):
|
||||
def __getitem__(self, index: int):
|
||||
return copy.deepcopy( self.meta_archs[index] )
|
||||
|
||||
def __len__(self):
|
||||
@@ -99,7 +100,7 @@ class NASBench201API(object):
|
||||
|
||||
# Overwrite all information of the 'index'-th architecture in the search space.
|
||||
# It will load its data from 'archive_root'.
|
||||
def reload(self, archive_root, index):
|
||||
def reload(self, archive_root: Text, index: int):
|
||||
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
||||
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index)
|
||||
@@ -141,7 +142,8 @@ class NASBench201API(object):
|
||||
# -- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
# -- cifar100 : training the model on the CIFAR-100 training set.
|
||||
# -- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False):
|
||||
def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None,
|
||||
use_12epochs_result: bool = False):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
|
||||
@@ -177,7 +179,7 @@ class NASBench201API(object):
|
||||
return best_index, highest_accuracy
|
||||
|
||||
# return the topology structure of the `index`-th architecture
|
||||
def arch(self, index):
|
||||
def arch(self, index: int):
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
||||
return copy.deepcopy(self.meta_archs[index])
|
||||
|
||||
@@ -238,7 +240,7 @@ class NASBench201API(object):
|
||||
# `is_random`
|
||||
# When is_random=True, the performance of a random architecture will be returned
|
||||
# When is_random=False, the performanceo of all trials will be averaged.
|
||||
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
||||
def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
archresult = arch2infos[index]
|
||||
@@ -301,7 +303,7 @@ class NASBench201API(object):
|
||||
If the index < 0: it will loop for all architectures and print their information one by one.
|
||||
else: it will print the information of the 'index'-th archiitecture.
|
||||
"""
|
||||
def show(self, index=-1):
|
||||
def show(self, index: int = -1) -> None:
|
||||
if index < 0: # show all architectures
|
||||
print(self)
|
||||
for i, idx in enumerate(self.evaluated_indexes):
|
||||
@@ -336,8 +338,8 @@ class NASBench201API(object):
|
||||
# for i, node in enumerate(arch):
|
||||
# print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
|
||||
@staticmethod
|
||||
def str2lists(xstr):
|
||||
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||
def str2lists(xstr: Text) -> List[Any]:
|
||||
# assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||
nodestrs = xstr.split('+')
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(nodestrs):
|
||||
|
@@ -3,6 +3,8 @@
|
||||
##################################################
|
||||
from .starts import prepare_seed, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint
|
||||
from .optimizers import get_optim_scheduler
|
||||
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
|
||||
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
|
||||
|
||||
def get_procedures(procedure):
|
||||
from .basic_main import basic_train, basic_valid
|
||||
|
129
lib/procedures/funcs_nasbench.py
Normal file
129
lib/procedures/funcs_nasbench.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
import time, torch
|
||||
from procedures import prepare_seed, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net
|
||||
|
||||
|
||||
__all__ = ['evaluate_for_seed', 'pure_evaluate']
|
||||
|
||||
|
||||
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
latencies = []
|
||||
network.eval()
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
inputs = inputs.cuda(non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
batch_time.update(time.time() - end)
|
||||
if batch is None or batch == inputs.size(0):
|
||||
batch = inputs.size(0)
|
||||
latencies.append( batch_time.val - data_time.val )
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
end = time.time()
|
||||
if len(latencies) > 2: latencies = latencies[1:]
|
||||
return losses.avg, top1.avg, top5.avg, latencies
|
||||
|
||||
|
||||
|
||||
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == 'train' : network.train()
|
||||
elif mode == 'valid': network.eval()
|
||||
else: raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
|
||||
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
if mode == 'train': optimizer.zero_grad()
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# backward
|
||||
if mode == 'train':
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
# count time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||
|
||||
|
||||
def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed: int, logger):
|
||||
|
||||
prepare_seed(seed) # random seed
|
||||
net = get_cell_based_tiny_net(arch_config)
|
||||
#net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||||
flop, param = get_model_infos(net, opt_config.xshape)
|
||||
logger.log('Network : {:}'.format(net.get_message()), False)
|
||||
logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed))
|
||||
logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param))
|
||||
# train and valid
|
||||
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config)
|
||||
network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup
|
||||
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
|
||||
train_times , valid_times, lrs = {}, {}, {}
|
||||
for epoch in range(total_epoch):
|
||||
scheduler.update(epoch, 0.0)
|
||||
lr = min(scheduler.get_lr())
|
||||
train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
|
||||
train_losses[epoch] = train_loss
|
||||
train_acc1es[epoch] = train_acc1
|
||||
train_acc5es[epoch] = train_acc5
|
||||
train_times [epoch] = train_tm
|
||||
lrs[epoch] = lr
|
||||
with torch.no_grad():
|
||||
for key, xloder in valid_loaders.items():
|
||||
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder , network, criterion, None, None, 'valid')
|
||||
valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss
|
||||
valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1
|
||||
valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5
|
||||
valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) )
|
||||
logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5, lr))
|
||||
info_seed = {'flop' : flop,
|
||||
'param': param,
|
||||
'arch_config' : arch_config._asdict(),
|
||||
'opt_config' : opt_config._asdict(),
|
||||
'total_epoch' : total_epoch ,
|
||||
'train_losses': train_losses,
|
||||
'train_acc1es': train_acc1es,
|
||||
'train_acc5es': train_acc5es,
|
||||
'train_times' : train_times,
|
||||
'valid_losses': valid_losses,
|
||||
'valid_acc1es': valid_acc1es,
|
||||
'valid_acc5es': valid_acc5es,
|
||||
'valid_times' : valid_times,
|
||||
'learning_rates': lrs,
|
||||
'net_state_dict': net.state_dict(),
|
||||
'net_string' : '{:}'.format(net),
|
||||
'finish-train': True
|
||||
}
|
||||
return info_seed
|
Reference in New Issue
Block a user