update scripts-cluster
This commit is contained in:
@@ -2,7 +2,7 @@ import os, sys, time
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from shutil import copyfile
|
||||
|
||||
from utils import print_log, obtain_accuracy, AverageMeter
|
||||
from utils import time_string, convert_secs2time
|
||||
@@ -11,6 +11,7 @@ from utils import Cutout
|
||||
from nas import NetworkCIFAR as Network
|
||||
from datasets import get_datasets
|
||||
|
||||
|
||||
def obtain_best(accuracies):
|
||||
if len(accuracies) == 0: return (0, 0)
|
||||
tops = [value for key, value in accuracies.items()]
|
||||
@@ -18,7 +19,7 @@ def obtain_best(accuracies):
|
||||
return s2b[-1]
|
||||
|
||||
|
||||
def main_procedure(config, dataset, data_path, args, genotype, init_channels, layers, log):
|
||||
def main_procedure(config, dataset, data_path, args, genotype, init_channels, layers, pure_evaluate, log):
|
||||
|
||||
train_data, test_data, class_num = get_datasets(dataset, data_path, config.cutout)
|
||||
|
||||
@@ -57,10 +58,17 @@ def main_procedure(config, dataset, data_path, args, genotype, init_channels, la
|
||||
else:
|
||||
raise ValueError('Can not find the schedular type : {:}'.format(config.type))
|
||||
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-{:}-model.pth'.format(dataset))
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load( checkpoint_path )
|
||||
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-{:}-model.pth'.format(dataset))
|
||||
checkpoint_best = os.path.join(args.save_path, 'checkpoint-{:}-best.pth'.format(dataset))
|
||||
if pure_evaluate:
|
||||
print_log('-'*20 + 'Pure Evaluation' + '-'*20, log)
|
||||
basemodel.load_state_dict( pure_evaluate )
|
||||
with torch.no_grad():
|
||||
valid_acc1, valid_acc5, valid_los = _train(test_loader, model, criterion, optimizer, 'test', -1, config, args.print_freq, log)
|
||||
return (valid_acc1, valid_acc5)
|
||||
elif os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load( checkpoint_path )
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
@@ -96,12 +104,14 @@ def main_procedure(config, dataset, data_path, args, genotype, init_channels, la
|
||||
'accuracies': accuracies},
|
||||
checkpoint_path)
|
||||
best_acc = obtain_best( accuracies )
|
||||
if accuracies[epoch] == best_acc: copyfile(checkpoint_path, checkpoint_best)
|
||||
print_log('----> Best Accuracy : Acc@1={:.2f}, Acc@5={:.2f}, Error@1={:.2f}, Error@5={:.2f}'.format(best_acc[0], best_acc[1], 100-best_acc[0], 100-best_acc[1]), log)
|
||||
print_log('----> Save into {:}'.format(checkpoint_path), log)
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
return obtain_best( accuracies )
|
||||
|
||||
|
||||
def _train(xloader, model, criterion, optimizer, mode, epoch, config, print_freq, log):
|
||||
|
Reference in New Issue
Block a user