update scripts

This commit is contained in:
D-X-Y
2019-02-01 03:23:55 +11:00
parent 4eb1a5ccf9
commit 3f9b54d99e
29 changed files with 115 additions and 137 deletions

View File

@@ -1,3 +1,4 @@
# DARTS First Order, Refer to https://github.com/quark0/darts
import os, sys, time, glob, random, argparse
import numpy as np
from copy import deepcopy

View File

@@ -13,25 +13,11 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from utils import AverageMeter, time_string, convert_secs2time
from utils import print_log, obtain_accuracy
from utils import Cutout, count_parameters_in_MB
from nas import DARTS_V1, DARTS_V2, NASNet, PNASNet, AmoebaNet, ENASNet
from nas import DMS_V1, DMS_F1, GDAS_CC
from meta_nas import META_V1, META_V2
from nas import model_types as models
from train_utils import main_procedure
from train_utils_imagenet import main_procedure_imagenet
from scheduler import load_config
models = {'DARTS_V1': DARTS_V1,
'DARTS_V2': DARTS_V2,
'NASNet' : NASNet,
'PNASNet' : PNASNet,
'ENASNet' : ENASNet,
'DMS_V1' : DMS_V1,
'DMS_F1' : DMS_F1,
'GDAS_CC' : GDAS_CC,
'META_V1' : META_V1,
'META_V2' : META_V2,
'AmoebaNet' : AmoebaNet}
parser = argparse.ArgumentParser("cifar")
parser.add_argument('--data_path', type=str, help='Path to dataset')

View File

@@ -10,6 +10,7 @@ from utils import time_string, convert_secs2time
from utils import count_parameters_in_MB
from utils import Cutout
from nas import NetworkCIFAR as Network
from datasets import get_datasets
def obtain_best(accuracies):
if len(accuracies) == 0: return (0, 0)
@@ -17,38 +18,10 @@ def obtain_best(accuracies):
s2b = sorted( tops )
return s2b[-1]
def main_procedure(config, dataset, data_path, args, genotype, init_channels, layers, log):
# Mean + Std
if dataset == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif dataset == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
else:
raise TypeError("Unknow dataset : {:}".format(dataset))
# Dataset Transformation
if dataset == 'cifar10' or dataset == 'cifar100':
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
transforms.Normalize(mean, std)]
if config.cutout > 0 : lists += [Cutout(config.cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
else:
raise TypeError("Unknow dataset : {:}".format(dataset))
# Dataset Defination
if dataset == 'cifar10':
train_data = dset.CIFAR10(data_path, train= True, transform=train_transform, download=True)
test_data = dset.CIFAR10(data_path, train=False, transform=test_transform , download=True)
class_num = 10
elif dataset == 'cifar100':
train_data = dset.CIFAR100(data_path, train= True, transform=train_transform, download=True)
test_data = dset.CIFAR100(data_path, train=False, transform=test_transform , download=True)
class_num = 100
else:
raise TypeError("Unknow dataset : {:}".format(dataset))
train_data, test_data, class_num = get_datasets(dataset, data_path, args.cutout)
print_log('-------------------------------------- main-procedure', log)
print_log('config : {:}'.format(config), log)

View File

@@ -12,6 +12,7 @@ from utils import count_parameters_in_MB
from utils import print_FLOPs
from utils import Cutout
from nas import NetworkImageNet as Network
from datasets import get_datasets
def obtain_best(accuracies):
@@ -40,30 +41,7 @@ class CrossEntropyLabelSmooth(nn.Module):
def main_procedure_imagenet(config, data_path, args, genotype, init_channels, layers, log):
# training data and testing data
traindir = os.path.join(data_path, 'train')
validdir = os.path.join(data_path, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_data = dset.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize,
]))
valid_data = dset.ImageFolder(
validdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
train_data, valid_data, class_num = get_datasets('imagenet-1k', data_path, -1)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=config.batch_size, shuffle= True, pin_memory=True, num_workers=args.workers)
@@ -73,7 +51,6 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la
class_num = 1000
print_log('-------------------------------------- main-procedure', log)
print_log('config : {:}'.format(config), log)
print_log('genotype : {:}'.format(genotype), log)
@@ -98,8 +75,7 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la
criterion_smooth = CrossEntropyLabelSmooth(class_num, config.label_smooth).cuda()
optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay)
#optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay, nestero=True)
optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay, nestero=True)
if config.type == 'cosine':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(config.epochs))
elif config.type == 'steplr':