Update test weights and shapes
This commit is contained in:
@@ -88,11 +88,7 @@ def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Tex
|
||||
|
||||
def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
|
||||
splits: List[int], seeds: List[int], nets: List[str], opt_config: Dict[Text, Any],
|
||||
srange: tuple, cover_mode: bool):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(workers)
|
||||
to_evaluate_indexes: tuple, cover_mode: bool):
|
||||
|
||||
log_dir = save_dir / 'logs'
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -103,13 +99,13 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
|
||||
logger.log('-' * 100)
|
||||
|
||||
logger.log(
|
||||
'Start evaluating range =: {:06d} - {:06d} / {:06d} with cover-mode={:}'.format(srange[0], srange[1], len(nets),
|
||||
cover_mode))
|
||||
'Start evaluating range =: {:06d} - {:06d}'.format(min(to_evaluate_indexes), max(to_evaluate_indexes))
|
||||
+'({:} in total) / {:06d} with cover-mode={:}'.format(len(to_evaluate_indexes), len(nets), cover_mode))
|
||||
for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)):
|
||||
logger.log(
|
||||
'--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split))
|
||||
logger.log('--->>> optimization config : {:}'.format(opt_config))
|
||||
to_evaluate_indexes = list(range(srange[0], srange[1] + 1))
|
||||
#to_evaluate_indexes = list(range(srange[0], srange[1] + 1))
|
||||
|
||||
start_time, epoch_time = time.time(), AverageMeter()
|
||||
for i, index in enumerate(to_evaluate_indexes):
|
||||
@@ -158,21 +154,55 @@ def traverse_net(candidates: List[int], N: int):
|
||||
return nets
|
||||
|
||||
|
||||
def filter_indexes(xlist, mode, save_dir, seeds):
|
||||
all_indexes = []
|
||||
for index in xlist:
|
||||
if mode == 'cover':
|
||||
all_indexes.append(index)
|
||||
else:
|
||||
for seed in seeds:
|
||||
temp_path = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed)
|
||||
if not temp_path.exists():
|
||||
all_indexes.append(index)
|
||||
break
|
||||
print('{:} [FILTER-INDEXES] : there are {:} architectures in total'.format(time_string(), len(all_indexes)))
|
||||
|
||||
SLURM_PROCID, SLURM_NTASKS = 'SLURM_PROCID', 'SLURM_NTASKS'
|
||||
if SLURM_PROCID in os.environ and SLURM_NTASKS in os.environ: # run on the slurm
|
||||
proc_id, ntasks = int(os.environ[SLURM_PROCID]), int(os.environ[SLURM_NTASKS])
|
||||
assert 0 <= proc_id < ntasks, 'invalid proc_id {:} vs ntasks {:}'.format(proc_id, ntasks)
|
||||
scales = [int(float(i)/ntasks*len(all_indexes)) for i in range(ntasks)] + [len(all_indexes)]
|
||||
per_job = []
|
||||
for i in range(ntasks):
|
||||
xs, xe = min(max(scales[i],0), len(all_indexes)-1), min(max(scales[i+1]-1,0), len(all_indexes)-1)
|
||||
per_job.append((xs, xe))
|
||||
for i, srange in enumerate(per_job):
|
||||
print(' -->> {:2d}/{:02d} : {:}'.format(i, ntasks, srange))
|
||||
current_range = per_job[proc_id]
|
||||
all_indexes = [all_indexes[i] for i in range(current_range[0], current_range[1]+1)]
|
||||
# set the device id
|
||||
device = proc_id % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
print(' set the device id = {:}'.format(device))
|
||||
print('{:} [FILTER-INDEXES] : after filtering there are {:} architectures in total'.format(time_string(), len(all_indexes)))
|
||||
return all_indexes
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--mode', type=str, required=True, choices=['new', 'cover'], help='The script mode.')
|
||||
parser.add_argument('--save_dir', type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--candidateC', type=int, nargs='+', default=[8, 16, 24, 32, 40, 48, 56, 64], help='.')
|
||||
parser.add_argument('--num_layers', type=int, default=5, help='The number of layers in a network.')
|
||||
parser.add_argument('--check_N', type=int, default=32768, help='For safety.')
|
||||
parser.add_argument('--mode', type=str, required=True, choices=['new', 'cover'], help='The script mode.')
|
||||
parser.add_argument('--save_dir', type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--candidateC', type=int, nargs='+', default=[8, 16, 24, 32, 40, 48, 56, 64], help='.')
|
||||
parser.add_argument('--num_layers', type=int, default=5, help='The number of layers in a network.')
|
||||
parser.add_argument('--check_N', type=int, default=32768, help='For safety.')
|
||||
# use for train the model
|
||||
parser.add_argument('--workers', type=int, default=8, help='The number of data loading workers (default: 2)')
|
||||
parser.add_argument('--srange' , type=str, required=True, help='The range of models to be evaluated')
|
||||
parser.add_argument('--datasets', type=str, nargs='+', help='The applied datasets.')
|
||||
parser.add_argument('--xpaths', type=str, nargs='+', help='The root path for this dataset.')
|
||||
parser.add_argument('--splits', type=int, nargs='+', help='The root path for this dataset.')
|
||||
parser.add_argument('--hyper', type=str, default='12', choices=['12', '90'], help='The tag for hyper-parameters.')
|
||||
parser.add_argument('--seeds' , type=int, nargs='+', help='The range of models to be evaluated')
|
||||
parser.add_argument('--workers', type=int, default=8, help='The number of data loading workers (default: 2)')
|
||||
parser.add_argument('--srange' , type=str, required=True, help='The range of models to be evaluated')
|
||||
parser.add_argument('--datasets', type=str, nargs='+', help='The applied datasets.')
|
||||
parser.add_argument('--xpaths', type=str, nargs='+', help='The root path for this dataset.')
|
||||
parser.add_argument('--splits', type=int, nargs='+', help='The root path for this dataset.')
|
||||
parser.add_argument('--hyper', type=str, default='12', choices=['12', '90'], help='The tag for hyper-parameters.')
|
||||
parser.add_argument('--seeds' , type=int, nargs='+', help='The range of models to be evaluated')
|
||||
args = parser.parse_args()
|
||||
|
||||
nets = traverse_net(args.candidateC, args.num_layers)
|
||||
@@ -182,15 +212,31 @@ if __name__ == '__main__':
|
||||
if not os.path.isfile(opt_config): raise ValueError('{:} is not a file.'.format(opt_config))
|
||||
save_dir = Path(args.save_dir) / 'raw-data-{:}'.format(args.hyper)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not isinstance(args.srange, str) or len(args.srange.split('-')) != 2:
|
||||
if not isinstance(args.srange, str):
|
||||
raise ValueError('Invalid scheme for {:}'.format(args.srange))
|
||||
srange = args.srange.split('-')
|
||||
srange = (int(srange[0]), int(srange[1]))
|
||||
assert 0 <= srange[0] <= srange[1] < args.check_N, '{:} vs {:} vs {:}'.format(srange[0], srange[1], args.check_N)
|
||||
srangestr = "".join(args.srange.split())
|
||||
to_evaluate_indexes = set()
|
||||
for srange in srangestr.split(','):
|
||||
srange = srange.split('-')
|
||||
if len(srange) != 2: raise ValueError('invalid srange : {:}'.format(srange))
|
||||
assert len(srange[0]) == len(srange[1]) == 5, 'invalid srange : {:}'.format(srange)
|
||||
srange = (int(srange[0]), int(srange[1]))
|
||||
if not (0 <= srange[0] <= srange[1] < args.check_N):
|
||||
raise ValueError('{:} vs {:} vs {:}'.format(srange[0], srange[1], args.check_N))
|
||||
for i in range(srange[0], srange[1]+1):
|
||||
to_evaluate_indexes.add(i)
|
||||
|
||||
assert len(args.seeds) > 0, 'invalid length of seeds args: {:}'.format(args.seeds)
|
||||
assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits))
|
||||
if not (len(args.datasets) == len(args.xpaths) == len(args.splits)):
|
||||
raise ValueError('invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits)))
|
||||
assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers)
|
||||
|
||||
target_indexes = filter_indexes(to_evaluate_indexes, args.mode, save_dir, args.seeds)
|
||||
|
||||
main(save_dir, args.workers, args.datasets, args.xpaths, args.splits, tuple(args.seeds), nets, opt_config,
|
||||
srange, args.mode == 'cover')
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(args.workers)
|
||||
|
||||
main(save_dir, args.workers, args.datasets, args.xpaths, args.splits, tuple(args.seeds), nets, opt_config, target_indexes, args.mode == 'cover')
|
||||
|
||||
|
Reference in New Issue
Block a user