# These imports are tricky because they use c++, do not move them import tqdm import os, shutil import warnings import torch import hydra from omegaconf import DictConfig from pytorch_lightning import Trainer import utils from datasets import dataset from diffusion_model import Graph_DiT from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete from metrics.molecular_metrics_train import TrainGraphMetricsDiscrete from metrics.molecular_metrics_sampling import SamplingMolecularMetrics from metrics.molecular_metrics_sampling import SamplingGraphMetrics from analysis.visualization import MolecularVisualization from analysis.visualization import GraphVisualization warnings.filterwarnings("ignore", category=UserWarning) torch.set_float32_matmul_precision("medium") def remove_folder(folder): for filename in os.listdir(folder): file_path = os.path.join(folder, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: print("Failed to delete %s. Reason: %s" % (file_path, e)) def get_resume(cfg, model_kwargs): """Resumes a run. It loads previous config without allowing to update keys (used for testing).""" saved_cfg = cfg.copy() name = cfg.general.name + "_resume" resume = cfg.general.test_only batch_size = cfg.train.batch_size model = Graph_DiT.load_from_checkpoint(resume, **model_kwargs) cfg = model.cfg cfg.general.test_only = resume cfg.general.name = name cfg.train.batch_size = batch_size cfg = utils.update_config_with_new_keys(cfg, saved_cfg) return cfg, model def get_resume_adaptive(cfg, model_kwargs): """Resumes a run. It loads previous config but allows to make some changes (used for resuming training).""" saved_cfg = cfg.copy() # Fetch path to this file to get base path current_path = os.path.dirname(os.path.realpath(__file__)) root_dir = current_path.split("outputs")[0] resume_path = os.path.join(root_dir, cfg.general.resume) if cfg.model.type == "discrete": model = Graph_DiT.load_from_checkpoint( resume_path, **model_kwargs ) else: raise NotImplementedError("Unknown model") new_cfg = model.cfg for category in cfg: for arg in cfg[category]: new_cfg[category][arg] = cfg[category][arg] new_cfg.general.resume = resume_path new_cfg.general.name = new_cfg.general.name + "_resume" new_cfg = utils.update_config_with_new_keys(new_cfg, saved_cfg) return new_cfg, model @hydra.main( version_base="1.1", config_path="../configs", config_name="config" ) def main(cfg: DictConfig): datamodule = dataset.DataModule(cfg) datamodule.prepare_data() dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) train_smiles, reference_smiles = datamodule.get_train_smiles() # train_graphs, reference_graphs = datamodule.get_train_graphs() # get input output dimensions dataset_infos.compute_input_output_dims(datamodule=datamodule) train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) # train_metrics = TrainGraphMetricsDiscrete(dataset_infos) sampling_metrics = SamplingMolecularMetrics( dataset_infos, train_smiles, reference_smiles ) # sampling_metrics = SamplingGraphMetrics( # dataset_infos, train_graphs, reference_graphs # ) visualization_tools = MolecularVisualization(dataset_infos) model_kwargs = { "dataset_infos": dataset_infos, # "train_metrics": train_metrics, # "sampling_metrics": sampling_metrics, "visualization_tools": visualization_tools, } if cfg.general.test_only: # When testing, previous configuration is fully loaded cfg, _ = get_resume(cfg, model_kwargs) os.chdir(cfg.general.test_only.split("checkpoints")[0]) elif cfg.general.resume is not None: # When resuming, we can override some parts of previous configuration cfg, _ = get_resume_adaptive(cfg, model_kwargs) os.chdir(cfg.general.resume.split("checkpoints")[0]) model = Graph_DiT(cfg=cfg, **model_kwargs) trainer = Trainer( gradient_clip_val=cfg.train.clip_grad, # accelerator="gpu" # if torch.cuda.is_available() and cfg.general.gpus > 0 # else "cpu", accelerator="cpu", devices=cfg.general.gpus if torch.cuda.is_available() and cfg.general.gpus > 0 else None, max_epochs=cfg.train.n_epochs, enable_checkpointing=False, check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, val_check_interval=cfg.train.val_check_interval, strategy="ddp" if cfg.general.gpus > 1 else "auto", enable_progress_bar=cfg.general.enable_progress_bar, callbacks=[], reload_dataloaders_every_n_epochs=0, logger=[], ) if not cfg.general.test_only: trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) if cfg.general.save_model: trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") trainer.test(model, datamodule=datamodule) else: trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) from accelerate import Accelerator from accelerate.utils import set_seed, ProjectConfiguration @hydra.main( version_base="1.1", config_path="../configs", config_name="config" ) def test(cfg: DictConfig): accelerator_config = ProjectConfiguration( project_dir=os.path.join(cfg.general.log_dir, cfg.general.name), automatic_checkpoint_naming=True, total_limit=cfg.general.number_checkpoint_limit, ) accelerator = Accelerator( mixed_precision=cfg.mixed_precision, project_config=accelerator_config, gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.n_epochs, ) set_seed(cfg.train.seed, device_specific=True) datamodule = dataset.DataModule(cfg) datamodule.prepare_data() dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) train_graphs, reference_graphs = datamodule.get_train_graphs() dataset_infos.compute_input_output_dims(datamodule=datamodule) train_metrics = TrainGraphMetricsDiscrete(dataset_infos) sampling_metrics = SamplingGraphMetrics( dataset_infos, train_graphs, reference_graphs ) visulization_tools = GraphVisualization(dataset_infos) model_kwargs = { "dataset_infos": dataset_infos, "train_metrics": train_metrics, "sampling_metrics": sampling_metrics, "visualization_tools": visulization_tools, } if cfg.general.test_only: cfg, _ = get_resume(cfg, model_kwargs) os.chdir(cfg.general.test_only.split("checkpoints")[0]) elif cfg.general.resume is not None: cfg, _ = get_resume_adaptive(cfg, model_kwargs) os.chdir(cfg.general.resume.split("checkpoints")[0]) # os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number model = Graph_DiT(cfg=cfg, **model_kwargs) graph_dit_model = model inference_dtype = torch.float32 graph_dit_model.to(accelerator.device, dtype=inference_dtype) # optional: freeze the model # graph_dit_model.model.requires_grad_(True) import torch.nn.functional as F optimizer = graph_dit_model.configure_optimizers() # start training for epoch in range(cfg.train.n_epochs): graph_dit_model.train() # 设置模型为训练模式 for batch_data in datamodule.train_dataloader: # 从数据加载器中获取一个批次的数据 data_x = F.one_hot(batch_data.x, num_classes=12).float()[:, graph_dit_model.active_index] # 节点特征 data_edge_attr = F.one_hot(batch_data.edge_attr, num_classes=2).float() # 边特征 # 转换为 dense 格式并传递给 Graph_DiT dense_data, node_mask = utils.to_dense(data_x, batch_data.edge_index, data_edge_attr, batch_data.batch, graph_dit_model.max_n_nodes) dense_data = dense_data.mask(node_mask) X, E = dense_data.X, dense_data.E # 节点特征和边特征 y = batch_data.y # 标签 # 前向传播和损失计算 pred = graph_dit_model(dense_data) # 传入 Graph_DiT 模型 loss = graph_dit_model.train_loss(pred, X, E, y, node_mask) # 优化步骤 optimizer.zero_grad() loss.backward() optimizer.step() # start sampling samples = [] for i in tqdm( range(cfg.general.n_samples), desc="Sampling", disable=not cfg.general.enable_progress_bar ): batch_size = cfg.train.batch_size num_steps = cfg.model.diffusion_steps y = torch.ones(batch_size, num_steps, 1, 1, device=accelerator.device, dtype=inference_dtype) # sample from the model samples_batch = graph_dit_model.sample_batch( batch_id=i, batch_size=batch_size, y=y, keep_chain=1, number_chain_steps=num_steps, save_final=batch_size ) samples.append(samples_batch) # trainer = Trainer( # gradient_clip_val=cfg.train.clip_grad, # # accelerator="cpu", # accelerator="gpu" # if torch.cuda.is_available() and cfg.general.gpus > 0 # else "cpu", # devices=[cfg.general.gpu_number] # if torch.cuda.is_available() and cfg.general.gpus > 0 # else None, # max_epochs=cfg.train.n_epochs, # enable_checkpointing=False, # check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, # val_check_interval=cfg.train.val_check_interval, # strategy="ddp" if cfg.general.gpus > 1 else "auto", # enable_progress_bar=cfg.general.enable_progress_bar, # callbacks=[], # reload_dataloaders_every_n_epochs=0, # logger=[], # ) # if not cfg.general.test_only: # print("start testing fit method") # trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) # if cfg.general.save_model: # trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") # trainer.test(model, datamodule=datamodule) if __name__ == "__main__": test()