update scripts

This commit is contained in:
Xuanyi Dong
2019-03-30 02:10:20 +08:00
parent 3734384b68
commit c8dddf9cf9
9 changed files with 61 additions and 23 deletions

View File

@@ -7,6 +7,7 @@ import torch.nn.functional as F
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import multiprocessing
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
print ('lib-dir : {:}'.format(lib_dir))
@@ -29,7 +30,7 @@ parser.add_argument('--config_path', type=str, help='the training configur
parser.add_argument('--save_path', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--threads', type=int, default=10, help='the number of threads')
parser.add_argument('--threads', type=int, default=4, help='the number of threads')
args = parser.parse_args()
assert torch.cuda.is_available(), 'torch.cuda is not available'
@@ -50,7 +51,7 @@ def main():
if not os.path.isdir(args.save_path):
os.makedirs(args.save_path)
log = open(os.path.join(args.save_path, 'log-seed-{:}-{:}.txt'.format(args.manualSeed, time_file_str())), 'w')
print_log('save path : {}'.format(args.save_path), log)
print_log('save path : {:}'.format(args.save_path), log)
state = {k: v for k, v in args._get_kwargs()}
print_log(state, log)
print_log("Random Seed: {}".format(args.manualSeed), log)
@@ -59,6 +60,7 @@ def main():
print_log("CUDA version : {}".format(torch.version.cuda), log)
print_log("cuDNN version : {}".format(cudnn.version()), log)
print_log("Num of GPUs : {}".format(torch.cuda.device_count()), log)
print_log("Num of CPUs : {}".format(multiprocessing.cpu_count()), log)
config = load_config( args.config_path )
genotype = Networks[ args.arch ]