try to run the graph, commented sampling codes

This commit is contained in:
mhz
2024-06-25 00:09:27 +02:00
parent e04ad5fbe7
commit 82299e5213
5 changed files with 80 additions and 209 deletions

View File

@@ -50,7 +50,6 @@ def get_resume_adaptive(cfg, model_kwargs):
# Fetch path to this file to get base path
current_path = os.path.dirname(os.path.realpath(__file__))
root_dir = current_path.split("outputs")[0]
resume_path = os.path.join(root_dir, cfg.general.resume)
if cfg.model.type == "discrete":
@@ -80,21 +79,21 @@ def main(cfg: DictConfig):
datamodule = dataset.DataModule(cfg)
datamodule.prepare_data()
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg)
train_smiles, reference_smiles = datamodule.get_train_smiles()
# train_smiles, reference_smiles = datamodule.get_train_smiles()
# get input output dimensions
dataset_infos.compute_input_output_dims(datamodule=datamodule)
train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
# train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
sampling_metrics = SamplingMolecularMetrics(
dataset_infos, train_smiles, reference_smiles
)
# sampling_metrics = SamplingMolecularMetrics(
# dataset_infos, train_smiles, reference_smiles
# )
visualization_tools = MolecularVisualization(dataset_infos)
model_kwargs = {
"dataset_infos": dataset_infos,
"train_metrics": train_metrics,
"sampling_metrics": sampling_metrics,
# "train_metrics": train_metrics,
# "sampling_metrics": sampling_metrics,
"visualization_tools": visualization_tools,
}
@@ -110,9 +109,10 @@ def main(cfg: DictConfig):
model = Graph_DiT(cfg=cfg, **model_kwargs)
trainer = Trainer(
gradient_clip_val=cfg.train.clip_grad,
accelerator="gpu"
if torch.cuda.is_available() and cfg.general.gpus > 0
else "cpu",
# accelerator="gpu"
# if torch.cuda.is_available() and cfg.general.gpus > 0
# else "cpu",
accelerator="cpu",
devices=cfg.general.gpus
if torch.cuda.is_available() and cfg.general.gpus > 0
else None,