Update NATS-Bench (tss version 0.8)
This commit is contained in:
@@ -27,6 +27,7 @@ from procedures import bench_evaluate_for_seed
|
||||
from procedures import get_machine_info
|
||||
from datasets import get_datasets
|
||||
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||
from utils import split_str2indexes
|
||||
|
||||
|
||||
def evaluate_all_datasets(channels: Text, datasets: List[Text], xpaths: List[Text],
|
||||
@@ -107,7 +108,6 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
|
||||
logger.log('xargs : seeds = {:}'.format(seeds))
|
||||
logger.log('xargs : cover_mode = {:}'.format(cover_mode))
|
||||
logger.log('-' * 100)
|
||||
|
||||
logger.log(
|
||||
'Start evaluating range =: {:06d} - {:06d}'.format(min(to_evaluate_indexes), max(to_evaluate_indexes))
|
||||
+'({:} in total) / {:06d} with cover-mode={:}'.format(len(to_evaluate_indexes), len(nets), cover_mode))
|
||||
@@ -115,7 +115,6 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
|
||||
logger.log(
|
||||
'--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split))
|
||||
logger.log('--->>> optimization config : {:}'.format(opt_config))
|
||||
#to_evaluate_indexes = list(range(srange[0], srange[1] + 1))
|
||||
|
||||
start_time, epoch_time = time.time(), AverageMeter()
|
||||
for i, index in enumerate(to_evaluate_indexes):
|
||||
@@ -136,10 +135,12 @@ def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
|
||||
logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name))
|
||||
has_continue = True
|
||||
continue
|
||||
results = evaluate_all_datasets(channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger)
|
||||
results = evaluate_all_datasets(channelstr,
|
||||
datasets, xpaths, splits, opt_config, seed,
|
||||
workers, logger)
|
||||
torch.save(results, to_save_name)
|
||||
logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}'.format(time_string(), i,
|
||||
len(to_evaluate_indexes), index, len(nets), seeds, to_save_name))
|
||||
logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}'.format(time_string(), i,
|
||||
len(to_evaluate_indexes), index, len(nets), seeds, to_save_name))
|
||||
# measure elapsed time
|
||||
if not has_continue: epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
@@ -224,20 +225,7 @@ if __name__ == '__main__':
|
||||
raise ValueError('{:} is not a file.'.format(opt_config))
|
||||
save_dir = Path(args.save_dir) / 'raw-data-{:}'.format(args.hyper)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not isinstance(args.srange, str):
|
||||
raise ValueError('Invalid scheme for {:}'.format(args.srange))
|
||||
srangestr = "".join(args.srange.split())
|
||||
to_evaluate_indexes = set()
|
||||
for srange in srangestr.split(','):
|
||||
srange = srange.split('-')
|
||||
if len(srange) != 2:
|
||||
raise ValueError('invalid srange : {:}'.format(srange))
|
||||
assert len(srange[0]) == len(srange[1]) == 5, 'invalid srange : {:}'.format(srange)
|
||||
srange = (int(srange[0]), int(srange[1]))
|
||||
if not (0 <= srange[0] <= srange[1] < args.check_N):
|
||||
raise ValueError('{:} vs {:} vs {:}'.format(srange[0], srange[1], args.check_N))
|
||||
for i in range(srange[0], srange[1]+1):
|
||||
to_evaluate_indexes.add(i)
|
||||
to_evaluate_indexes = split_str2indexes(args.srange, args.check_N, 5)
|
||||
|
||||
if not len(args.seeds):
|
||||
raise ValueError('invalid length of seeds args: {:}'.format(args.seeds))
|
||||
|
Reference in New Issue
Block a user