Compare commits
5 Commits
82299e5213
...
a7f7010da7
Author | SHA1 | Date | |
---|---|---|---|
a7f7010da7 | |||
14186fa97f | |||
a222c514d9 | |||
062a27b83f | |||
0c7c525680 |
@ -127,4 +127,19 @@ class AbstractDatasetInfos:
|
|||||||
print('input dims')
|
print('input dims')
|
||||||
print(self.input_dims)
|
print(self.input_dims)
|
||||||
print('output dims')
|
print('output dims')
|
||||||
|
print(self.output_dims)
|
||||||
|
def compute_graph_input_output_dims(self, datamodule):
|
||||||
|
example_batch = datamodule.example_batch()
|
||||||
|
example_batch_x = torch.nn.functional.one_hot(example_batch.x, num_classes=8).float()[:, self.active_index]
|
||||||
|
example_batch_edge_attr = torch.nn.functional.one_hot(example_batch.edge_attr, num_classes=2).float()
|
||||||
|
|
||||||
|
self.input_dims = {'X': example_batch_x.size(1),
|
||||||
|
'E': example_batch_edge_attr.size(1),
|
||||||
|
'y': example_batch['y'].size(1)}
|
||||||
|
self.output_dims = {'X': example_batch_x.size(1),
|
||||||
|
'E': example_batch_edge_attr.size(1),
|
||||||
|
'y': example_batch['y'].size(1)}
|
||||||
|
print('input dims')
|
||||||
|
print(self.input_dims)
|
||||||
|
print('output dims')
|
||||||
print(self.output_dims)
|
print(self.output_dims)
|
@ -50,12 +50,12 @@ class DataModule(AbstractDataModule):
|
|||||||
|
|
||||||
def prepare_data(self) -> None:
|
def prepare_data(self) -> None:
|
||||||
target = getattr(self.cfg.dataset, 'guidance_target', None)
|
target = getattr(self.cfg.dataset, 'guidance_target', None)
|
||||||
print("target", target)
|
print("target", target) # nasbench-201
|
||||||
# try:
|
# try:
|
||||||
# base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
# base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
|
||||||
# except NameError:
|
# except NameError:
|
||||||
# base_path = pathlib.Path(os.getcwd()).parent[2]
|
# base_path = pathlib.Path(os.getcwd()).parent[2]
|
||||||
base_path = '/home/stud/hanzhang/Graph-Dit'
|
base_path = '/home/stud/hanzhang/nasbenchDiT'
|
||||||
root_path = os.path.join(base_path, self.datadir)
|
root_path = os.path.join(base_path, self.datadir)
|
||||||
self.root_path = root_path
|
self.root_path = root_path
|
||||||
|
|
||||||
@ -68,13 +68,16 @@ class DataModule(AbstractDataModule):
|
|||||||
# Dataset has target property, root path, and transform
|
# Dataset has target property, root path, and transform
|
||||||
source = './NAS-Bench-201-v1_1-096897.pth'
|
source = './NAS-Bench-201-v1_1-096897.pth'
|
||||||
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)
|
dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)
|
||||||
|
self.dataset = dataset
|
||||||
|
self.api = dataset.api
|
||||||
|
|
||||||
# if len(self.task.split('-')) == 2:
|
# if len(self.task.split('-')) == 2:
|
||||||
# train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
|
# train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
|
||||||
# else:
|
# else:
|
||||||
train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)
|
train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)
|
||||||
|
|
||||||
self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index
|
self.train_index, self.val_index, self.test_index, self.unlabeled_index = (
|
||||||
|
train_index, val_index, test_index, unlabeled_index)
|
||||||
train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)
|
train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)
|
||||||
if len(unlabeled_index) > 0:
|
if len(unlabeled_index) > 0:
|
||||||
train_index = torch.cat([train_index, unlabeled_index], dim=0)
|
train_index = torch.cat([train_index, unlabeled_index], dim=0)
|
||||||
@ -175,6 +178,27 @@ class DataModule(AbstractDataModule):
|
|||||||
smiles = Chem.MolToSmiles(mol)
|
smiles = Chem.MolToSmiles(mol)
|
||||||
return smiles
|
return smiles
|
||||||
|
|
||||||
|
def get_train_graphs(self):
|
||||||
|
train_graphs = []
|
||||||
|
test_graphs = []
|
||||||
|
for graph in self.train_dataset:
|
||||||
|
train_graphs.append(graph)
|
||||||
|
for graph in self.test_dataset:
|
||||||
|
test_graphs.append(graph)
|
||||||
|
return train_graphs, test_graphs
|
||||||
|
|
||||||
|
|
||||||
|
# def get_train_smiles(self):
|
||||||
|
# filename = f'{self.task}.csv.gz'
|
||||||
|
# df = pd.read_csv(f'{self.root_path}/raw/{filename}')
|
||||||
|
# df_test = df.iloc[self.test_index]
|
||||||
|
# df = df.iloc[self.train_index]
|
||||||
|
# smiles_list = df['smiles'].tolist()
|
||||||
|
# smiles_list_test = df_test['smiles'].tolist()
|
||||||
|
# smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list]
|
||||||
|
# smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test]
|
||||||
|
# return smiles_list, smiles_list_test
|
||||||
|
|
||||||
def get_train_smiles(self):
|
def get_train_smiles(self):
|
||||||
train_smiles = []
|
train_smiles = []
|
||||||
test_smiles = []
|
test_smiles = []
|
||||||
@ -477,14 +501,17 @@ def graphs_to_json(graphs, filename):
|
|||||||
class Dataset(InMemoryDataset):
|
class Dataset(InMemoryDataset):
|
||||||
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
|
def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
|
||||||
self.target_prop = target_prop
|
self.target_prop = target_prop
|
||||||
source = '/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
source = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
|
||||||
self.source = source
|
self.source = source
|
||||||
|
super().__init__(root, transform, pre_transform, pre_filter)
|
||||||
|
print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt
|
||||||
self.api = API(source) # Initialize NAS-Bench-201 API
|
self.api = API(source) # Initialize NAS-Bench-201 API
|
||||||
print('API loaded')
|
print('API loaded')
|
||||||
super().__init__(root, transform, pre_transform, pre_filter)
|
|
||||||
print('Dataset initialized')
|
print('Dataset initialized')
|
||||||
print(self.processed_paths[0])
|
|
||||||
self.data, self.slices = torch.load(self.processed_paths[0])
|
self.data, self.slices = torch.load(self.processed_paths[0])
|
||||||
|
self.data.edge_attr = self.data.edge_attr.squeeze()
|
||||||
|
self.data.idx = torch.arange(len(self.data.y))
|
||||||
|
print(f"self.data={self.data}, self.slices={self.slices}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def raw_file_names(self):
|
def raw_file_names(self):
|
||||||
@ -676,7 +703,7 @@ def create_adj_matrix_and_ops(nodes, edges):
|
|||||||
adj_matrix[src][dst] = 1
|
adj_matrix[src][dst] = 1
|
||||||
return adj_matrix, nodes
|
return adj_matrix, nodes
|
||||||
class DataInfos(AbstractDatasetInfos):
|
class DataInfos(AbstractDatasetInfos):
|
||||||
def __init__(self, datamodule, cfg):
|
def __init__(self, datamodule, cfg, dataset):
|
||||||
tasktype_dict = {
|
tasktype_dict = {
|
||||||
'hiv_b': 'classification',
|
'hiv_b': 'classification',
|
||||||
'bace_b': 'classification',
|
'bace_b': 'classification',
|
||||||
@ -689,6 +716,7 @@ class DataInfos(AbstractDatasetInfos):
|
|||||||
self.task = task_name
|
self.task = task_name
|
||||||
self.task_type = tasktype_dict.get(task_name, "regression")
|
self.task_type = tasktype_dict.get(task_name, "regression")
|
||||||
self.ensure_connected = cfg.model.ensure_connected
|
self.ensure_connected = cfg.model.ensure_connected
|
||||||
|
self.api = dataset.api
|
||||||
|
|
||||||
datadir = cfg.dataset.datadir
|
datadir = cfg.dataset.datadir
|
||||||
|
|
||||||
@ -699,9 +727,9 @@ class DataInfos(AbstractDatasetInfos):
|
|||||||
length = 15625
|
length = 15625
|
||||||
ops_type = {}
|
ops_type = {}
|
||||||
len_ops = set()
|
len_ops = set()
|
||||||
api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
|
# api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')
|
||||||
for i in range(length):
|
for i in range(length):
|
||||||
arch_info = api.query_meta_info_by_index(i)
|
arch_info = self.api.query_meta_info_by_index(i)
|
||||||
nodes, edges = parse_architecture_string(arch_info.arch_str)
|
nodes, edges = parse_architecture_string(arch_info.arch_str)
|
||||||
adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
|
adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
|
||||||
if i < 5:
|
if i < 5:
|
||||||
@ -716,7 +744,6 @@ class DataInfos(AbstractDatasetInfos):
|
|||||||
graphs.append((adj_matrix, ops))
|
graphs.append((adj_matrix, ops))
|
||||||
|
|
||||||
meta_dict = graphs_to_json(graphs, 'nasbench-201')
|
meta_dict = graphs_to_json(graphs, 'nasbench-201')
|
||||||
|
|
||||||
self.base_path = base_path
|
self.base_path = base_path
|
||||||
self.active_atoms = meta_dict['active_atoms']
|
self.active_atoms = meta_dict['active_atoms']
|
||||||
self.max_n_nodes = meta_dict['max_node']
|
self.max_n_nodes = meta_dict['max_node']
|
||||||
@ -930,4 +957,4 @@ def compute_meta(root, source_name, train_index, test_index):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pass
|
dataset = Dataset(source='nasbench', root='/home/stud/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None)
|
||||||
|
@ -78,16 +78,20 @@ def main(cfg: DictConfig):
|
|||||||
|
|
||||||
datamodule = dataset.DataModule(cfg)
|
datamodule = dataset.DataModule(cfg)
|
||||||
datamodule.prepare_data()
|
datamodule.prepare_data()
|
||||||
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg)
|
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset)
|
||||||
# train_smiles, reference_smiles = datamodule.get_train_smiles()
|
# train_smiles, reference_smiles = datamodule.get_train_smiles()
|
||||||
|
train_graphs, reference_graphs = datamodule.get_train_graphs()
|
||||||
|
|
||||||
# get input output dimensions
|
# get input output dimensions
|
||||||
dataset_infos.compute_input_output_dims(datamodule=datamodule)
|
dataset_infos.compute_input_output_dims(datamodule=datamodule)
|
||||||
# train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
|
train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
|
||||||
|
|
||||||
# sampling_metrics = SamplingMolecularMetrics(
|
# sampling_metrics = SamplingMolecularMetrics(
|
||||||
# dataset_infos, train_smiles, reference_smiles
|
# dataset_infos, train_smiles, reference_smiles
|
||||||
# )
|
# )
|
||||||
|
sampling_metrics = SamplingGraphMetrics(
|
||||||
|
dataset_infos, train_graphs, reference_graphs
|
||||||
|
)
|
||||||
visualization_tools = MolecularVisualization(dataset_infos)
|
visualization_tools = MolecularVisualization(dataset_infos)
|
||||||
|
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
@ -135,5 +139,16 @@ def main(cfg: DictConfig):
|
|||||||
else:
|
else:
|
||||||
trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only)
|
trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only)
|
||||||
|
|
||||||
|
@hydra.main(
|
||||||
|
version_base="1.1", config_path="../configs", config_name="config"
|
||||||
|
)
|
||||||
|
def test(cfg: DictConfig):
|
||||||
|
datamodule = dataset.DataModule(cfg)
|
||||||
|
datamodule.prepare_data()
|
||||||
|
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset)
|
||||||
|
train_graphs, reference_graphs = datamodule.get_train_graphs()
|
||||||
|
|
||||||
|
dataset_infos.compute_input_output_dims(datamodule=datamodule)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
test()
|
||||||
|
0
graph_dit/workingdoc.md
Normal file
0
graph_dit/workingdoc.md
Normal file
Loading…
Reference in New Issue
Block a user