diff --git a/cotracker/predictor.py b/cotracker/predictor.py index f36b500..cb0945b 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -133,7 +133,7 @@ class CoTrackerPredictor(torch.nn.Module): ) if add_support_grid: - grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape) + grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device) grid_pts = torch.cat( [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 ) diff --git a/cotracker/utils/visualizer.py b/cotracker/utils/visualizer.py index 2187c18..704b591 100644 --- a/cotracker/utils/visualizer.py +++ b/cotracker/utils/visualizer.py @@ -62,6 +62,7 @@ class Visualizer: self, video: torch.Tensor, # (B,T,C,H,W) tracks: torch.Tensor, # (B,T,N,2) + visibility: torch.Tensor = None, # (B, T, N, 1) bool gt_tracks: torch.Tensor = None, # (B,T,N,2) segm_mask: torch.Tensor = None, # (B,1,H,W) filename: str = "video", @@ -93,6 +94,7 @@ class Visualizer: res_video = self.draw_tracks_on_video( video=video, tracks=tracks, + visibility=visibility, segm_mask=segm_mask, gt_tracks=gt_tracks, query_frame=query_frame, @@ -126,6 +128,7 @@ class Visualizer: self, video: torch.Tensor, tracks: torch.Tensor, + visibility: torch.Tensor = None, segm_mask: torch.Tensor = None, gt_tracks=None, query_frame: int = 0, @@ -227,11 +230,13 @@ class Visualizer: if not compensate_for_camera_motion or ( compensate_for_camera_motion and segm_mask[i] > 0 ): + cv2.circle( res_video[t], coord, int(self.linewidth * 2), vector_colors[t, i].tolist(), + thickness=-1 if visibility[0, t, i] else 2 -1, ) diff --git a/demo.py b/demo.py index fd24fed..5b72948 100644 --- a/demo.py +++ b/demo.py @@ -14,6 +14,9 @@ from PIL import Image from cotracker.utils.visualizer import Visualizer, read_video_from_path from cotracker.predictor import CoTrackerPredictor +DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else + 'mps' if torch.backends.mps.is_available() else + 'cpu') if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -55,11 +58,8 @@ if __name__ == "__main__": segm_mask = torch.from_numpy(segm_mask)[None, None] model = CoTrackerPredictor(checkpoint=args.checkpoint) - if torch.cuda.is_available(): - model = model.cuda() - video = video.cuda() - else: - print("CUDA is not available!") + model = model.to(DEFAULT_DEVICE) + video = video.to(DEFAULT_DEVICE) pred_tracks, pred_visibility = model( video, @@ -73,4 +73,4 @@ if __name__ == "__main__": # save a video with predicted tracks seq_name = args.video_path.split("/")[-1] vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) - vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame) + vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)