release cotracker 2.0

This commit is contained in:
Nikita Karaev
2023-12-27 12:54:02 +00:00
parent 3df96621ed
commit f8fab323c4
38 changed files with 2238 additions and 1910 deletions

View File

@@ -6,27 +6,33 @@
import torch
dependencies = ["torch", "einops", "timm", "tqdm"]
_COTRACKER_URL = (
"https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth"
)
_COTRACKER_URL = "https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth"
def _make_cotracker_predictor(*, pretrained: bool = True, **kwargs):
from cotracker.predictor import CoTrackerPredictor
def _make_cotracker_predictor(*, pretrained: bool = True, online=False, **kwargs):
if online:
from cotracker.predictor import CoTrackerOnlinePredictor
predictor = CoTrackerPredictor(checkpoint=None)
predictor = CoTrackerOnlinePredictor(checkpoint=None)
else:
from cotracker.predictor import CoTrackerPredictor
predictor = CoTrackerPredictor(checkpoint=None)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
_COTRACKER_URL, map_location="cpu"
)
state_dict = torch.hub.load_state_dict_from_url(_COTRACKER_URL, map_location="cpu")
predictor.model.load_state_dict(state_dict)
return predictor
def cotracker_w8(*, pretrained: bool = True, **kwargs):
def cotracker2(*, pretrained: bool = True, **kwargs):
"""
CoTracker model with stride 4 and window length 8. (The main model from the paper)
CoTracker2 with stride 4 and window length 8. Can track up to 265*265 points jointly.
"""
return _make_cotracker_predictor(pretrained=pretrained, **kwargs)
return _make_cotracker_predictor(pretrained=pretrained, online=False, **kwargs)
def cotracker2_online(*, pretrained: bool = True, **kwargs):
"""
Online CoTracker2 with stride 4 and window length 8. Can track up to 265*265 points jointly.
"""
return _make_cotracker_predictor(pretrained=pretrained, online=True, **kwargs)