release cotracker 2.0
This commit is contained in:
90
online_demo.py
Normal file
90
online_demo.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
import imageio.v3 as iio
|
||||
import numpy as np
|
||||
|
||||
from cotracker.utils.visualizer import Visualizer
|
||||
from cotracker.predictor import CoTrackerOnlinePredictor
|
||||
|
||||
# Unfortunately MPS acceleration does not support all the features we require,
|
||||
# but we may be able to enable it in the future
|
||||
|
||||
DEFAULT_DEVICE = (
|
||||
# "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--video_path",
|
||||
default="./assets/apple.mp4",
|
||||
help="path to a video",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
default=None,
|
||||
help="CoTracker model parameters",
|
||||
)
|
||||
parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size")
|
||||
parser.add_argument(
|
||||
"--grid_query_frame",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Compute dense and grid tracks starting from this frame",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.checkpoint is not None:
|
||||
model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint)
|
||||
else:
|
||||
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online")
|
||||
model = model.to(DEFAULT_DEVICE)
|
||||
|
||||
window_frames = []
|
||||
|
||||
def _process_step(window_frames, is_first_step, grid_size):
|
||||
video_chunk = (
|
||||
torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE)
|
||||
.float()
|
||||
.permute(0, 3, 1, 2)[None]
|
||||
) # (1, T, 3, H, W)
|
||||
return model(video_chunk, is_first_step=is_first_step, grid_size=grid_size)
|
||||
|
||||
# Iterating over video frames, processing one window at a time:
|
||||
is_first_step = True
|
||||
for i, frame in enumerate(
|
||||
iio.imiter(
|
||||
"https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4",
|
||||
plugin="FFMPEG",
|
||||
)
|
||||
):
|
||||
if i % model.step == 0 and i != 0:
|
||||
pred_tracks, pred_visibility = _process_step(
|
||||
window_frames, is_first_step, grid_size=args.grid_size
|
||||
)
|
||||
is_first_step = False
|
||||
window_frames.append(frame)
|
||||
# Processing the final video frames in case video length is not a multiple of model.step
|
||||
pred_tracks, pred_visibility = _process_step(
|
||||
window_frames[-(i % model.step) - model.step - 1 :],
|
||||
is_first_step,
|
||||
grid_size=args.grid_size,
|
||||
)
|
||||
|
||||
print("Tracks are computed")
|
||||
|
||||
# save a video with predicted tracks
|
||||
seq_name = args.video_path.split("/")[-1]
|
||||
video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None]
|
||||
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
||||
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)
|
||||
Reference in New Issue
Block a user