From 5890fbd16df766cd763155c479f051ace6582edf Mon Sep 17 00:00:00 2001 From: JunkyByte Date: Tue, 25 Jul 2023 16:17:29 +0200 Subject: [PATCH 1/4] mps / cpu support --- cotracker/models/core/cotracker/cotracker.py | 8 ++++---- cotracker/predictor.py | 15 ++++++++------- cotracker/utils/visualizer.py | 5 +++++ demo.py | 7 ++++++- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/cotracker/models/core/cotracker/cotracker.py b/cotracker/models/core/cotracker/cotracker.py index c9eca1f..e9e6e89 100644 --- a/cotracker/models/core/cotracker/cotracker.py +++ b/cotracker/models/core/cotracker/cotracker.py @@ -25,14 +25,14 @@ from cotracker.models.core.embeddings import ( torch.manual_seed(0) -def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): +def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device='cuda'): if grid_size == 1: return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[ None, None - ].cuda() + ].to(device) grid_y, grid_x = meshgrid2d( - 1, grid_size, grid_size, stack=False, norm=False, device="cuda" + 1, grid_size, grid_size, stack=False, norm=False, device=device ) step = interp_shape[1] // 64 if grid_center[0] != 0 or grid_center[1] != 0: @@ -47,7 +47,7 @@ def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): grid_y = grid_y + grid_center[0] grid_x = grid_x + grid_center[1] - xy = torch.stack([grid_x, grid_y], dim=-1).cuda() + xy = torch.stack([grid_x, grid_y], dim=-1).to(device) return xy diff --git a/cotracker/predictor.py b/cotracker/predictor.py index a4e2003..9a4bf20 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -17,7 +17,7 @@ from cotracker.models.build_cotracker import ( class CoTrackerPredictor(torch.nn.Module): def __init__( - self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth" + self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", device=None ): super().__init__() self.interp_shape = (384, 512) @@ -25,7 +25,8 @@ class CoTrackerPredictor(torch.nn.Module): model = build_cotracker(checkpoint) self.model = model - self.model.to("cuda") + self.device = device or 'cuda' + self.model.to(self.device) self.model.eval() @torch.no_grad() @@ -72,7 +73,7 @@ class CoTrackerPredictor(torch.nn.Module): grid_width = W // grid_step grid_height = H // grid_step tracks = visibilities = None - grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda") + grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device) grid_pts[0, :, 0] = grid_query_frame for offset in tqdm(range(grid_step * grid_step)): ox = offset % grid_step @@ -107,10 +108,10 @@ class CoTrackerPredictor(torch.nn.Module): assert B == 1 video = video.reshape(B * T, C, H, W) - video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda() + video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device) video = video.reshape( B, T, 3, self.interp_shape[0], self.interp_shape[1] - ).cuda() + ).to(self.device) if queries is not None: queries = queries.clone() @@ -119,7 +120,7 @@ class CoTrackerPredictor(torch.nn.Module): queries[:, :, 1] *= self.interp_shape[1] / W queries[:, :, 2] *= self.interp_shape[0] / H elif grid_size > 0: - grid_pts = get_points_on_a_grid(grid_size, self.interp_shape) + grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device) if segm_mask is not None: segm_mask = F.interpolate( segm_mask, tuple(self.interp_shape), mode="nearest" @@ -136,7 +137,7 @@ class CoTrackerPredictor(torch.nn.Module): ) if add_support_grid: - grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape) + grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=self.device) grid_pts = torch.cat( [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 ) diff --git a/cotracker/utils/visualizer.py b/cotracker/utils/visualizer.py index 404436b..057738e 100644 --- a/cotracker/utils/visualizer.py +++ b/cotracker/utils/visualizer.py @@ -63,6 +63,7 @@ class Visualizer: self, video: torch.Tensor, # (B,T,C,H,W) tracks: torch.Tensor, # (B,T,N,2) + visibility: torch.Tensor, # (B, T, N, 1) bool gt_tracks: torch.Tensor = None, # (B,T,N,2) segm_mask: torch.Tensor = None, # (B,1,H,W) filename: str = "video", @@ -94,6 +95,7 @@ class Visualizer: res_video = self.draw_tracks_on_video( video=video, tracks=tracks, + visibility=visibility, segm_mask=segm_mask, gt_tracks=gt_tracks, query_frame=query_frame, @@ -127,6 +129,7 @@ class Visualizer: self, video: torch.Tensor, tracks: torch.Tensor, + visibility: torch.Tensor, segm_mask: torch.Tensor = None, gt_tracks=None, query_frame: int = 0, @@ -228,11 +231,13 @@ class Visualizer: if not compensate_for_camera_motion or ( compensate_for_camera_motion and segm_mask[i] > 0 ): + cv2.circle( res_video[t], coord, int(self.linewidth * 2), vector_colors[t, i].tolist(), + thickness=-1 if visibility[0, t, i] else 2 -1, ) diff --git a/demo.py b/demo.py index 7b2e1c7..a166a35 100644 --- a/demo.py +++ b/demo.py @@ -32,6 +32,11 @@ if __name__ == "__main__": default="./checkpoints/cotracker_stride_4_wind_8.pth", help="cotracker model", ) + parser.add_argument( + "--device", + default="cuda", + help="Device to use for inference", + ) parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size") parser.add_argument( "--grid_query_frame", @@ -54,7 +59,7 @@ if __name__ == "__main__": segm_mask = np.array(Image.open(os.path.join(args.mask_path))) segm_mask = torch.from_numpy(segm_mask)[None, None] - model = CoTrackerPredictor(checkpoint=args.checkpoint) + model = CoTrackerPredictor(checkpoint=args.checkpoint, device=args.device) pred_tracks, pred_visibility = model( video, From d6df5d248fece287a001dc6113e5af015169df76 Mon Sep 17 00:00:00 2001 From: JunkyByte Date: Tue, 25 Jul 2023 16:28:48 +0200 Subject: [PATCH 2/4] Allows MPS inference. Fix visualization args --- cotracker/predictor.py | 2 +- demo.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cotracker/predictor.py b/cotracker/predictor.py index 541a129..8dcf3ca 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -17,7 +17,7 @@ from cotracker.models.build_cotracker import ( class CoTrackerPredictor(torch.nn.Module): def __init__( - self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", device=None + self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth" ): super().__init__() self.interp_shape = (384, 512) diff --git a/demo.py b/demo.py index fd24fed..5b72948 100644 --- a/demo.py +++ b/demo.py @@ -14,6 +14,9 @@ from PIL import Image from cotracker.utils.visualizer import Visualizer, read_video_from_path from cotracker.predictor import CoTrackerPredictor +DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else + 'mps' if torch.backends.mps.is_available() else + 'cpu') if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -55,11 +58,8 @@ if __name__ == "__main__": segm_mask = torch.from_numpy(segm_mask)[None, None] model = CoTrackerPredictor(checkpoint=args.checkpoint) - if torch.cuda.is_available(): - model = model.cuda() - video = video.cuda() - else: - print("CUDA is not available!") + model = model.to(DEFAULT_DEVICE) + video = video.to(DEFAULT_DEVICE) pred_tracks, pred_visibility = model( video, @@ -73,4 +73,4 @@ if __name__ == "__main__": # save a video with predicted tracks seq_name = args.video_path.split("/")[-1] vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) - vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame) + vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame) From 51175e006acd8a0a5c507cc2a4bd205228c8fbc1 Mon Sep 17 00:00:00 2001 From: JunkyByte Date: Tue, 25 Jul 2023 16:30:38 +0200 Subject: [PATCH 3/4] allow no visibility --- cotracker/utils/visualizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cotracker/utils/visualizer.py b/cotracker/utils/visualizer.py index 696bc85..704b591 100644 --- a/cotracker/utils/visualizer.py +++ b/cotracker/utils/visualizer.py @@ -62,7 +62,7 @@ class Visualizer: self, video: torch.Tensor, # (B,T,C,H,W) tracks: torch.Tensor, # (B,T,N,2) - visibility: torch.Tensor, # (B, T, N, 1) bool + visibility: torch.Tensor = None, # (B, T, N, 1) bool gt_tracks: torch.Tensor = None, # (B,T,N,2) segm_mask: torch.Tensor = None, # (B,1,H,W) filename: str = "video", @@ -128,7 +128,7 @@ class Visualizer: self, video: torch.Tensor, tracks: torch.Tensor, - visibility: torch.Tensor, + visibility: torch.Tensor = None, segm_mask: torch.Tensor = None, gt_tracks=None, query_frame: int = 0, From 4a9286e17f59bae9f9b15219ce8aabce3f10dd5f Mon Sep 17 00:00:00 2001 From: JunkyByte Date: Tue, 25 Jul 2023 16:31:44 +0200 Subject: [PATCH 4/4] fix --- cotracker/predictor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cotracker/predictor.py b/cotracker/predictor.py index 8dcf3ca..cb0945b 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -133,7 +133,7 @@ class CoTrackerPredictor(torch.nn.Module): ) if add_support_grid: - grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=self.device) + grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device) grid_pts = torch.cat( [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 )