update scripts-cluster
This commit is contained in:
@@ -19,7 +19,7 @@ from train_utils_imagenet import main_procedure_imagenet
|
||||
from scheduler import load_config
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser("cifar")
|
||||
parser = argparse.ArgumentParser("Train-CNN")
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['imagenet', 'cifar10', 'cifar100'], help='Choose between Cifar10/100 and ImageNet.')
|
||||
parser.add_argument('--arch', type=str, choices=models.keys(), help='the searched model.')
|
||||
@@ -38,6 +38,7 @@ args = parser.parse_args()
|
||||
|
||||
assert torch.cuda.is_available(), 'torch.cuda is not available'
|
||||
|
||||
|
||||
if args.manualSeed is None:
|
||||
args.manualSeed = random.randint(1, 10000)
|
||||
random.seed(args.manualSeed)
|
||||
@@ -72,9 +73,9 @@ def main():
|
||||
# clear GPU cache
|
||||
torch.cuda.empty_cache()
|
||||
if args.dataset == 'imagenet':
|
||||
main_procedure_imagenet(config, args.data_path, args, genotype, args.init_channels, args.layers, log)
|
||||
main_procedure_imagenet(config, args.data_path, args, genotype, args.init_channels, args.layers, None, log)
|
||||
else:
|
||||
main_procedure(config, args.dataset, args.data_path, args, genotype, args.init_channels, args.layers, log)
|
||||
main_procedure(config, args.dataset, args.data_path, args, genotype, args.init_channels, args.layers, None, log)
|
||||
log.close()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user