Merge branch 'main' of github.com:JunkyByte/co-tracker

This commit is contained in:
JunkyByte
2023-07-25 16:23:37 +02:00
12 changed files with 236 additions and 138 deletions

View File

@@ -25,8 +25,6 @@ class CoTrackerPredictor(torch.nn.Module):
model = build_cotracker(checkpoint)
self.model = model
self.device = device or 'cuda'
self.model.to(self.device)
self.model.eval()
@torch.no_grad()
@@ -73,7 +71,7 @@ class CoTrackerPredictor(torch.nn.Module):
grid_width = W // grid_step
grid_height = H // grid_step
tracks = visibilities = None
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device)
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
grid_pts[0, :, 0] = grid_query_frame
for offset in tqdm(range(grid_step * grid_step)):
ox = offset % grid_step
@@ -108,10 +106,8 @@ class CoTrackerPredictor(torch.nn.Module):
assert B == 1
video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device)
video = video.reshape(
B, T, 3, self.interp_shape[0], self.interp_shape[1]
).to(self.device)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
if queries is not None:
queries = queries.clone()
@@ -120,7 +116,7 @@ class CoTrackerPredictor(torch.nn.Module):
queries[:, :, 1] *= self.interp_shape[1] / W
queries[:, :, 2] *= self.interp_shape[0] / H
elif grid_size > 0:
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device)
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
if segm_mask is not None:
segm_mask = F.interpolate(
segm_mask, tuple(self.interp_shape), mode="nearest"