update scripts
This commit is contained in:
@@ -12,6 +12,7 @@ from utils import count_parameters_in_MB
|
||||
from utils import print_FLOPs
|
||||
from utils import Cutout
|
||||
from nas import NetworkImageNet as Network
|
||||
from datasets import get_datasets
|
||||
|
||||
|
||||
def obtain_best(accuracies):
|
||||
@@ -40,30 +41,7 @@ class CrossEntropyLabelSmooth(nn.Module):
|
||||
def main_procedure_imagenet(config, data_path, args, genotype, init_channels, layers, log):
|
||||
|
||||
# training data and testing data
|
||||
traindir = os.path.join(data_path, 'train')
|
||||
validdir = os.path.join(data_path, 'val')
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_data = dset.ImageFolder(
|
||||
traindir,
|
||||
transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.2),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
valid_data = dset.ImageFolder(
|
||||
validdir,
|
||||
transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
train_data, valid_data, class_num = get_datasets('imagenet-1k', data_path, -1)
|
||||
|
||||
train_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=config.batch_size, shuffle= True, pin_memory=True, num_workers=args.workers)
|
||||
@@ -73,7 +51,6 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la
|
||||
|
||||
class_num = 1000
|
||||
|
||||
|
||||
print_log('-------------------------------------- main-procedure', log)
|
||||
print_log('config : {:}'.format(config), log)
|
||||
print_log('genotype : {:}'.format(genotype), log)
|
||||
@@ -98,8 +75,7 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la
|
||||
criterion_smooth = CrossEntropyLabelSmooth(class_num, config.label_smooth).cuda()
|
||||
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay)
|
||||
#optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay, nestero=True)
|
||||
optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay, nestero=True)
|
||||
if config.type == 'cosine':
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(config.epochs))
|
||||
elif config.type == 'steplr':
|
||||
|
Reference in New Issue
Block a user