From 4f297a92fe1a684b1b0980da138b706d62e45472 Mon Sep 17 00:00:00 2001 From: Ernie Chu <51432514+ernestchu@users.noreply.github.com> Date: Thu, 14 Sep 2023 18:20:02 +0800 Subject: [PATCH] correct query-point predictions (#32) --- cotracker/predictor.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/cotracker/predictor.py b/cotracker/predictor.py index cb0945b..6c62889 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -152,6 +152,21 @@ class CoTrackerPredictor(torch.nn.Module): visibilities = visibilities[:, :, : -self.support_grid_size ** 2] thr = 0.9 visibilities = visibilities > thr + + # correct query-point predictions + # see https://github.com/facebookresearch/co-tracker/issues/28 + + # TODO: batchify + for i in range(len(queries)): + queries_t = queries[i, :tracks.size(2), 0].to(torch.int64) + arange = torch.arange(0, len(queries_t)) + + # overwrite the predictions with the query points + tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:] + + # correct visibilities, the query points should be visible + visibilities[i, queries_t, arange] = True + tracks[:, :, :, 0] *= W / float(self.interp_shape[1]) tracks[:, :, :, 1] *= H / float(self.interp_shape[0]) return tracks, visibilities