update the gpu id
This commit is contained in:
@@ -175,6 +175,7 @@ def test(cfg: DictConfig):
|
||||
elif cfg.general.resume is not None:
|
||||
cfg, _ = get_resume_adaptive(cfg, model_kwargs)
|
||||
os.chdir(cfg.general.resume.split("checkpoints")[0])
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number
|
||||
model = Graph_DiT(cfg=cfg, **model_kwargs)
|
||||
trainer = Trainer(
|
||||
gradient_clip_val=cfg.train.clip_grad,
|
||||
@@ -182,7 +183,7 @@ def test(cfg: DictConfig):
|
||||
accelerator="gpu"
|
||||
if torch.cuda.is_available() and cfg.general.gpus > 0
|
||||
else "cpu",
|
||||
devices=cfg.general.gpus
|
||||
devices=[cfg.general.gpu_number]
|
||||
if torch.cuda.is_available() and cfg.general.gpus > 0
|
||||
else None,
|
||||
max_epochs=cfg.train.n_epochs,
|
||||
|
Reference in New Issue
Block a user