try to run the graph, commented sampling codes
This commit is contained in:
@@ -50,7 +50,6 @@ def get_resume_adaptive(cfg, model_kwargs):
|
||||
# Fetch path to this file to get base path
|
||||
current_path = os.path.dirname(os.path.realpath(__file__))
|
||||
root_dir = current_path.split("outputs")[0]
|
||||
|
||||
resume_path = os.path.join(root_dir, cfg.general.resume)
|
||||
|
||||
if cfg.model.type == "discrete":
|
||||
@@ -80,21 +79,21 @@ def main(cfg: DictConfig):
|
||||
datamodule = dataset.DataModule(cfg)
|
||||
datamodule.prepare_data()
|
||||
dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg)
|
||||
train_smiles, reference_smiles = datamodule.get_train_smiles()
|
||||
# train_smiles, reference_smiles = datamodule.get_train_smiles()
|
||||
|
||||
# get input output dimensions
|
||||
dataset_infos.compute_input_output_dims(datamodule=datamodule)
|
||||
train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
|
||||
# train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)
|
||||
|
||||
sampling_metrics = SamplingMolecularMetrics(
|
||||
dataset_infos, train_smiles, reference_smiles
|
||||
)
|
||||
# sampling_metrics = SamplingMolecularMetrics(
|
||||
# dataset_infos, train_smiles, reference_smiles
|
||||
# )
|
||||
visualization_tools = MolecularVisualization(dataset_infos)
|
||||
|
||||
model_kwargs = {
|
||||
"dataset_infos": dataset_infos,
|
||||
"train_metrics": train_metrics,
|
||||
"sampling_metrics": sampling_metrics,
|
||||
# "train_metrics": train_metrics,
|
||||
# "sampling_metrics": sampling_metrics,
|
||||
"visualization_tools": visualization_tools,
|
||||
}
|
||||
|
||||
@@ -110,9 +109,10 @@ def main(cfg: DictConfig):
|
||||
model = Graph_DiT(cfg=cfg, **model_kwargs)
|
||||
trainer = Trainer(
|
||||
gradient_clip_val=cfg.train.clip_grad,
|
||||
accelerator="gpu"
|
||||
if torch.cuda.is_available() and cfg.general.gpus > 0
|
||||
else "cpu",
|
||||
# accelerator="gpu"
|
||||
# if torch.cuda.is_available() and cfg.general.gpus > 0
|
||||
# else "cpu",
|
||||
accelerator="cpu",
|
||||
devices=cfg.general.gpus
|
||||
if torch.cuda.is_available() and cfg.general.gpus > 0
|
||||
else None,
|
||||
|
Reference in New Issue
Block a user