first commit
This commit is contained in:
77
NAS-Bench-201/analysis/arch_metrics.py
Normal file
77
NAS-Bench-201/analysis/arch_metrics.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from analysis.arch_functions import compute_arch_metrics, compute_arch_metrics_meta
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SamplingArchMetrics(nn.Module):
|
||||
def __init__(self,
|
||||
config,
|
||||
train_ds,
|
||||
exp_name,):
|
||||
|
||||
super().__init__()
|
||||
self.exp_name = exp_name
|
||||
self.train_ds = train_ds
|
||||
self.train_arch_str_list = train_ds.arch_str_list_
|
||||
|
||||
|
||||
def forward(self,
|
||||
arch_list: list,
|
||||
this_sample_dir,
|
||||
check_dataname='cifar10'):
|
||||
|
||||
arch_metrics, all_arch_str = compute_arch_metrics(arch_list=arch_list,
|
||||
train_arch_str_list=self.train_arch_str_list,
|
||||
train_ds=self.train_ds,
|
||||
check_dataname=check_dataname)
|
||||
|
||||
valid_unique_arch = arch_metrics[1] # arch_str
|
||||
valid_unique_arch_prop_dict = arch_metrics[2] # flops, params, latency
|
||||
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
|
||||
for i in range(len(valid_unique_arch)):
|
||||
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
|
||||
textfile.write(f"Arch Index: {valid_unique_arch_prop_dict['arch_idx_list'][i]} \n")
|
||||
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n")
|
||||
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
|
||||
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n\n")
|
||||
textfile.writelines(valid_unique_arch)
|
||||
textfile.close()
|
||||
|
||||
return arch_metrics
|
||||
|
||||
|
||||
class SamplingArchMetricsMeta(nn.Module):
|
||||
def __init__(self,
|
||||
config,
|
||||
train_ds,
|
||||
exp_name):
|
||||
|
||||
super().__init__()
|
||||
self.exp_name = exp_name
|
||||
self.train_ds = train_ds
|
||||
self.search_space = config.data.name
|
||||
self.train_arch_str_list = [train_ds.arch_str_list[i] for i in train_ds.idx_lst['train']]
|
||||
|
||||
|
||||
def forward(self,
|
||||
arch_list: list,
|
||||
this_sample_dir,
|
||||
check_dataname='cifar10'):
|
||||
|
||||
arch_metrics = compute_arch_metrics_meta(arch_list=arch_list,
|
||||
train_arch_str_list=self.train_arch_str_list,
|
||||
train_ds=self.train_ds,
|
||||
check_dataname=check_dataname)
|
||||
|
||||
valid_unique_arch = arch_metrics[1] # arch_str
|
||||
valid_unique_arch_prop_dict = arch_metrics[2] # flops, params, latency
|
||||
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
|
||||
for i in range(len(valid_unique_arch)):
|
||||
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
|
||||
textfile.write(f"Arch Index: {valid_unique_arch_prop_dict['arch_idx_list'][i]} \n")
|
||||
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n")
|
||||
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
|
||||
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n\n")
|
||||
textfile.writelines(valid_unique_arch)
|
||||
textfile.close()
|
||||
|
||||
return arch_metrics
|
Reference in New Issue
Block a user