update code styles

This commit is contained in:
D-X-Y
2020-01-09 22:26:23 +11:00
parent 5ac5060a33
commit ad34af9913
26 changed files with 192 additions and 81 deletions

View File

@@ -9,16 +9,25 @@ class SearchDataset(data.Dataset):
def __init__(self, name, data, train_split, valid_split, check=True):
self.datasetname = name
self.data = data
self.train_split = train_split.copy()
self.valid_split = valid_split.copy()
if check:
intersection = set(train_split).intersection(set(valid_split))
assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection'
if isinstance(data, (list, tuple)): # new type of SearchDataset
assert len(data) == 2, 'invalid length: {:}'.format( len(data) )
self.train_data = data[0]
self.valid_data = data[1]
self.train_split = train_split.copy()
self.valid_split = valid_split.copy()
self.mode_str = 'V2' # new mode
else:
self.mode_str = 'V1' # old mode
self.data = data
self.train_split = train_split.copy()
self.valid_split = valid_split.copy()
if check:
intersection = set(train_split).intersection(set(valid_split))
assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection'
self.length = len(self.train_split)
def __repr__(self):
return ('{name}(name={datasetname}, train={tr_L}, valid={val_L})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split)))
return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str))
def __len__(self):
return self.length
@@ -27,6 +36,11 @@ class SearchDataset(data.Dataset):
assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index)
train_index = self.train_split[index]
valid_index = random.choice( self.valid_split )
train_image, train_label = self.data[train_index]
valid_image, valid_label = self.data[valid_index]
if self.mode_str == 'V1':
train_image, train_label = self.data[train_index]
valid_image, valid_label = self.data[valid_index]
elif self.mode_str == 'V2':
train_image, train_label = self.train_data[train_index]
valid_image, valid_label = self.valid_data[valid_index]
else: raise ValueError('invalid mode : {:}'.format(self.mode_str))
return train_image, train_label, valid_image, valid_label

View File

@@ -34,7 +34,7 @@ class PointMeta():
def get_box(self, return_diagonal=False):
if self.box is None: return None
if return_diagonal == False:
if not return_diagonal:
return self.box.clone()
else:
W = (self.box[2]-self.box[0]).item()

View File

@@ -1,4 +1,3 @@
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import OPS

View File

@@ -68,7 +68,7 @@ class Structure:
for i, node_info in enumerate(self.nodes):
sums = []
for op, xin in node_info:
if op == 'none' or nodes[xin] == False: x = False
if op == 'none' or nodes[xin] is False: x = False
else: x = True
sums.append( x )
nodes[i+1] = sum(sums) > 0

View File

@@ -85,7 +85,7 @@ class SearchCell(nn.Module):
candidates = self.edges[node_str]
select_op = random.choice(candidates)
sops.append( select_op )
if not hasattr(select_op, 'is_zero') or select_op.is_zero == False: has_non_zero=True
if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True
if has_non_zero: break
inter_nodes = []
for j, select_op in enumerate(sops):

View File

@@ -1,4 +1,4 @@
import math, torch
import math
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet

View File

@@ -70,6 +70,9 @@ class NASBench102API(object):
def __repr__(self):
return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs)))
def random(self):
return random.randint(0, len(self.meta_archs)-1)
def query_index_by_arch(self, arch):
if isinstance(arch, str):
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]

View File

@@ -1,7 +1,7 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch, random, PIL, copy, numpy as np
import os, sys, torch, random, PIL, copy, numpy as np
from os import path as osp
from shutil import copyfile

View File

@@ -1,3 +1,5 @@
from .evaluation_utils import obtain_accuracy
from .gpu_manager import GPUManager
from .flop_benchmark import get_model_infos
from .affine_utils import normalize_points, denormalize_points
from .affine_utils import identity2affine, solve2theta, affine2image

View File

@@ -1,10 +1,3 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
#
# functions for affine transformation
import math, torch
import numpy as np

View File

@@ -1,4 +1,4 @@
import copy, torch
import torch
import torch.nn as nn
import numpy as np

View File

@@ -27,7 +27,7 @@ class GPUManager():
find = False
for gpu in all_gpus:
if gpu['index'] == CUDA_VISIBLE_DEVICE:
assert find==False, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE)
assert not find, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE)
find = True
selected_gpus.append( gpu.copy() )
selected_gpus[-1]['index'] = '{}'.format(idx)

52
lib/utils/nas_utils.py Normal file
View File

@@ -0,0 +1,52 @@
# This file is for experimental usage
import os, sys, torch, random
import numpy as np
from copy import deepcopy
from tqdm import tqdm
import torch.nn as nn
from utils import obtain_accuracy
from models import CellStructure
from log_utils import time_string
def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
weights = deepcopy(model.state_dict())
model.train(cal_mode)
with torch.no_grad():
logits = nn.functional.log_softmax(model.arch_parameters, dim=-1)
archs = CellStructure.gen_all(model.op_names, model.max_nodes, False)
probs, accuracies, gt_accs = [], [], []
loader_iter = iter(xloader)
random.seed(seed)
random.shuffle(archs)
for idx, arch in enumerate(archs):
arch_index = api.query_index_by_arch( arch )
metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False)
gt_accs.append( metrics['valid-accuracy'] )
select_logits = []
for i, node_info in enumerate(arch.nodes):
for op, xin in node_info:
node_str = '{:}<-{:}'.format(i+1, xin)
op_index = model.op_names.index(op)
select_logits.append( logits[model.edge2index[node_str], op_index] )
cur_prob = sum(select_logits).item()
probs.append( cur_prob )
cor_prob = np.corrcoef(probs, gt_accs)[0,1]
print ('correlation for probabilities : {:}'.format(cor_prob))
for idx, arch in enumerate(archs):
model.set_cal_mode('dynamic', arch)
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
_, logits = model(inputs.cuda())
_, preds = torch.max(logits, dim=-1)
correct = (preds == targets.cuda() ).float()
accuracies.append( correct.mean().item() )
if idx != 0 and (idx % 300 == 0 or idx + 1 == len(archs) or idx == 10):
cor_accs = np.corrcoef(accuracies, gt_accs[:idx+1])[0,1]
print ('{:} {:03d}/{:03d} mode={:5s}, correlation : accs={:.4f}, arch={:}'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs, arch))
model.load_state_dict(weights)
return archs, probs, accuracies

View File

@@ -1 +0,0 @@
from .affine_utils import normalize_points, denormalize_points