Compare commits

...

10 Commits

Author SHA1 Message Date
Nikita Karaev
19767a9d65
Update README.md 2024-06-14 15:06:32 +01:00
Iurii Makarov
e29e938311
readme.md update, demo flexible save path (#83) 2024-05-11 15:34:09 +01:00
Nikita Karaev
0f9d32869a
Update README.md 2024-01-22 11:59:03 +00:00
Nikita Karaev
9460eefecc
Update README.md 2024-01-09 16:00:07 +00:00
Patrick Pfreundschuh
9921cf0895
fix ignored input video argument (#57) 2024-01-07 15:14:28 +00:00
Nikita Karaev
941c24fd40 add meta copyright 2024-01-05 16:17:50 +00:00
Nikita Karaev
fac27989b3 fixed a small online processing bug 2024-01-05 14:55:54 +00:00
Nikita Karaev
f084a93f28 fix multi-batch inference 2024-01-04 16:53:22 +00:00
Nikita Karaev
3716e36249 fix online demo 2023-12-29 16:12:42 +00:00
Nikita Karaev
721fcc237b remove assert B==1 2023-12-28 17:27:30 +00:00
6 changed files with 58 additions and 17 deletions

View File

@ -4,7 +4,7 @@
[Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/)
[[`Paper`](https://arxiv.org/abs/2307.07635)] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)]
### [Project Page](https://co-tracker.github.io/) | [Paper](https://arxiv.org/abs/2307.07635) | [X Thread](https://twitter.com/n_karaev/status/1742638906355470772) | [BibTeX](#citing-cotracker)
<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/co-tracker/blob/main/notebooks/demo.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
@ -26,6 +26,7 @@ CoTracker can track:
Try these tracking modes for yourself with our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb) or in the [Hugging Face Space 🤗](https://huggingface.co/spaces/facebook/cotracker).
**Updates:**
- [June 14, 2024] 📣 We have released the code for [VGGSfM](https://github.com/facebookresearch/vggsfm), a model for recovering camera poses and 3D structure from any image sequences based on point tracking! VGGSfM is the first fully differentiable SfM framework that unlocks scalability and outperforms conventional SfM methods on standard benchmarks.
- [December 27, 2023] 📣 CoTracker2 is now available! It can now track many more (up to **265*265**!) points jointly and it has a cleaner and more memory-efficient implementation. It also supports online processing. See the [updated paper](https://arxiv.org/abs/2307.07635) for more details. The old version remains available [here](https://github.com/facebookresearch/co-tracker/tree/8d364031971f6b3efec945dd15c468a183e58212).
@ -119,7 +120,7 @@ We strongly recommend installing both PyTorch and TorchVision with CUDA support,
git clone https://github.com/facebookresearch/co-tracker
cd co-tracker
pip install -e .
pip install matplotlib flow_vis tqdm tensorboard
pip install matplotlib flow_vis tqdm tensorboard imageio[ffmpeg]
```
You can manually download the CoTracker2 checkpoint from the links below and place it in the `checkpoints` folder as follows:
@ -132,6 +133,11 @@ cd ..
```
For old checkpoints, see [this section](#previous-version).
After installation, this is how you could run the model on `./assets/apple.mp4` (results will be saved to `./saved_videos/apple.mp4`):
```bash
python demo.py --checkpoint checkpoints/cotracker2.pth
```
## Evaluation
To reproduce the results presented in the paper, download the following datasets:
@ -203,6 +209,15 @@ make -C docs html
## Previous version
You can use CoTracker v1 directly via pytorch hub:
```python
import torch
import einops
import timm
import tqdm
cotracker = torch.hub.load("facebookresearch/co-tracker:v1.0", "cotracker_w8")
```
The old version of the code is available [here](https://github.com/facebookresearch/co-tracker/tree/8d364031971f6b3efec945dd15c468a183e58212).
You can also download the corresponding checkpoints:
```bash

View File

@ -38,7 +38,6 @@ class EvaluationPredictor(torch.nn.Module):
B, N, D = queries.shape
assert D == 3
assert B == 1
video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)

View File

@ -23,11 +23,11 @@ class CoTrackerPredictor(torch.nn.Module):
@torch.no_grad()
def forward(
self,
video, # (1, T, 3, H, W)
video, # (B, T, 3, H, W)
# input prompt types:
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
# *backward_tracking=True* will compute tracks in both directions.
# - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates.
# - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
queries: torch.Tensor = None,
@ -92,7 +92,6 @@ class CoTrackerPredictor(torch.nn.Module):
backward_tracking=False,
):
B, T, C, H, W = video.shape
assert B == 1
video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
@ -121,13 +120,14 @@ class CoTrackerPredictor(torch.nn.Module):
queries = torch.cat(
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
dim=2,
)
).repeat(B, 1, 1)
if add_support_grid:
grid_pts = get_points_on_a_grid(
self.support_grid_size, self.interp_shape, device=video.device
)
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
grid_pts = grid_pts.repeat(B, 1, 1)
queries = torch.cat([queries, grid_pts], dim=1)
tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
@ -174,7 +174,7 @@ class CoTrackerPredictor(torch.nn.Module):
inv_visibilities = inv_visibilities.flip(1)
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
mask = (arange < queries[None, :, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
tracks[mask] = inv_tracks[mask]
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
@ -201,6 +201,7 @@ class CoTrackerOnlinePredictor(torch.nn.Module):
grid_query_frame: int = 0,
add_support_grid=False,
):
B, T, C, H, W = video_chunk.shape
# Initialize online video processing and save queried points
# This needs to be done before processing *each new video*
if is_first_step:
@ -231,7 +232,7 @@ class CoTrackerOnlinePredictor(torch.nn.Module):
queries = torch.cat([queries, grid_pts], dim=1)
self.queries = queries
return (None, None)
B, T, C, H, W = video_chunk.shape
video_chunk = video_chunk.reshape(B * T, C, H, W)
video_chunk = F.interpolate(
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True

View File

@ -83,11 +83,12 @@ if __name__ == "__main__":
print("computed")
# save a video with predicted tracks
seq_name = args.video_path.split("/")[-1]
seq_name = os.path.splitext(args.video_path.split("/")[-1])[0]
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
vis.visualize(
video,
pred_tracks,
pred_visibility,
query_frame=0 if args.backward_tracking else args.grid_query_frame,
filename=seq_name,
)

View File

@ -1,3 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import gradio as gr
@ -22,7 +29,12 @@ def cotracker_demo(
model = model.cuda()
load_video = load_video.cuda()
model(video_chunk=load_video, is_first_step=True, grid_size=grid_size)
model(
video_chunk=load_video,
is_first_step=True,
grid_size=grid_size,
grid_query_frame=grid_query_frame,
)
for ind in range(0, load_video.shape[1] - model.step, model.step):
pred_tracks, pred_visibility = model(
video_chunk=load_video[:, ind : ind + model.step * 2]

View File

@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import argparse
import imageio.v3 as iio
@ -44,6 +45,9 @@ if __name__ == "__main__":
args = parser.parse_args()
if not os.path.isfile(args.video_path):
raise ValueError("Video file does not exist")
if args.checkpoint is not None:
model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint)
else:
@ -52,25 +56,33 @@ if __name__ == "__main__":
window_frames = []
def _process_step(window_frames, is_first_step, grid_size):
def _process_step(window_frames, is_first_step, grid_size, grid_query_frame):
video_chunk = (
torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE)
.float()
.permute(0, 3, 1, 2)[None]
) # (1, T, 3, H, W)
return model(video_chunk, is_first_step=is_first_step, grid_size=grid_size)
return model(
video_chunk,
is_first_step=is_first_step,
grid_size=grid_size,
grid_query_frame=grid_query_frame,
)
# Iterating over video frames, processing one window at a time:
is_first_step = True
for i, frame in enumerate(
iio.imiter(
"https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4",
args.video_path,
plugin="FFMPEG",
)
):
if i % model.step == 0 and i != 0:
pred_tracks, pred_visibility = _process_step(
window_frames, is_first_step, grid_size=args.grid_size
window_frames,
is_first_step,
grid_size=args.grid_size,
grid_query_frame=args.grid_query_frame,
)
is_first_step = False
window_frames.append(frame)
@ -79,12 +91,13 @@ if __name__ == "__main__":
window_frames[-(i % model.step) - model.step - 1 :],
is_first_step,
grid_size=args.grid_size,
grid_query_frame=args.grid_query_frame,
)
print("Tracks are computed")
# save a video with predicted tracks
seq_name = args.video_path.split("/")[-1]
seq_name = os.path.splitext(args.video_path.split("/")[-1])[0]
video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None]
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame, filename=seq_name)