diff --git a/cotracker/predictor.py b/cotracker/predictor.py index cb0945b..6c62889 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -152,6 +152,21 @@ class CoTrackerPredictor(torch.nn.Module): visibilities = visibilities[:, :, : -self.support_grid_size ** 2] thr = 0.9 visibilities = visibilities > thr + + # correct query-point predictions + # see https://github.com/facebookresearch/co-tracker/issues/28 + + # TODO: batchify + for i in range(len(queries)): + queries_t = queries[i, :tracks.size(2), 0].to(torch.int64) + arange = torch.arange(0, len(queries_t)) + + # overwrite the predictions with the query points + tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:] + + # correct visibilities, the query points should be visible + visibilities[i, queries_t, arange] = True + tracks[:, :, :, 0] *= W / float(self.interp_shape[1]) tracks[:, :, :, 1] *= H / float(self.interp_shape[0]) return tracks, visibilities