diff --git a/cotracker/models/build_cotracker.py b/cotracker/models/build_cotracker.py index 40b19b5..b922d6c 100644 --- a/cotracker/models/build_cotracker.py +++ b/cotracker/models/build_cotracker.py @@ -12,6 +12,8 @@ from cotracker.models.core.cotracker.cotracker import CoTracker def build_cotracker( checkpoint: str, ): + if checkpoint is None: + return build_cotracker_stride_4_wind_8() model_name = checkpoint.split("/")[-1].split(".")[0] if model_name == "cotracker_stride_4_wind_8": return build_cotracker_stride_4_wind_8(checkpoint=checkpoint) diff --git a/cotracker/predictor.py b/cotracker/predictor.py index a4e2003..e664240 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -25,7 +25,6 @@ class CoTrackerPredictor(torch.nn.Module): model = build_cotracker(checkpoint) self.model = model - self.model.to("cuda") self.model.eval() @torch.no_grad() @@ -72,7 +71,7 @@ class CoTrackerPredictor(torch.nn.Module): grid_width = W // grid_step grid_height = H // grid_step tracks = visibilities = None - grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda") + grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device) grid_pts[0, :, 0] = grid_query_frame for offset in tqdm(range(grid_step * grid_step)): ox = offset % grid_step @@ -107,10 +106,8 @@ class CoTrackerPredictor(torch.nn.Module): assert B == 1 video = video.reshape(B * T, C, H, W) - video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda() - video = video.reshape( - B, T, 3, self.interp_shape[0], self.interp_shape[1] - ).cuda() + video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear") + video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) if queries is not None: queries = queries.clone() diff --git a/demo.py b/demo.py index 7b2e1c7..fd24fed 100644 --- a/demo.py +++ b/demo.py @@ -55,6 +55,11 @@ if __name__ == "__main__": segm_mask = torch.from_numpy(segm_mask)[None, None] model = CoTrackerPredictor(checkpoint=args.checkpoint) + if torch.cuda.is_available(): + model = model.cuda() + video = video.cuda() + else: + print("CUDA is not available!") pred_tracks, pred_visibility = model( video, diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000..c9ceac4 --- /dev/null +++ b/hubconf.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +dependencies = ["torch", "einops", "timm", "tqdm"] + +_COTRACKER_URL = ( + "https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth" +) + + +def _make_cotracker_predictor(*, pretrained: bool = True, **kwargs): + from cotracker.predictor import CoTrackerPredictor + + predictor = CoTrackerPredictor(checkpoint=None) + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + _COTRACKER_URL, map_location="cpu" + ) + predictor.model.load_state_dict(state_dict) + return predictor + + +def cotracker_w8(*, pretrained: bool = True, **kwargs): + """ + CoTracker model with stride 4 and window length 8. (The main model from the paper) + """ + return _make_cotracker_predictor(pretrained=pretrained, **kwargs) diff --git a/notebooks/demo.ipynb b/notebooks/demo.ipynb index 3367090..7291235 100644 --- a/notebooks/demo.ipynb +++ b/notebooks/demo.ipynb @@ -37,9 +37,13 @@ }, { "cell_type": "markdown", - "id": "6757bfa3-d663-4a54-9722-3e1a7da3307c", + "id": "88c6db31", "metadata": {}, "source": [ + "Don't forget to turn on GPU support if you're running this demo in Colab. \n", + "\n", + "**Runtime** -> **Change runtime type** -> **Hardware accelerator** -> **GPU**\n", + "\n", "Let's install dependencies for Colab:" ] }, @@ -61,10 +65,26 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "id": "1745a859-71d4-4ec3-8ef3-027cabe786d4", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/private/home/nikitakaraev/dev/neurips_2023/co-tracker\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/private/home/nikitakaraev/.conda/envs/stereoformer/lib/python3.8/site-packages/requests/__init__.py:109: RequestsDependencyWarning: urllib3 (1.26.14) or chardet (None)/charset_normalizer (3.2.0) doesn't match a supported version!\n", + " warnings.warn(\n" + ] + } + ], "source": [ "%cd ..\n", "import os\n", @@ -85,7 +105,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 2, "id": "f1f9ca4d-951e-49d2-8844-91f7bcadfecd", "metadata": {}, "outputs": [], @@ -96,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 3, "id": "fb4c2e9d-0e85-4c10-81a2-827d0759bf87", "metadata": {}, "outputs": [ @@ -109,7 +129,7 @@ "" ] }, - "execution_count": 16, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -133,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 5, "id": "d59ac40b-bde8-46d4-bd57-4ead939f22ca", "metadata": {}, "outputs": [], @@ -147,6 +167,18 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3f2a4485", + "metadata": {}, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " model=model.cuda()\n", + " video=video.cuda()" + ] + }, { "cell_type": "markdown", "id": "e8398155-6dae-4ff0-95f3-dbb52ac70d20", @@ -157,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 7, "id": "17fcaae9-7b3c-474c-977a-cce08a09d580", "metadata": {}, "outputs": [], @@ -175,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 8, "id": "7e793ce0-7b77-46ca-a629-155a6a146000", "metadata": {}, "outputs": [ @@ -194,20 +226,20 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 9, "id": "2d0733ba-8fe1-4cd4-b963-2085202fba13", "metadata": {}, "outputs": [ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" ] }, - "execution_count": 20, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -240,7 +272,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 10, "id": "c6422e7c-8c6f-4269-92c3-245344afe35b", "metadata": {}, "outputs": [], @@ -263,7 +295,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 11, "id": "d7141079-d7e0-40b3-b031-a28879c4bd6d", "metadata": {}, "outputs": [ @@ -309,7 +341,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 12, "id": "09008ca9-6a87-494f-8b05-6370cae6a600", "metadata": {}, "outputs": [], @@ -328,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 13, "id": "01467f8d-667c-4f41-b418-93132584c659", "metadata": {}, "outputs": [ @@ -355,20 +387,20 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 14, "id": "fe23d210-ed90-49f1-8311-b7e354c7a9f6", "metadata": {}, "outputs": [ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" ] }, - "execution_count": 25, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -403,7 +435,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 15, "id": "c880f3ca-cf42-4f64-9df6-a0e8de6561dc", "metadata": {}, "outputs": [], @@ -414,7 +446,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 16, "id": "3cd58820-7b23-469e-9b6d-5fa81257981f", "metadata": {}, "outputs": [], @@ -424,7 +456,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 17, "id": "25a85a1d-dce0-4e6b-9f7a-aaf31ade0600", "metadata": {}, "outputs": [ @@ -450,25 +482,25 @@ "id": "ce0fb5b8-d249-4f4e-b59a-51b4f03972c4", "metadata": {}, "source": [ - "Notice that tracking starts only from points sampled on a frame in the middle of the video. This is different from the grid in the first example:" + "Note that tracking starts only from points sampled on a frame in the middle of the video. This is different from the grid in the first example:" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 18, "id": "f0b01d51-9222-472b-a714-188c38d83ad9", "metadata": {}, "outputs": [ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" ] }, - "execution_count": 29, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -495,7 +527,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 19, "id": "506233dc-1fb3-4a3c-b9eb-5cbd5df49128", "metadata": {}, "outputs": [], @@ -514,7 +546,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 20, "id": "677cf34e-6c6a-49e3-a21b-f8a4f718f916", "metadata": {}, "outputs": [ @@ -545,20 +577,20 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 21, "id": "c8d64ab0-7e92-4238-8e7d-178652fc409c", "metadata": {}, "outputs": [ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" ] }, - "execution_count": 32, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -586,7 +618,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 22, "id": "b759548d-1eda-473e-9c90-99e5d3197e20", "metadata": {}, "outputs": [], @@ -598,7 +630,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 23, "id": "14ae8a8b-fec7-40d1-b6f2-10e333b75db4", "metadata": {}, "outputs": [], @@ -617,17 +649,17 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 24, "id": "4d2efd4e-22df-4833-b9a0-a0763d59ee22", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 35, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" }, @@ -648,7 +680,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 25, "id": "b42dce24-7952-4660-8298-4c362d6913cf", "metadata": {}, "outputs": [ @@ -683,20 +715,20 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 26, "id": "1810440f-00f4-488a-a174-36be05949e42", "metadata": {}, "outputs": [ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" ] }, - "execution_count": 41, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -731,7 +763,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 27, "id": "379557d9-80ea-4316-91df-4da215193b41", "metadata": {}, "outputs": [ @@ -741,7 +773,7 @@ "torch.Size([1, 48, 3, 719, 1282])" ] }, - "execution_count": 38, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -752,13 +784,13 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 28, "id": "c6db5cc7-351d-4d9e-9b9d-3a40f05b077a", "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", - "video_interp = F.interpolate(video[0], [100,180], mode=\"bilinear\")[None].cuda()" + "video_interp = F.interpolate(video[0], [100,180], mode=\"bilinear\")[None]" ] }, { @@ -771,7 +803,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 29, "id": "0918f246-5556-43b8-9f6d-88013d5a487e", "metadata": {}, "outputs": [ @@ -781,7 +813,7 @@ "torch.Size([1, 48, 3, 100, 180])" ] }, - "execution_count": 40, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -800,7 +832,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 30, "id": "3b852606-5229-4abd-b166-496d35da1009", "metadata": {}, "outputs": [ @@ -808,7 +840,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 9/9 [02:05<00:00, 13.95s/it]\n" + "100%|██████████| 36/36 [02:01<00:00, 3.38s/it]\n" ] } ], @@ -826,7 +858,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 31, "id": "5394b0ba-1fc7-4843-91d5-6113a6e86bdf", "metadata": {}, "outputs": [ @@ -854,20 +886,20 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 32, "id": "9113c2ac-4d25-4ef2-8951-71a1c1be74dd", "metadata": {}, "outputs": [ { "data": { "text/html": [ - "" + "" ], "text/plain": [ "" ] }, - "execution_count": 44, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" }