diff --git a/cotracker/models/core/cotracker/cotracker.py b/cotracker/models/core/cotracker/cotracker.py index c9eca1f..e9e6e89 100644 --- a/cotracker/models/core/cotracker/cotracker.py +++ b/cotracker/models/core/cotracker/cotracker.py @@ -25,14 +25,14 @@ from cotracker.models.core.embeddings import ( torch.manual_seed(0) -def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): +def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device='cuda'): if grid_size == 1: return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[ None, None - ].cuda() + ].to(device) grid_y, grid_x = meshgrid2d( - 1, grid_size, grid_size, stack=False, norm=False, device="cuda" + 1, grid_size, grid_size, stack=False, norm=False, device=device ) step = interp_shape[1] // 64 if grid_center[0] != 0 or grid_center[1] != 0: @@ -47,7 +47,7 @@ def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): grid_y = grid_y + grid_center[0] grid_x = grid_x + grid_center[1] - xy = torch.stack([grid_x, grid_y], dim=-1).cuda() + xy = torch.stack([grid_x, grid_y], dim=-1).to(device) return xy diff --git a/cotracker/predictor.py b/cotracker/predictor.py index a4e2003..9a4bf20 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -17,7 +17,7 @@ from cotracker.models.build_cotracker import ( class CoTrackerPredictor(torch.nn.Module): def __init__( - self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth" + self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", device=None ): super().__init__() self.interp_shape = (384, 512) @@ -25,7 +25,8 @@ class CoTrackerPredictor(torch.nn.Module): model = build_cotracker(checkpoint) self.model = model - self.model.to("cuda") + self.device = device or 'cuda' + self.model.to(self.device) self.model.eval() @torch.no_grad() @@ -72,7 +73,7 @@ class CoTrackerPredictor(torch.nn.Module): grid_width = W // grid_step grid_height = H // grid_step tracks = visibilities = None - grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda") + grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device) grid_pts[0, :, 0] = grid_query_frame for offset in tqdm(range(grid_step * grid_step)): ox = offset % grid_step @@ -107,10 +108,10 @@ class CoTrackerPredictor(torch.nn.Module): assert B == 1 video = video.reshape(B * T, C, H, W) - video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda() + video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device) video = video.reshape( B, T, 3, self.interp_shape[0], self.interp_shape[1] - ).cuda() + ).to(self.device) if queries is not None: queries = queries.clone() @@ -119,7 +120,7 @@ class CoTrackerPredictor(torch.nn.Module): queries[:, :, 1] *= self.interp_shape[1] / W queries[:, :, 2] *= self.interp_shape[0] / H elif grid_size > 0: - grid_pts = get_points_on_a_grid(grid_size, self.interp_shape) + grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device) if segm_mask is not None: segm_mask = F.interpolate( segm_mask, tuple(self.interp_shape), mode="nearest" @@ -136,7 +137,7 @@ class CoTrackerPredictor(torch.nn.Module): ) if add_support_grid: - grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape) + grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=self.device) grid_pts = torch.cat( [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 ) diff --git a/cotracker/utils/visualizer.py b/cotracker/utils/visualizer.py index 404436b..057738e 100644 --- a/cotracker/utils/visualizer.py +++ b/cotracker/utils/visualizer.py @@ -63,6 +63,7 @@ class Visualizer: self, video: torch.Tensor, # (B,T,C,H,W) tracks: torch.Tensor, # (B,T,N,2) + visibility: torch.Tensor, # (B, T, N, 1) bool gt_tracks: torch.Tensor = None, # (B,T,N,2) segm_mask: torch.Tensor = None, # (B,1,H,W) filename: str = "video", @@ -94,6 +95,7 @@ class Visualizer: res_video = self.draw_tracks_on_video( video=video, tracks=tracks, + visibility=visibility, segm_mask=segm_mask, gt_tracks=gt_tracks, query_frame=query_frame, @@ -127,6 +129,7 @@ class Visualizer: self, video: torch.Tensor, tracks: torch.Tensor, + visibility: torch.Tensor, segm_mask: torch.Tensor = None, gt_tracks=None, query_frame: int = 0, @@ -228,11 +231,13 @@ class Visualizer: if not compensate_for_camera_motion or ( compensate_for_camera_motion and segm_mask[i] > 0 ): + cv2.circle( res_video[t], coord, int(self.linewidth * 2), vector_colors[t, i].tolist(), + thickness=-1 if visibility[0, t, i] else 2 -1, ) diff --git a/demo.py b/demo.py index 7b2e1c7..a166a35 100644 --- a/demo.py +++ b/demo.py @@ -32,6 +32,11 @@ if __name__ == "__main__": default="./checkpoints/cotracker_stride_4_wind_8.pth", help="cotracker model", ) + parser.add_argument( + "--device", + default="cuda", + help="Device to use for inference", + ) parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size") parser.add_argument( "--grid_query_frame", @@ -54,7 +59,7 @@ if __name__ == "__main__": segm_mask = np.array(Image.open(os.path.join(args.mask_path))) segm_mask = torch.from_numpy(segm_mask)[None, None] - model = CoTrackerPredictor(checkpoint=args.checkpoint) + model = CoTrackerPredictor(checkpoint=args.checkpoint, device=args.device) pred_tracks, pred_visibility = model( video,