first commit
This commit is contained in:
305
NAS-Bench-201/main_exp/transfer_nag/nag.py
Normal file
305
NAS-Bench-201/main_exp/transfer_nag/nag.py
Normal file
@@ -0,0 +1,305 @@
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import os
|
||||
import gc
|
||||
import sys
|
||||
import numpy as np
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from nag_utils import mean_confidence_interval
|
||||
from nag_utils import restore_checkpoint
|
||||
from nag_utils import load_graph_config
|
||||
from nag_utils import load_model
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
|
||||
from nas_bench_201 import train_single_model
|
||||
from unnoised_model import MetaSurrogateUnnoisedModel
|
||||
from diffusion.run_lib import generate_archs_meta
|
||||
from diffusion.run_lib import get_sampling_fn_meta
|
||||
from diffusion.run_lib import get_score_model
|
||||
from diffusion.run_lib import get_surrogate
|
||||
from loader import MetaTestDataset
|
||||
from logger import Logger
|
||||
from all_path import *
|
||||
|
||||
|
||||
class NAG:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
## Target dataset information
|
||||
self.raw_data_path = RAW_DATA_PATH
|
||||
self.data_path = DATA_PATH
|
||||
self.data_name = args.data_name
|
||||
self.num_class = args.num_class
|
||||
self.num_sample = args.num_sample
|
||||
|
||||
graph_config = load_graph_config(args.graph_data_name, args.nvt, NASBENCH201)
|
||||
self.meta_surrogate_unnoised_model = MetaSurrogateUnnoisedModel(args, graph_config)
|
||||
load_model(model=self.meta_surrogate_unnoised_model,
|
||||
ckpt_path=META_SURROGATE_UNNOISED_CKPT_PATH)
|
||||
self.meta_surrogate_unnoised_model.to(self.device)
|
||||
|
||||
## Load pre-trained meta-surrogate model
|
||||
self.meta_surrogate_ckpt_path = META_SURROGATE_CKPT_PATH
|
||||
|
||||
## Load score network model (base diffusion model)
|
||||
self.load_diffusion_model(args=args)
|
||||
|
||||
## Check config
|
||||
self.check_config()
|
||||
|
||||
## Set logger
|
||||
self.logger = Logger(
|
||||
log_dir=args.exp_name,
|
||||
write_textfile=True
|
||||
)
|
||||
self.logger.update_config(args, is_args=True)
|
||||
self.logger.write_str(str(vars(args)))
|
||||
self.logger.write_str('-' * 100)
|
||||
|
||||
|
||||
def check_config(self):
|
||||
"""
|
||||
Check if the configuration of the pre-trained score network model matches that of the meta surrogate model.
|
||||
"""
|
||||
scorenet_config = torch.load(self.config.scorenet_ckpt_path)['config']
|
||||
meta_surrogate_config = torch.load(self.meta_surrogate_ckpt_path)['config']
|
||||
assert scorenet_config.model.sigma_min == meta_surrogate_config.model.sigma_min
|
||||
assert scorenet_config.model.sigma_max == meta_surrogate_config.model.sigma_max
|
||||
assert scorenet_config.training.sde == meta_surrogate_config.training.sde
|
||||
assert scorenet_config.training.continuous == meta_surrogate_config.training.continuous
|
||||
assert scorenet_config.data.centered == meta_surrogate_config.data.centered
|
||||
assert scorenet_config.data.max_node == meta_surrogate_config.data.max_node
|
||||
assert scorenet_config.data.n_vocab == meta_surrogate_config.data.n_vocab
|
||||
|
||||
|
||||
def forward(self, x, arch):
|
||||
D_mu = self.meta_surrogate_unnoised_model.set_encode(x.to(self.device))
|
||||
G_mu = self.meta_surrogate_unnoised_model.graph_encode(arch)
|
||||
y_pred = self.meta_surrogate_unnoised_model.predict(D_mu, G_mu)
|
||||
return y_pred
|
||||
|
||||
|
||||
def meta_test(self):
|
||||
if self.data_name == 'all':
|
||||
for data_name in ['cifar10', 'cifar100', 'aircraft', 'pets']:
|
||||
self.meta_test_per_dataset(data_name)
|
||||
else:
|
||||
self.meta_test_per_dataset(self.data_name)
|
||||
|
||||
|
||||
def meta_test_per_dataset(self, data_name):
|
||||
## Load NASBench201
|
||||
self.nasbench201 = torch.load(NASBENCH201)
|
||||
all_arch_str = np.array(self.nasbench201['arch']['str'])
|
||||
|
||||
## Load meta-test dataset
|
||||
self.test_dataset = MetaTestDataset(self.data_path, data_name, self.num_sample, self.num_class)
|
||||
|
||||
## Set save path
|
||||
meta_test_path = os.path.join(META_TEST_PATH, data_name)
|
||||
os.makedirs(meta_test_path, exist_ok=True)
|
||||
f_arch_str = open(os.path.join(self.args.exp_name, 'architecture.txt'), 'w')
|
||||
f_arch_acc = open(os.path.join(self.args.exp_name, 'accuracy.txt'), 'w')
|
||||
|
||||
## Generate architectures
|
||||
gen_arch_str = self.get_gen_arch_str()
|
||||
gen_arch_igraph = self.get_items(
|
||||
full_target=self.nasbench201['arch']['igraph'],
|
||||
full_source=self.nasbench201['arch']['str'],
|
||||
source=gen_arch_str)
|
||||
|
||||
## Sort with unnoised meta-surrogate model
|
||||
y_pred_all = []
|
||||
self.meta_surrogate_unnoised_model.eval()
|
||||
self.meta_surrogate_unnoised_model.to(self.device)
|
||||
with torch.no_grad():
|
||||
for arch_igraph in gen_arch_igraph:
|
||||
x, g = self.collect_data(arch_igraph)
|
||||
y_pred = self.forward(x, g)
|
||||
y_pred = torch.mean(y_pred)
|
||||
y_pred_all.append(y_pred.cpu().detach().item())
|
||||
sorted_arch_lst = self.sort_arch(data_name, torch.tensor(y_pred_all), gen_arch_str)
|
||||
|
||||
## Record the information of the architecture generated in sorted order
|
||||
for _, arch_str in enumerate(sorted_arch_lst):
|
||||
f_arch_str.write(f'{arch_str}\n')
|
||||
arch_idx_lst = [self.nasbench201['arch']['str'].index(i) for i in sorted_arch_lst]
|
||||
arch_str_lst = []
|
||||
arch_acc_lst = []
|
||||
|
||||
## Get the accuracy of the architecture
|
||||
if 'cifar' in data_name:
|
||||
sorted_acc_lst = self.get_items(
|
||||
full_target=self.nasbench201['test-acc'][data_name],
|
||||
full_source=self.nasbench201['arch']['str'],
|
||||
source=sorted_arch_lst)
|
||||
arch_str_lst += sorted_arch_lst
|
||||
arch_acc_lst += sorted_acc_lst
|
||||
for arch_idx, acc in zip(arch_idx_lst, sorted_acc_lst):
|
||||
msg = f'Avg {acc:4f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
else:
|
||||
if self.args.multi_proc:
|
||||
## Run multiple processes in parallel
|
||||
run_file = os.path.join(os.getcwd(), 'main_exp', 'transfer_nag', 'run_multi_proc.py')
|
||||
MAX_CAP = 5 # hard-coded for available GPUs
|
||||
if not len(arch_idx_lst) > MAX_CAP:
|
||||
arch_idx_lst_ = [arch_idx for arch_idx in arch_idx_lst if not os.path.exists(os.path.join(meta_test_path, str(arch_idx)))]
|
||||
support_ = ','.join([str(i) for i in arch_idx_lst_])
|
||||
num_split = int(3 * len(arch_idx_lst_)) # why 3? => running for 3 seeds
|
||||
cmd = f"python {run_file} --num_split {num_split} --arch_idx_lst {support_} --meta_test_path {meta_test_path} --data_name {data_name} --raw_data_path {self.raw_data_path}"
|
||||
subprocess.run([cmd], shell=True)
|
||||
else:
|
||||
arch_idx_lst_ = []
|
||||
for j, arch_idx in enumerate(arch_idx_lst):
|
||||
if not os.path.exists(os.path.join(meta_test_path, str(arch_idx))):
|
||||
arch_idx_lst_.append(arch_idx)
|
||||
if (len(arch_idx_lst_) == MAX_CAP) or (j == len(arch_idx_lst) - 1):
|
||||
support_ = ','.join([str(i) for i in arch_idx_lst_])
|
||||
num_split = int(3 * len(arch_idx_lst_))
|
||||
cmd = f"python {run_file} --num_split {num_split} --arch_idx_lst {support_} --meta_test_path {meta_test_path} --data_name {data_name} --raw_data_path {self.raw_data_path}"
|
||||
subprocess.run([cmd], shell=True)
|
||||
arch_idx_lst_ = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
acc_runs_lst = []
|
||||
epoch = 199
|
||||
seeds = (777, 888, 999)
|
||||
for arch_idx in arch_idx_lst:
|
||||
acc_runs = []
|
||||
save_path_ = os.path.join(meta_test_path, str(arch_idx))
|
||||
for seed in seeds:
|
||||
result = torch.load(os.path.join(save_path_, f'seed-0{seed}.pth'))
|
||||
acc_runs.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
|
||||
acc_runs_lst.append(acc_runs)
|
||||
break
|
||||
except:
|
||||
pass
|
||||
for i in acc_runs_lst:print(np.mean(i))
|
||||
for arch_idx, acc_runs in zip(arch_idx_lst, acc_runs_lst):
|
||||
for r, acc in enumerate(acc_runs):
|
||||
msg = f'run {r+1} {acc:.2f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
m, h = mean_confidence_interval(acc_runs)
|
||||
msg = f'Avg {m:.2f}+-{h.item():.2f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
arch_acc_lst.append(np.mean(acc_runs))
|
||||
arch_str_lst.append(all_arch_str[arch_idx])
|
||||
|
||||
else:
|
||||
for arch_idx in arch_idx_lst:
|
||||
acc_runs = self.train_single_arch(
|
||||
data_name, self.nasbench201['str'][arch_idx], meta_test_path)
|
||||
for r, acc in enumerate(acc_runs):
|
||||
msg = f'run {r+1} {acc:.2f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
m, h = mean_confidence_interval(acc_runs)
|
||||
msg = f'Avg {m:.2f}+-{h.item():.2f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
arch_acc_lst.append(np.mean(acc_runs))
|
||||
arch_str_lst.append(all_arch_str[arch_idx])
|
||||
|
||||
# Save results
|
||||
results_path = os.path.join(self.args.exp_name, 'results.pt')
|
||||
torch.save({
|
||||
'arch_idx_lst': arch_idx_lst,
|
||||
'arch_str_lst': arch_str_lst,
|
||||
'arch_acc_lst': arch_acc_lst
|
||||
}, results_path)
|
||||
print(f">>> Save the results at {results_path}...")
|
||||
|
||||
|
||||
def train_single_arch(self, data_name, arch_str, meta_test_path):
|
||||
save_path = os.path.join(meta_test_path, arch_str)
|
||||
seeds = (777, 888, 999)
|
||||
train_single_model(save_dir=save_path,
|
||||
workers=24,
|
||||
datasets=[data_name],
|
||||
xpaths=[f'{self.raw_data_path}/{data_name}'],
|
||||
splits=[0],
|
||||
use_less=False,
|
||||
seeds=seeds,
|
||||
model_str=arch_str,
|
||||
arch_config={'channel': 16, 'num_cells': 5})
|
||||
epoch = 199
|
||||
test_acc_lst = []
|
||||
for seed in seeds:
|
||||
result = torch.load(os.path.join(save_path, f'seed-0{seed}.pth'))
|
||||
test_acc_lst.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
|
||||
return test_acc_lst
|
||||
|
||||
|
||||
def sort_arch(self, data_name, y_pred_all, gen_arch_str):
|
||||
_, sorted_idx = torch.sort(y_pred_all, descending=True)
|
||||
sotred_gen_arch_str = [gen_arch_str[_] for _ in sorted_idx]
|
||||
return sotred_gen_arch_str
|
||||
|
||||
|
||||
def collect_data_only(self):
|
||||
x_batch = []
|
||||
x_batch.append(self.test_dataset[0])
|
||||
return torch.stack(x_batch).to(self.device)
|
||||
|
||||
|
||||
def collect_data(self, arch_igraph):
|
||||
x_batch, g_batch = [], []
|
||||
for _ in range(10):
|
||||
x_batch.append(self.test_dataset[0])
|
||||
g_batch.append(arch_igraph)
|
||||
return torch.stack(x_batch).to(self.device), g_batch
|
||||
|
||||
|
||||
def get_items(self, full_target, full_source, source):
|
||||
return [full_target[full_source.index(_)] for _ in source]
|
||||
|
||||
|
||||
def load_diffusion_model(self, args):
|
||||
self.config = torch.load('./configs/transfer_nag_config.pt')
|
||||
self.config.device = torch.device('cuda')
|
||||
self.config.data.label_list = ['meta-acc']
|
||||
self.config.scorenet_ckpt_path = SCORENET_CKPT_PATH
|
||||
self.config.sampling.classifier_scale = args.classifier_scale
|
||||
self.config.eval.batch_size = args.eval_batch_size
|
||||
self.config.sampling.predictor = args.predictor
|
||||
self.config.sampling.corrector = args.corrector
|
||||
self.config.sampling.check_dataname = self.data_name
|
||||
self.sampling_fn, self.sde = get_sampling_fn_meta(self.config)
|
||||
self.score_model, self.score_ema, self.score_config = get_score_model(self.config)
|
||||
|
||||
|
||||
def get_gen_arch_str(self):
|
||||
## Load meta-surrogate model
|
||||
meta_surrogate_config = torch.load(self.meta_surrogate_ckpt_path)['config']
|
||||
meta_surrogate_model = get_surrogate(meta_surrogate_config)
|
||||
meta_surrogate_state = dict(model=meta_surrogate_model, step=0, config=meta_surrogate_config)
|
||||
meta_surrogate_state = restore_checkpoint(
|
||||
self.meta_surrogate_ckpt_path,
|
||||
meta_surrogate_state,
|
||||
device=self.config.device,
|
||||
resume=True)
|
||||
|
||||
## Get dataset embedding, x
|
||||
with torch.no_grad():
|
||||
x = self.collect_data_only()
|
||||
|
||||
## Generate architectures
|
||||
generated_arch_str = generate_archs_meta(
|
||||
config=self.config,
|
||||
sampling_fn=self.sampling_fn,
|
||||
score_model=self.score_model,
|
||||
score_ema=self.score_ema,
|
||||
meta_surrogate_model=meta_surrogate_model,
|
||||
num_samples=self.args.n_gen_samples,
|
||||
args=self.args,
|
||||
task=x)
|
||||
|
||||
## Clean up
|
||||
meta_surrogate_model = None
|
||||
gc.collect()
|
||||
|
||||
return generated_arch_str
|
Reference in New Issue
Block a user