mps / cpu support

This commit is contained in:
JunkyByte
2023-07-25 16:17:29 +02:00
parent c6878420f5
commit 5890fbd16d
4 changed files with 23 additions and 12 deletions

View File

@@ -17,7 +17,7 @@ from cotracker.models.build_cotracker import (
class CoTrackerPredictor(torch.nn.Module):
def __init__(
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth"
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", device=None
):
super().__init__()
self.interp_shape = (384, 512)
@@ -25,7 +25,8 @@ class CoTrackerPredictor(torch.nn.Module):
model = build_cotracker(checkpoint)
self.model = model
self.model.to("cuda")
self.device = device or 'cuda'
self.model.to(self.device)
self.model.eval()
@torch.no_grad()
@@ -72,7 +73,7 @@ class CoTrackerPredictor(torch.nn.Module):
grid_width = W // grid_step
grid_height = H // grid_step
tracks = visibilities = None
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda")
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device)
grid_pts[0, :, 0] = grid_query_frame
for offset in tqdm(range(grid_step * grid_step)):
ox = offset % grid_step
@@ -107,10 +108,10 @@ class CoTrackerPredictor(torch.nn.Module):
assert B == 1
video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda()
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device)
video = video.reshape(
B, T, 3, self.interp_shape[0], self.interp_shape[1]
).cuda()
).to(self.device)
if queries is not None:
queries = queries.clone()
@@ -119,7 +120,7 @@ class CoTrackerPredictor(torch.nn.Module):
queries[:, :, 1] *= self.interp_shape[1] / W
queries[:, :, 2] *= self.interp_shape[0] / H
elif grid_size > 0:
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape)
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device)
if segm_mask is not None:
segm_mask = F.interpolate(
segm_mask, tuple(self.interp_shape), mode="nearest"
@@ -136,7 +137,7 @@ class CoTrackerPredictor(torch.nn.Module):
)
if add_support_grid:
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape)
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=self.device)
grid_pts = torch.cat(
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
)