mps / cpu support
This commit is contained in:
@@ -17,7 +17,7 @@ from cotracker.models.build_cotracker import (
|
||||
|
||||
class CoTrackerPredictor(torch.nn.Module):
|
||||
def __init__(
|
||||
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth"
|
||||
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", device=None
|
||||
):
|
||||
super().__init__()
|
||||
self.interp_shape = (384, 512)
|
||||
@@ -25,7 +25,8 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
model = build_cotracker(checkpoint)
|
||||
|
||||
self.model = model
|
||||
self.model.to("cuda")
|
||||
self.device = device or 'cuda'
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -72,7 +73,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
grid_width = W // grid_step
|
||||
grid_height = H // grid_step
|
||||
tracks = visibilities = None
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda")
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device)
|
||||
grid_pts[0, :, 0] = grid_query_frame
|
||||
for offset in tqdm(range(grid_step * grid_step)):
|
||||
ox = offset % grid_step
|
||||
@@ -107,10 +108,10 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
assert B == 1
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda()
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device)
|
||||
video = video.reshape(
|
||||
B, T, 3, self.interp_shape[0], self.interp_shape[1]
|
||||
).cuda()
|
||||
).to(self.device)
|
||||
|
||||
if queries is not None:
|
||||
queries = queries.clone()
|
||||
@@ -119,7 +120,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape)
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device)
|
||||
if segm_mask is not None:
|
||||
segm_mask = F.interpolate(
|
||||
segm_mask, tuple(self.interp_shape), mode="nearest"
|
||||
@@ -136,7 +137,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
)
|
||||
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape)
|
||||
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=self.device)
|
||||
grid_pts = torch.cat(
|
||||
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user