diff --git a/cotracker/predictor.py b/cotracker/predictor.py index 8dcf3ca..cb0945b 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -133,7 +133,7 @@ class CoTrackerPredictor(torch.nn.Module): ) if add_support_grid: - grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=self.device) + grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device) grid_pts = torch.cat( [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 )