add config path
This commit is contained in:
@@ -24,7 +24,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
|
||||
|
||||
from nas_201_api import NASBench201API as API
|
||||
self.api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
|
||||
self.api = API(cfg.general.nas_201)
|
||||
|
||||
input_dims = dataset_infos.input_dims
|
||||
output_dims = dataset_infos.output_dims
|
||||
@@ -44,7 +44,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
self.args.batch_size = 128
|
||||
self.args.GPU = '0'
|
||||
self.args.dataset = 'cifar10-valid'
|
||||
self.args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||
self.args.api_loc = cfg.general.nas_201
|
||||
self.args.data_loc = '../cifardata/'
|
||||
self.args.seed = 777
|
||||
self.args.init = ''
|
||||
@@ -177,7 +177,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
rewards = []
|
||||
if reward_model == 'swap':
|
||||
import csv
|
||||
with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f:
|
||||
with open(self.cfg.general.swap_result, 'r') as f:
|
||||
reader = csv.reader(f)
|
||||
header = next(reader)
|
||||
data = [row for row in reader]
|
||||
@@ -345,10 +345,15 @@ class Graph_DiT(pl.LightningModule):
|
||||
num_examples = self.val_y_collection.size(0)
|
||||
batch_y = self.val_y_collection[start_index:start_index + to_generate]
|
||||
all_ys.append(batch_y)
|
||||
samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
|
||||
cur_sample, logprobs = self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
|
||||
save_final=to_save,
|
||||
keep_chain=chains_save,
|
||||
number_chain_steps=self.number_chain_steps))
|
||||
number_chain_steps=self.number_chain_steps)
|
||||
samples.extend(cur_sample)
|
||||
# samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y,
|
||||
# save_final=to_save,
|
||||
# keep_chain=chains_save,
|
||||
# number_chain_steps=self.number_chain_steps))
|
||||
ident += to_generate
|
||||
start_index += to_generate
|
||||
|
||||
@@ -423,7 +428,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
|
||||
cur_sample, log_probs = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
||||
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
|
||||
samples.append(cur_sample)
|
||||
samples.extend(cur_sample)
|
||||
|
||||
all_ys.append(batch_y)
|
||||
batch_id += to_generate
|
||||
|
Reference in New Issue
Block a user