Upgrade NAS-API to v2.0:
we use an abstract class NASBenchMetaAPI to define the spec of an API; it can be inherited to support different kinds of NAS API, while keep the query interface the same.
This commit is contained in:
@@ -22,7 +22,7 @@ def create_result_count(used_seed: int, dataset: Text, arch_config: Dict[Text, A
|
||||
results: Dict[Text, Any], dataloader_dict: Dict[Text, Any]) -> ResultsCount:
|
||||
xresult = ResultsCount(dataset, results['net_state_dict'], results['train_acc1es'], results['train_losses'],
|
||||
results['param'], results['flop'], arch_config, used_seed, results['total_epoch'], None)
|
||||
net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes':arch_config['class_num']}, None)
|
||||
net_config = dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': CellStructure.str2structure(arch_config['arch_str']), 'num_classes': arch_config['class_num']}, None)
|
||||
network = get_cell_based_tiny_net(net_config)
|
||||
network.load_state_dict(xresult.get_net_param())
|
||||
if 'train_times' in results: # new version
|
||||
@@ -126,7 +126,6 @@ def correct_time_related_info(arch_index: int, arch_info_full: ArchResults, arch
|
||||
arch_info.reset_pseudo_eval_times('ImageNet16-120', None, 'ori-test', eval_per_sample * nums['ImageNet16-120-test'])
|
||||
# arch_info_full.debug_test()
|
||||
# arch_info_less.debug_test()
|
||||
# import pdb; pdb.set_trace()
|
||||
return arch_info_full, arch_info_less
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user