change batchsize in DARTS-NASNet to 64 ; add some type checking

This commit is contained in:
D-X-Y
2020-02-07 10:15:58 +11:00
parent 923b0fe9cf
commit 1efe3cbccf
4 changed files with 16 additions and 10 deletions

View File

@@ -2,6 +2,7 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from os import path as osp
from typing import List, Text
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \
@@ -42,7 +43,7 @@ def get_cell_based_tiny_net(config):
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
def get_search_spaces(xtype, name):
def get_search_spaces(xtype, name) -> List[Text]:
if xtype == 'cell':
from .cell_operations import SearchSpaceNames
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())

View File

@@ -4,6 +4,7 @@
import torch
import torch.nn as nn
from copy import deepcopy
from typing import List, Text, Dict
from .search_cells import NASNetSearchCell as SearchCell
from .genotypes import Structure
@@ -11,7 +12,7 @@ from .genotypes import Structure
# The macro structure is based on NASNet
class NASNetworkDARTS(nn.Module):
def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats):
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int, num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
super(NASNetworkDARTS, self).__init__()
self._C = C
self._layerN = N
@@ -44,31 +45,31 @@ class NASNetworkDARTS(nn.Module):
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
def get_weights(self):
def get_weights(self) -> List[torch.nn.Parameter]:
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
xlist+= list( self.classifier.parameters() )
return xlist
def get_alphas(self):
def get_alphas(self) -> List[torch.nn.Parameter]:
return [self.arch_normal_parameters, self.arch_reduce_parameters]
def show_alphas(self):
def show_alphas(self) -> Text:
with torch.no_grad():
A = 'arch-normal-parameters :\n{:}'.format( nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu() )
B = 'arch-reduce-parameters :\n{:}'.format( nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu() )
return '{:}\n{:}'.format(A, B)
def get_message(self):
def get_message(self) -> Text:
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
def extra_repr(self) -> Text:
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def genotype(self):
def genotype(self) -> Dict[Text, List]:
def _parse(weights):
gene = []
for i in range(self._steps):