update baseline NAS algos

This commit is contained in:
D-X-Y
2019-11-14 13:55:42 +11:00
parent 5c73aeb50b
commit 7843940846
13 changed files with 924 additions and 33 deletions

View File

@@ -1,6 +1,3 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, argparse, collections
from copy import deepcopy
import torch
@@ -167,7 +164,6 @@ def simplify(save_dir, meta_file, basestr, target_dir):
arch_time = AverageMeter()
for idx, arch_index in enumerate(arch_indexes):
checkpoints = list(target_directory.glob('arch-{:}-seed-*.pth'.format(arch_index)))
arch_info = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict)
try:
arch_info = account_one_arch(arch_index, meta_archs[int(arch_index)], checkpoints, datasets, dataloader_dict)
num_seeds[ len(checkpoints) ] += 1
@@ -181,7 +177,7 @@ def simplify(save_dir, meta_file, basestr, target_dir):
torch.save(arch_info.state_dict(), to_save_allarc / '{:}-FULL.pth'.format(arch_index))
#torch.save(arch_info, to_save_allarc / '{:}-FULL.pth'.format(arch_index))
arch_info.clear_params()
torch.save(arch_info, to_save_allarc / '{:}-SIMPLE.pth'.format(arch_index))
torch.save(arch_info.state_dict(), to_save_allarc / '{:}-SIMPLE.pth'.format(arch_index))
# measure elapsed time
arch_time.update(time.time() - end_time)
end_time = time.time()
@@ -241,7 +237,7 @@ def merge_all(save_dir, meta_file, basestr):
xevalindexs = sub_ckps['evaluated_indexes']
for eval_index in xevalindexs:
assert eval_index not in evaluated_indexes and eval_index not in arch2infos
arch2infos[eval_index] = xarch2infos[eval_index]
arch2infos[eval_index] = xarch2infos[eval_index].state_dict()
evaluated_indexes.add( eval_index )
print ('{:} [{:03d}/{:03d}] merge data from {:} with {:} models.'.format(time_string(), IDX, len(subdir2archs), ckp_path, len(xevalindexs)))
else: