try to run the graph, commented sampling codes

This commit is contained in:
mhz
2024-06-25 00:09:27 +02:00
parent e04ad5fbe7
commit 82299e5213
5 changed files with 80 additions and 209 deletions

View File

@@ -13,9 +13,11 @@ from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
import utils
class Graph_DiT(pl.LightningModule):
def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
# def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
def __init__(self, cfg, dataset_infos, visualization_tools):
super().__init__()
self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
# self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
self.test_only = cfg.general.test_only
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
@@ -55,8 +57,8 @@ class Graph_DiT(pl.LightningModule):
self.test_E_logp = SumExceptBatchMetric()
self.test_y_collection = []
self.train_metrics = train_metrics
self.sampling_metrics = sampling_metrics
# self.train_metrics = train_metrics
# self.sampling_metrics = sampling_metrics
self.visualization_tools = visualization_tools
self.max_n_nodes = dataset_infos.max_n_nodes
@@ -171,7 +173,7 @@ class Graph_DiT(pl.LightningModule):
self.val_E_kl.reset()
self.val_X_logp.reset()
self.val_E_logp.reset()
self.sampling_metrics.reset()
# self.sampling_metrics.reset()
self.val_y_collection = []
@torch.no_grad()
@@ -239,14 +241,15 @@ class Graph_DiT(pl.LightningModule):
samples_left_to_generate -= to_generate
chains_left_to_save -= chains_save
print(f"Computing sampling metrics", ' ...')
valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False)
print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
# print(f"Computing sampling metrics", ' ...')
# valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False)
# print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
current_path = os.getcwd()
result_path = os.path.join(current_path,
f'graphs/{self.name}/epoch{self.current_epoch}_b0/')
self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save)
self.sampling_metrics.reset()
# self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save)
# self.sampling_metrics.reset()
def on_test_epoch_start(self) -> None:
print("Starting test...")