Update NATS-Bench (tss version 0.8)

This commit is contained in:
D-X-Y
2020-08-28 08:31:53 +00:00
parent c68458f66c
commit 2c86d6aa67
5 changed files with 172 additions and 111 deletions

View File

@@ -27,6 +27,7 @@ from procedures import bench_evaluate_for_seed
from procedures import get_machine_info
from datasets import get_datasets
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
from utils import split_str2indexes
def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Text],
@@ -107,7 +108,6 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
logger.log('xargs : seeds = {:}'.format(seeds))
logger.log('xargs : cover_mode = {:}'.format(cover_mode))
logger.log('-' * 100)
logger.log(
'Start evaluating range =: {:06d} - {:06d}'.format(min(to_evaluate_indexes), max(to_evaluate_indexes))
+'({:} in total) / {:06d} with cover-mode={:}'.format(len(to_evaluate_indexes), len(nets), cover_mode))
@@ -115,7 +115,6 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
logger.log(
'--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split))
logger.log('--->>> optimization config : {:}'.format(opt_config))
#to_evaluate_indexes = list(range(srange[0], srange[1] + 1))
start_time, epoch_time = time.time(), AverageMeter()
for i, index in enumerate(to_evaluate_indexes):
@@ -136,10 +135,12 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name))
has_continue = True
continue
results = evaluate_all_datasets(channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger)
results = evaluate_all_datasets(channelstr,
datasets, xpaths, splits, opt_config, seed,
workers, logger)
torch.save(results, to_save_name)
logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}'.format(time_string(), i,
len(to_evaluate_indexes), index, len(nets), seeds, to_save_name))
logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}'.format(time_string(), i,
len(to_evaluate_indexes), index, len(nets), seeds, to_save_name))
# measure elapsed time
if not has_continue: epoch_time.update(time.time() - start_time)
start_time = time.time()
@@ -224,20 +225,7 @@ if __name__ == '__main__':
raise ValueError('{:} is not a file.'.format(opt_config))
save_dir = Path(args.save_dir) / 'raw-data-{:}'.format(args.hyper)
save_dir.mkdir(parents=True, exist_ok=True)
if not isinstance(args.srange, str):
raise ValueError('Invalid scheme for {:}'.format(args.srange))
srangestr = "".join(args.srange.split())
to_evaluate_indexes = set()
for srange in srangestr.split(','):
srange = srange.split('-')
if len(srange) != 2:
raise ValueError('invalid srange : {:}'.format(srange))
assert len(srange[0]) == len(srange[1]) == 5, 'invalid srange : {:}'.format(srange)
srange = (int(srange[0]), int(srange[1]))
if not (0 <= srange[0] <= srange[1] < args.check_N):
raise ValueError('{:} vs {:} vs {:}'.format(srange[0], srange[1], args.check_N))
for i in range(srange[0], srange[1]+1):
to_evaluate_indexes.add(i)
to_evaluate_indexes = split_str2indexes(args.srange, args.check_N, 5)
if not len(args.seeds):
raise ValueError('invalid length of seeds args: {:}'.format(args.seeds))