update scripts
This commit is contained in:
@@ -10,6 +10,7 @@ from utils import time_string, convert_secs2time
|
||||
from utils import count_parameters_in_MB
|
||||
from utils import Cutout
|
||||
from nas import NetworkCIFAR as Network
|
||||
from datasets import get_datasets
|
||||
|
||||
def obtain_best(accuracies):
|
||||
if len(accuracies) == 0: return (0, 0)
|
||||
@@ -17,38 +18,10 @@ def obtain_best(accuracies):
|
||||
s2b = sorted( tops )
|
||||
return s2b[-1]
|
||||
|
||||
|
||||
def main_procedure(config, dataset, data_path, args, genotype, init_channels, layers, log):
|
||||
|
||||
# Mean + Std
|
||||
if dataset == 'cifar10':
|
||||
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
||||
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
||||
elif dataset == 'cifar100':
|
||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(dataset))
|
||||
# Dataset Transformation
|
||||
if dataset == 'cifar10' or dataset == 'cifar100':
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std)]
|
||||
if config.cutout > 0 : lists += [Cutout(config.cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(dataset))
|
||||
# Dataset Defination
|
||||
if dataset == 'cifar10':
|
||||
train_data = dset.CIFAR10(data_path, train= True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR10(data_path, train=False, transform=test_transform , download=True)
|
||||
class_num = 10
|
||||
elif dataset == 'cifar100':
|
||||
train_data = dset.CIFAR100(data_path, train= True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(data_path, train=False, transform=test_transform , download=True)
|
||||
class_num = 100
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(dataset))
|
||||
|
||||
train_data, test_data, class_num = get_datasets(dataset, data_path, args.cutout)
|
||||
|
||||
print_log('-------------------------------------- main-procedure', log)
|
||||
print_log('config : {:}'.format(config), log)
|
||||
|
Reference in New Issue
Block a user