update scripts

This commit is contained in:
D-X-Y
2019-02-01 03:23:55 +11:00
parent 4eb1a5ccf9
commit 3f9b54d99e
29 changed files with 115 additions and 137 deletions

View File

@@ -1,3 +1,4 @@
from .MetaBatchSampler import MetaBatchSampler
from .TieredImageNet import TieredImageNet
from .LanguageDataset import Corpus
from .get_dataset_with_transform import get_datasets

View File

@@ -0,0 +1,74 @@
import os, sys, torch
import os.path as osp
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from utils import Cutout
from .TieredImageNet import TieredImageNet
Dataset2Class = {'cifar10' : 10,
'cifar100': 100,
'tiered' : -1,
'imagnet-1k' : 1000,
'imagenet-100': 100}
def get_datasets(name, root, cutout):
# Mean + Std
if name == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif name == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name == 'tiered':
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name == 'imagnet-1k' or name == 'imagenet-100':
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
else: raise TypeError("Unknow dataset : {:}".format(name))
# Data Argumentation
if name == 'cifar10' or name == 'cifar100':
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
transforms.Normalize(mean, std)]
if cutout > 0 : lists += [Cutout(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
elif name == 'tiered':
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
if cutout > 0 : lists += [Cutout(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
elif name == 'imagnet-1k' or name == 'imagenet-100':
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
else: raise TypeError("Unknow dataset : {:}".format(name))
train_data = TieredImageNet(root, 'train-val', train_transform)
test_data = None
if name == 'cifar10':
train_data = dset.CIFAR10(root, train=True, transform=train_transform, download=True)
test_data = dset.CIFAR10(root, train=True, transform=test_transform , download=True)
elif name == 'cifar100':
train_data = dset.CIFAR100(root, train=True, transform=train_transform, download=True)
test_data = dset.CIFAR100(root, train=True, transform=test_transform , download=True)
elif name == 'imagnet-1k' or name == 'imagenet-100':
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), train_transform)
else: raise TypeError("Unknow dataset : {:}".format(name))
class_num = Dataset2Class[name]
return train_data, test_data, class_num

View File

@@ -1,4 +0,0 @@
rm -rf pytorch
git clone https://github.com/pytorch/pytorch.git
cp -r ./pytorch/torch/nn xnn
rm -rf pytorch

View File

@@ -11,8 +11,6 @@ from .CifarNet import NetworkCIFAR
from .ImageNet import NetworkImageNet
# genotypes
from .genotypes import DARTS_V1, DARTS_V2
from .genotypes import NASNet, PNASNet, AmoebaNet, ENASNet
from .genotypes import DMS_V1, DMS_F1, GDAS_CC
from .genotypes import model_types
from .construct_utils import return_alphas_str

View File

@@ -179,7 +179,7 @@ ENASNet = Genotype(
DARTS = DARTS_V2
# Search by normal and reduce
DMS_V1 = Genotype(
GDAS_V1 = Genotype(
normal=[('skip_connect', 0, 0.13017432391643524), ('skip_connect', 1, 0.12947972118854523), ('skip_connect', 0, 0.13062666356563568), ('sep_conv_5x5', 2, 0.12980839610099792), ('sep_conv_3x3', 3, 0.12923765182495117), ('skip_connect', 0, 0.12901571393013), ('sep_conv_5x5', 4, 0.12938997149467468), ('sep_conv_3x3', 3, 0.1289220005273819)],
normal_concat=range(2, 6),
reduce=[('sep_conv_5x5', 0, 0.12862831354141235), ('sep_conv_3x3', 1, 0.12783904373645782), ('sep_conv_5x5', 2, 0.12725995481014252), ('sep_conv_5x5', 1, 0.12705285847187042), ('dil_conv_5x5', 2, 0.12797553837299347), ('sep_conv_3x3', 1, 0.12737272679805756), ('sep_conv_5x5', 0, 0.12833961844444275), ('sep_conv_5x5', 1, 0.12758426368236542)],
@@ -187,7 +187,7 @@ DMS_V1 = Genotype(
)
# Search by normal and fixing reduction
DMS_F1 = Genotype(
GDAS_F1 = Genotype(
normal=[('skip_connect', 0, 0.16), ('skip_connect', 1, 0.13), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.16), ('sep_conv_3x3', 2, 0.15)],
normal_concat=[2, 3, 4, 5],
reduce=None,
@@ -201,3 +201,13 @@ GDAS_CC = Genotype(
reduce=None,
reduce_concat=range(2, 6)
)
model_types = {'DARTS_V1': DARTS_V1,
'DARTS_V2': DARTS_V2,
'NASNet' : NASNet,
'PNASNet' : PNASNet,
'AmoebaNet': AmoebaNet,
'ENASNet' : ENASNet,
'GDAS_V1' : GDAS_V1,
'GDAS_F1' : GDAS_F1,
'GDAS_CC' : GDAS_CC}