From d6df5d248fece287a001dc6113e5af015169df76 Mon Sep 17 00:00:00 2001 From: JunkyByte Date: Tue, 25 Jul 2023 16:28:48 +0200 Subject: [PATCH] Allows MPS inference. Fix visualization args --- cotracker/predictor.py | 2 +- demo.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cotracker/predictor.py b/cotracker/predictor.py index 541a129..8dcf3ca 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -17,7 +17,7 @@ from cotracker.models.build_cotracker import ( class CoTrackerPredictor(torch.nn.Module): def __init__( - self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", device=None + self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth" ): super().__init__() self.interp_shape = (384, 512) diff --git a/demo.py b/demo.py index fd24fed..5b72948 100644 --- a/demo.py +++ b/demo.py @@ -14,6 +14,9 @@ from PIL import Image from cotracker.utils.visualizer import Visualizer, read_video_from_path from cotracker.predictor import CoTrackerPredictor +DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else + 'mps' if torch.backends.mps.is_available() else + 'cpu') if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -55,11 +58,8 @@ if __name__ == "__main__": segm_mask = torch.from_numpy(segm_mask)[None, None] model = CoTrackerPredictor(checkpoint=args.checkpoint) - if torch.cuda.is_available(): - model = model.cuda() - video = video.cuda() - else: - print("CUDA is not available!") + model = model.to(DEFAULT_DEVICE) + video = video.to(DEFAULT_DEVICE) pred_tracks, pred_visibility = model( video, @@ -73,4 +73,4 @@ if __name__ == "__main__": # save a video with predicted tracks seq_name = args.video_path.split("/")[-1] vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) - vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame) + vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)