From 36d15667501a5d1be816d029170ef6cdd00c8c88 Mon Sep 17 00:00:00 2001 From: Hanzhang ma Date: Wed, 10 Jul 2024 00:05:34 +0200 Subject: [PATCH] add some comments --- cotracker/predictor.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/cotracker/predictor.py b/cotracker/predictor.py index baded92..067b50d 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -17,6 +17,7 @@ class CoTrackerPredictor(torch.nn.Module): self.support_grid_size = 6 model = build_cotracker(checkpoint) self.interp_shape = model.model_resolution + print(self.interp_shape) self.model = model self.model.eval() @@ -103,12 +104,16 @@ class CoTrackerPredictor(torch.nn.Module): B, T, C, H, W = video.shape video = video.reshape(B * T, C, H, W) + # ? what is interpolate? + # 将video插值成interp_shape? video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True) video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) if queries is not None: - B, N, D = queries.shape + B, N, D = queries.shape # batch_size, number of points, (t,x,y) assert D == 3 + # query 缩放到( interp_shape - 1 ) / (W - 1) + # 插完值之后缩放 queries = queries.clone() queries[:, :, 1:] *= queries.new_tensor( [ @@ -116,6 +121,7 @@ class CoTrackerPredictor(torch.nn.Module): (self.interp_shape[0] - 1) / (H - 1), ] ) + # 生成grid elif grid_size > 0: grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device) if segm_mask is not None: @@ -131,6 +137,8 @@ class CoTrackerPredictor(torch.nn.Module): dim=2, ).repeat(B, 1, 1) + # 添加支持点 + if add_support_grid: grid_pts = get_points_on_a_grid( self.support_grid_size, self.interp_shape, device=video.device