update scripts-cluster
This commit is contained in:
49
exps-cnn/evaluate.py
Normal file
49
exps-cnn/evaluate.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.datasets as dset
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from utils import AverageMeter, time_string, convert_secs2time
|
||||
from utils import print_log, obtain_accuracy
|
||||
from utils import Cutout, count_parameters_in_MB
|
||||
from nas import model_types as models
|
||||
from train_utils import main_procedure
|
||||
from train_utils_imagenet import main_procedure_imagenet
|
||||
from scheduler import load_config
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser("Evaluate-CNN")
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset.')
|
||||
parser.add_argument('--checkpoint', type=str, help='Choose between Cifar10/100 and ImageNet.')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert torch.cuda.is_available(), 'torch.cuda is not available'
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
assert os.path.isdir( args.data_path ), 'invalid data-path : {:}'.format(args.data_path)
|
||||
assert os.path.isfile( args.checkpoint ), 'invalid checkpoint : {:}'.format(args.checkpoint)
|
||||
|
||||
checkpoint = torch.load( args.checkpoint )
|
||||
xargs = checkpoint['args']
|
||||
config = load_config(xargs.model_config)
|
||||
genotype = models[xargs.arch]
|
||||
|
||||
# clear GPU cache
|
||||
torch.cuda.empty_cache()
|
||||
if xargs.dataset == 'imagenet':
|
||||
main_procedure_imagenet(config, args.data_path, xargs, genotype, xargs.init_channels, xargs.layers, checkpoint['state_dict'], None)
|
||||
else:
|
||||
main_procedure(config, xargs.dataset, args.data_path, xargs, genotype, xargs.init_channels, xargs.layers, checkpoint['state_dict'], None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@@ -19,7 +19,7 @@ from train_utils_imagenet import main_procedure_imagenet
|
||||
from scheduler import load_config
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser("cifar")
|
||||
parser = argparse.ArgumentParser("Train-CNN")
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['imagenet', 'cifar10', 'cifar100'], help='Choose between Cifar10/100 and ImageNet.')
|
||||
parser.add_argument('--arch', type=str, choices=models.keys(), help='the searched model.')
|
||||
@@ -38,6 +38,7 @@ args = parser.parse_args()
|
||||
|
||||
assert torch.cuda.is_available(), 'torch.cuda is not available'
|
||||
|
||||
|
||||
if args.manualSeed is None:
|
||||
args.manualSeed = random.randint(1, 10000)
|
||||
random.seed(args.manualSeed)
|
||||
@@ -72,9 +73,9 @@ def main():
|
||||
# clear GPU cache
|
||||
torch.cuda.empty_cache()
|
||||
if args.dataset == 'imagenet':
|
||||
main_procedure_imagenet(config, args.data_path, args, genotype, args.init_channels, args.layers, log)
|
||||
main_procedure_imagenet(config, args.data_path, args, genotype, args.init_channels, args.layers, None, log)
|
||||
else:
|
||||
main_procedure(config, args.dataset, args.data_path, args, genotype, args.init_channels, args.layers, log)
|
||||
main_procedure(config, args.dataset, args.data_path, args, genotype, args.init_channels, args.layers, None, log)
|
||||
log.close()
|
||||
|
||||
|
||||
|
@@ -2,7 +2,7 @@ import os, sys, time
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from shutil import copyfile
|
||||
|
||||
from utils import print_log, obtain_accuracy, AverageMeter
|
||||
from utils import time_string, convert_secs2time
|
||||
@@ -11,6 +11,7 @@ from utils import Cutout
|
||||
from nas import NetworkCIFAR as Network
|
||||
from datasets import get_datasets
|
||||
|
||||
|
||||
def obtain_best(accuracies):
|
||||
if len(accuracies) == 0: return (0, 0)
|
||||
tops = [value for key, value in accuracies.items()]
|
||||
@@ -18,7 +19,7 @@ def obtain_best(accuracies):
|
||||
return s2b[-1]
|
||||
|
||||
|
||||
def main_procedure(config, dataset, data_path, args, genotype, init_channels, layers, log):
|
||||
def main_procedure(config, dataset, data_path, args, genotype, init_channels, layers, pure_evaluate, log):
|
||||
|
||||
train_data, test_data, class_num = get_datasets(dataset, data_path, config.cutout)
|
||||
|
||||
@@ -57,10 +58,17 @@ def main_procedure(config, dataset, data_path, args, genotype, init_channels, la
|
||||
else:
|
||||
raise ValueError('Can not find the schedular type : {:}'.format(config.type))
|
||||
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-{:}-model.pth'.format(dataset))
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load( checkpoint_path )
|
||||
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-{:}-model.pth'.format(dataset))
|
||||
checkpoint_best = os.path.join(args.save_path, 'checkpoint-{:}-best.pth'.format(dataset))
|
||||
if pure_evaluate:
|
||||
print_log('-'*20 + 'Pure Evaluation' + '-'*20, log)
|
||||
basemodel.load_state_dict( pure_evaluate )
|
||||
with torch.no_grad():
|
||||
valid_acc1, valid_acc5, valid_los = _train(test_loader, model, criterion, optimizer, 'test', -1, config, args.print_freq, log)
|
||||
return (valid_acc1, valid_acc5)
|
||||
elif os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load( checkpoint_path )
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
@@ -96,12 +104,14 @@ def main_procedure(config, dataset, data_path, args, genotype, init_channels, la
|
||||
'accuracies': accuracies},
|
||||
checkpoint_path)
|
||||
best_acc = obtain_best( accuracies )
|
||||
if accuracies[epoch] == best_acc: copyfile(checkpoint_path, checkpoint_best)
|
||||
print_log('----> Best Accuracy : Acc@1={:.2f}, Acc@5={:.2f}, Error@1={:.2f}, Error@5={:.2f}'.format(best_acc[0], best_acc[1], 100-best_acc[0], 100-best_acc[1]), log)
|
||||
print_log('----> Save into {:}'.format(checkpoint_path), log)
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
return obtain_best( accuracies )
|
||||
|
||||
|
||||
def _train(xloader, model, criterion, optimizer, mode, epoch, config, print_freq, log):
|
||||
|
@@ -3,7 +3,7 @@ from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from shutil import copyfile
|
||||
|
||||
from utils import print_log, obtain_accuracy, AverageMeter
|
||||
from utils import time_string, convert_secs2time
|
||||
@@ -37,7 +37,7 @@ class CrossEntropyLabelSmooth(nn.Module):
|
||||
return loss
|
||||
|
||||
|
||||
def main_procedure_imagenet(config, data_path, args, genotype, init_channels, layers, log):
|
||||
def main_procedure_imagenet(config, data_path, args, genotype, init_channels, layers, pure_evaluate, log):
|
||||
|
||||
# training data and testing data
|
||||
train_data, valid_data, class_num = get_datasets('imagenet-1k', data_path, -1)
|
||||
@@ -48,8 +48,6 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la
|
||||
valid_queue = torch.utils.data.DataLoader(
|
||||
valid_data, batch_size=config.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers)
|
||||
|
||||
class_num = 1000
|
||||
|
||||
print_log('-------------------------------------- main-procedure', log)
|
||||
print_log('config : {:}'.format(config), log)
|
||||
print_log('genotype : {:}'.format(genotype), log)
|
||||
@@ -84,9 +82,16 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la
|
||||
|
||||
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-imagenet-model.pth')
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load( checkpoint_path )
|
||||
checkpoint_best = os.path.join(args.save_path, 'checkpoint-imagenet-best.pth')
|
||||
|
||||
if pure_evaluate:
|
||||
print_log('-'*20 + 'Pure Evaluation' + '-'*20, log)
|
||||
basemodel.load_state_dict( pure_evaluate )
|
||||
with torch.no_grad():
|
||||
valid_acc1, valid_acc5, valid_los = _train(valid_queue, model, criterion, None, 'test' , -1, config, args.print_freq, log)
|
||||
return (valid_acc1, valid_acc5)
|
||||
elif os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load( checkpoint_path )
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
@@ -122,12 +127,14 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la
|
||||
'accuracies': accuracies},
|
||||
checkpoint_path)
|
||||
best_acc = obtain_best( accuracies )
|
||||
if accuracies[epoch] == best_acc: copyfile(checkpoint_path, checkpoint_best)
|
||||
print_log('----> Best Accuracy : Acc@1={:.2f}, Acc@5={:.2f}, Error@1={:.2f}, Error@5={:.2f}'.format(best_acc[0], best_acc[1], 100-best_acc[0], 100-best_acc[1]), log)
|
||||
print_log('----> Save into {:}'.format(checkpoint_path), log)
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
return obtain_best( accuracies )
|
||||
|
||||
|
||||
def _train(xloader, model, criterion, optimizer, mode, epoch, config, print_freq, log):
|
||||
|
Reference in New Issue
Block a user