try to run the graph, commented sampling codes
This commit is contained in:
@@ -13,9 +13,11 @@ from metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchKL, NLL
|
||||
import utils
|
||||
|
||||
class Graph_DiT(pl.LightningModule):
|
||||
def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
|
||||
# def __init__(self, cfg, dataset_infos, train_metrics, sampling_metrics, visualization_tools):
|
||||
def __init__(self, cfg, dataset_infos, visualization_tools):
|
||||
|
||||
super().__init__()
|
||||
self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
|
||||
# self.save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])
|
||||
self.test_only = cfg.general.test_only
|
||||
self.guidance_target = getattr(cfg.dataset, 'guidance_target', None)
|
||||
|
||||
@@ -55,8 +57,8 @@ class Graph_DiT(pl.LightningModule):
|
||||
self.test_E_logp = SumExceptBatchMetric()
|
||||
self.test_y_collection = []
|
||||
|
||||
self.train_metrics = train_metrics
|
||||
self.sampling_metrics = sampling_metrics
|
||||
# self.train_metrics = train_metrics
|
||||
# self.sampling_metrics = sampling_metrics
|
||||
|
||||
self.visualization_tools = visualization_tools
|
||||
self.max_n_nodes = dataset_infos.max_n_nodes
|
||||
@@ -171,7 +173,7 @@ class Graph_DiT(pl.LightningModule):
|
||||
self.val_E_kl.reset()
|
||||
self.val_X_logp.reset()
|
||||
self.val_E_logp.reset()
|
||||
self.sampling_metrics.reset()
|
||||
# self.sampling_metrics.reset()
|
||||
self.val_y_collection = []
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -239,14 +241,15 @@ class Graph_DiT(pl.LightningModule):
|
||||
samples_left_to_generate -= to_generate
|
||||
chains_left_to_save -= chains_save
|
||||
|
||||
print(f"Computing sampling metrics", ' ...')
|
||||
valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False)
|
||||
print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
|
||||
# print(f"Computing sampling metrics", ' ...')
|
||||
# valid_smiles = self.sampling_metrics(samples, all_ys, self.name, self.current_epoch, val_counter=-1, test=False)
|
||||
# print(f'Done. Sampling took {time.time() - start:.2f} seconds\n')
|
||||
|
||||
current_path = os.getcwd()
|
||||
result_path = os.path.join(current_path,
|
||||
f'graphs/{self.name}/epoch{self.current_epoch}_b0/')
|
||||
self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save)
|
||||
self.sampling_metrics.reset()
|
||||
# self.visualization_tools.visualize_by_smiles(result_path, valid_smiles, self.cfg.general.samples_to_save)
|
||||
# self.sampling_metrics.reset()
|
||||
|
||||
def on_test_epoch_start(self) -> None:
|
||||
print("Starting test...")
|
||||
|
Reference in New Issue
Block a user