Merge branch 'main' of github.com:JunkyByte/co-tracker
This commit is contained in:
@@ -25,8 +25,6 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
model = build_cotracker(checkpoint)
|
||||
|
||||
self.model = model
|
||||
self.device = device or 'cuda'
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -73,7 +71,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
grid_width = W // grid_step
|
||||
grid_height = H // grid_step
|
||||
tracks = visibilities = None
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device)
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
||||
grid_pts[0, :, 0] = grid_query_frame
|
||||
for offset in tqdm(range(grid_step * grid_step)):
|
||||
ox = offset % grid_step
|
||||
@@ -108,10 +106,8 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
assert B == 1
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device)
|
||||
video = video.reshape(
|
||||
B, T, 3, self.interp_shape[0], self.interp_shape[1]
|
||||
).to(self.device)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
if queries is not None:
|
||||
queries = queries.clone()
|
||||
@@ -120,7 +116,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device)
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
||||
if segm_mask is not None:
|
||||
segm_mask = F.interpolate(
|
||||
segm_mask, tuple(self.interp_shape), mode="nearest"
|
||||
|
||||
Reference in New Issue
Block a user