{ "cells": [ { "cell_type": "markdown", "id": "60a7e08e-93c6-4370-9778-3bb102dce78b", "metadata": {}, "source": [ "Copyright (c) Meta Platforms, Inc. and affiliates." ] }, { "cell_type": "markdown", "id": "3081cd8f-f6f9-4a1a-8c36-8a857b0c3b03", "metadata": {}, "source": [ "\n", " \"Open\n", "" ] }, { "cell_type": "markdown", "id": "f9f3240f-0354-4802-b8b5-9070930fc957", "metadata": {}, "source": [ "# CoTracker: It is Better to Track Together\n", "This is a demo for CoTracker, a model that can track any point in a video." ] }, { "cell_type": "markdown", "id": "36ff1fd0-572e-47fb-8221-1e73ac17cfd1", "metadata": {}, "source": [ "\"Logo\"" ] }, { "cell_type": "markdown", "id": "88c6db31", "metadata": {}, "source": [ "Don't forget to turn on GPU support if you're running this demo in Colab. \n", "\n", "**Runtime** -> **Change runtime type** -> **Hardware accelerator** -> **GPU**\n", "\n", "Let's install dependencies for Colab:" ] }, { "cell_type": "code", "execution_count": null, "id": "278876a7", "metadata": {}, "outputs": [], "source": [ "!git clone https://github.com/facebookresearch/co-tracker\n", "%cd co-tracker\n", "!pip install -e .\n", "!pip install opencv-python einops timm matplotlib moviepy flow_vis\n", "!mkdir checkpoints\n", "%cd checkpoints\n", "!wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth" ] }, { "cell_type": "code", "execution_count": null, "id": "1745a859-71d4-4ec3-8ef3-027cabe786d4", "metadata": {}, "outputs": [], "source": [ "%cd ..\n", "import os\n", "import torch\n", "\n", "from base64 import b64encode\n", "from cotracker.utils.visualizer import Visualizer, read_video_from_path\n", "from IPython.display import HTML" ] }, { "cell_type": "markdown", "id": "7894bd2d-2099-46fa-8286-f0c56298ecd1", "metadata": {}, "source": [ "Read a video from CO3D:" ] }, { "cell_type": "code", "execution_count": null, "id": "f1f9ca4d-951e-49d2-8844-91f7bcadfecd", "metadata": {}, "outputs": [], "source": [ "video = read_video_from_path('./assets/apple.mp4')\n", "video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()" ] }, { "cell_type": "code", "execution_count": null, "id": "fb4c2e9d-0e85-4c10-81a2-827d0759bf87", "metadata": {}, "outputs": [], "source": [ "def show_video(video_path):\n", " video_file = open(video_path, \"r+b\").read()\n", " video_url = f\"data:video/mp4;base64,{b64encode(video_file).decode()}\"\n", " return HTML(f\"\"\"\"\"\")\n", " \n", "show_video(\"./assets/apple.mp4\")" ] }, { "cell_type": "markdown", "id": "6f89ae18-54d0-4384-8a79-ca9247f5f31a", "metadata": {}, "source": [ "Import CoTrackerPredictor and create an instance of it. We'll use this object to estimate tracks:" ] }, { "cell_type": "code", "execution_count": null, "id": "d59ac40b-bde8-46d4-bd57-4ead939f22ca", "metadata": {}, "outputs": [], "source": [ "from cotracker.predictor import CoTrackerPredictor\n", "\n", "model = CoTrackerPredictor(\n", " checkpoint=os.path.join(\n", " './checkpoints/cotracker_stride_4_wind_8.pth'\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "3f2a4485", "metadata": {}, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " model = model.cuda()\n", " video = video.cuda()\n", "elif torch.backends.mps.is_available():\n", " model = model.to('mps')\n", " video = video.to('mps')" ] }, { "cell_type": "markdown", "id": "e8398155-6dae-4ff0-95f3-dbb52ac70d20", "metadata": {}, "source": [ "Track points sampled on a regular grid of size 30\\*30 on the first frame:" ] }, { "cell_type": "code", "execution_count": null, "id": "17fcaae9-7b3c-474c-977a-cce08a09d580", "metadata": {}, "outputs": [], "source": [ "pred_tracks, pred_visibility = model(video, grid_size=30)" ] }, { "cell_type": "markdown", "id": "50a58521-a9ba-4f8b-be02-cfdaf79613a2", "metadata": {}, "source": [ "Visualize and save the result: " ] }, { "cell_type": "code", "execution_count": null, "id": "7e793ce0-7b77-46ca-a629-155a6a146000", "metadata": {}, "outputs": [], "source": [ "vis = Visualizer(save_dir='./videos', pad_value=100)\n", "vis.visualize(video=video, tracks=pred_tracks, visibility=pred_visibility, filename='teaser');" ] }, { "cell_type": "code", "execution_count": null, "id": "2d0733ba-8fe1-4cd4-b963-2085202fba13", "metadata": {}, "outputs": [], "source": [ "show_video(\"./videos/teaser_pred_track.mp4\")" ] }, { "cell_type": "markdown", "id": "73d88a5f-057c-4b9f-828d-ee0b97d1e72f", "metadata": {}, "source": [ "## Tracking manually selected points" ] }, { "cell_type": "markdown", "id": "a75bca85-b872-4f4e-be19-ff16f0984037", "metadata": { "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ "We will start by tracking points queried manually.\n", "We define a queried point as: [time, x coord, y coord] \n", "\n", "So, the code below defines points with different x and y coordinates sampled on frames 0, 10, 20, and 30:" ] }, { "cell_type": "code", "execution_count": null, "id": "c6422e7c-8c6f-4269-92c3-245344afe35b", "metadata": {}, "outputs": [], "source": [ "queries = torch.tensor([\n", " [0., 400., 350.], # point tracked from the first frame\n", " [10., 600., 500.], # frame number 10\n", " [20., 750., 600.], # ...\n", " [30., 900., 200.]\n", "])\n", "if torch.cuda.is_available():\n", " queries = queries.cuda()" ] }, { "cell_type": "markdown", "id": "13697a2a-7304-4d18-93be-bfbebf3dc12a", "metadata": {}, "source": [ "That's what our queried points look like:" ] }, { "cell_type": "code", "execution_count": null, "id": "d7141079-d7e0-40b3-b031-a28879c4bd6d", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "# Create a list of frame numbers corresponding to each point\n", "frame_numbers = queries[:,0].int().tolist()\n", "\n", "fig, axs = plt.subplots(2, 2)\n", "axs = axs.flatten()\n", "\n", "for i, (query, frame_number) in enumerate(zip(queries, frame_numbers)):\n", " ax = axs[i]\n", " ax.plot(query[1].item(), query[2].item(), 'ro') \n", " \n", " ax.set_title(\"Frame {}\".format(frame_number))\n", " ax.set_xlim(0, video.shape[4])\n", " ax.set_ylim(0, video.shape[3])\n", " ax.invert_yaxis()\n", " \n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "aec7693b-9d74-48b3-b612-360290ff1e7a", "metadata": {}, "source": [ "We pass these points as input to the model and track them:" ] }, { "cell_type": "code", "execution_count": null, "id": "09008ca9-6a87-494f-8b05-6370cae6a600", "metadata": {}, "outputs": [], "source": [ "pred_tracks, pred_visibility = model(video, queries=queries[None])" ] }, { "cell_type": "markdown", "id": "b00d2a35-3daf-482d-b40b-b6d4f548ca40", "metadata": {}, "source": [ "Finally, we visualize the results with tracks leaving traces from the frame where the tracking starts.\n", "Color encodes time:" ] }, { "cell_type": "code", "execution_count": null, "id": "01467f8d-667c-4f41-b418-93132584c659", "metadata": {}, "outputs": [], "source": [ "vis = Visualizer(\n", " save_dir='./videos',\n", " linewidth=6,\n", " mode='cool',\n", " tracks_leave_trace=-1\n", ")\n", "vis.visualize(\n", " video=video,\n", " tracks=pred_tracks,\n", " visibility=pred_visibility,\n", " filename='queries');" ] }, { "cell_type": "code", "execution_count": null, "id": "fe23d210-ed90-49f1-8311-b7e354c7a9f6", "metadata": {}, "outputs": [], "source": [ "show_video(\"./videos/queries_pred_track.mp4\")" ] }, { "cell_type": "markdown", "id": "87f2a3b4-a8b3-4aeb-87d2-28f056c624ba", "metadata": {}, "source": [ "## Points on a regular grid" ] }, { "cell_type": "markdown", "id": "a9aac679-19f8-4b78-9cc9-d934c6f83b01", "metadata": {}, "source": [ "### Tracking forward from the frame number x" ] }, { "cell_type": "markdown", "id": "0aeabca9-cc34-4d0f-8b2d-e6a6f797cb20", "metadata": {}, "source": [ "Let's now sample points on a regular grid and start tracking from the frame number 20 with a grid of 30\\*30. " ] }, { "cell_type": "code", "execution_count": null, "id": "c880f3ca-cf42-4f64-9df6-a0e8de6561dc", "metadata": {}, "outputs": [], "source": [ "grid_size = 30\n", "grid_query_frame = 20" ] }, { "cell_type": "code", "execution_count": null, "id": "3cd58820-7b23-469e-9b6d-5fa81257981f", "metadata": {}, "outputs": [], "source": [ "pred_tracks, pred_visibility = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame)" ] }, { "cell_type": "code", "execution_count": null, "id": "25a85a1d-dce0-4e6b-9f7a-aaf31ade0600", "metadata": {}, "outputs": [], "source": [ "vis = Visualizer(save_dir='./videos', pad_value=100)\n", "vis.visualize(\n", " video=video,\n", " tracks=pred_tracks,\n", " visibility=pred_visibility,\n", " filename='grid_query_20',\n", " query_frame=grid_query_frame);" ] }, { "cell_type": "markdown", "id": "ce0fb5b8-d249-4f4e-b59a-51b4f03972c4", "metadata": {}, "source": [ "Note that tracking starts only from points sampled on a frame in the middle of the video. This is different from the grid in the first example:" ] }, { "cell_type": "code", "execution_count": null, "id": "f0b01d51-9222-472b-a714-188c38d83ad9", "metadata": {}, "outputs": [], "source": [ "show_video(\"./videos/grid_query_20_pred_track.mp4\")" ] }, { "cell_type": "markdown", "id": "10baad8f-0cb8-4118-9e69-3fb24575715c", "metadata": {}, "source": [ "### Tracking forward **and backward** from the frame number x" ] }, { "cell_type": "markdown", "id": "8409e2f7-9e4e-4228-b198-56a64e2260a7", "metadata": {}, "source": [ "CoTracker is an online algorithm and tracks points only in one direction. However, we can also run it backward from the queried point to track in both directions: " ] }, { "cell_type": "code", "execution_count": null, "id": "506233dc-1fb3-4a3c-b9eb-5cbd5df49128", "metadata": {}, "outputs": [], "source": [ "grid_size = 30\n", "grid_query_frame = 20" ] }, { "cell_type": "markdown", "id": "495b5fb4-9050-41fe-be98-d757916d0812", "metadata": {}, "source": [ "Let's activate backward tracking:" ] }, { "cell_type": "code", "execution_count": null, "id": "677cf34e-6c6a-49e3-a21b-f8a4f718f916", "metadata": {}, "outputs": [], "source": [ "pred_tracks, pred_visibility = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame, backward_tracking=True)\n", "vis.visualize(\n", " video=video,\n", " tracks=pred_tracks,\n", " visibility=pred_visibility,\n", " filename='grid_query_20_backward',\n", " query_frame=grid_query_frame);" ] }, { "cell_type": "markdown", "id": "585a0afa-2cfc-4a07-a6f0-f65924b9ebce", "metadata": {}, "source": [ "As you can see, we are now tracking points queried in the middle from the first frame:" ] }, { "cell_type": "code", "execution_count": null, "id": "c8d64ab0-7e92-4238-8e7d-178652fc409c", "metadata": {}, "outputs": [], "source": [ "show_video(\"./videos/grid_query_20_backward_pred_track.mp4\")" ] }, { "cell_type": "markdown", "id": "fb55fb01-6d8e-4e06-9346-8b2e9ef489c2", "metadata": {}, "source": [ "## Regular grid + Segmentation mask" ] }, { "cell_type": "markdown", "id": "e93a6b0a-b173-46fa-a6d2-1661ae6e6779", "metadata": {}, "source": [ "Let's now sample points on a grid and filter them with a segmentation mask.\n", "This allows us to track points sampled densely on an object because we consume less GPU memory." ] }, { "cell_type": "code", "execution_count": null, "id": "b759548d-1eda-473e-9c90-99e5d3197e20", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from PIL import Image\n", "grid_size = 120" ] }, { "cell_type": "code", "execution_count": null, "id": "14ae8a8b-fec7-40d1-b6f2-10e333b75db4", "metadata": {}, "outputs": [], "source": [ "input_mask = './assets/apple_mask.png'\n", "segm_mask = np.array(Image.open(input_mask))" ] }, { "cell_type": "markdown", "id": "4e3a1520-64bf-4a0d-b6e9-639430e31940", "metadata": {}, "source": [ "That's a segmentation mask for the first frame:" ] }, { "cell_type": "code", "execution_count": null, "id": "4d2efd4e-22df-4833-b9a0-a0763d59ee22", "metadata": {}, "outputs": [], "source": [ "plt.imshow((segm_mask[...,None]/255.*video[0,0].permute(1,2,0).cpu().numpy()/255.))" ] }, { "cell_type": "code", "execution_count": null, "id": "b42dce24-7952-4660-8298-4c362d6913cf", "metadata": {}, "outputs": [], "source": [ "pred_tracks, pred_visibility = model(video, grid_size=grid_size, segm_mask=torch.from_numpy(segm_mask)[None, None])\n", "vis = Visualizer(\n", " save_dir='./videos',\n", " pad_value=100,\n", " linewidth=2,\n", ")\n", "vis.visualize(\n", " video=video,\n", " tracks=pred_tracks,\n", " visibility=pred_visibility,\n", " filename='segm_grid');" ] }, { "cell_type": "markdown", "id": "5a386308-0d20-4ba3-bbb9-98ea79823a47", "metadata": {}, "source": [ "We are now only tracking points on the object (and around):" ] }, { "cell_type": "code", "execution_count": null, "id": "1810440f-00f4-488a-a174-36be05949e42", "metadata": {}, "outputs": [], "source": [ "show_video(\"./videos/segm_grid_pred_track.mp4\")" ] }, { "cell_type": "markdown", "id": "a63e89e4-8890-4e1b-91ec-d5dfa3f93309", "metadata": {}, "source": [ "## Dense Tracks" ] }, { "cell_type": "markdown", "id": "4ae764d8-db7c-41c2-a712-1876e7b4372d", "metadata": {}, "source": [ "### Tracking forward **and backward** from the frame number x" ] }, { "cell_type": "markdown", "id": "0dde3237-ecad-4c9b-b100-28b1f1b3cbe6", "metadata": {}, "source": [ "CoTracker also has a mode to track **every pixel** in a video in a **dense** manner but it is much slower than in previous examples. Let's downsample the video in order to make it faster: " ] }, { "cell_type": "code", "execution_count": null, "id": "379557d9-80ea-4316-91df-4da215193b41", "metadata": {}, "outputs": [], "source": [ "video.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "c6db5cc7-351d-4d9e-9b9d-3a40f05b077a", "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "video_interp = F.interpolate(video[0], [100,180], mode=\"bilinear\")[None]" ] }, { "cell_type": "markdown", "id": "7ba32cb3-97dc-46f5-b2bd-b93a094dc819", "metadata": {}, "source": [ "The video now has a much lower resolution:" ] }, { "cell_type": "code", "execution_count": null, "id": "0918f246-5556-43b8-9f6d-88013d5a487e", "metadata": {}, "outputs": [], "source": [ "video_interp.shape" ] }, { "cell_type": "markdown", "id": "bc7d3a2c-5e87-4c8d-ad10-1f9c6d2ffbed", "metadata": {}, "source": [ "Again, let's track points in both directions. This will only take a couple of minutes:" ] }, { "cell_type": "code", "execution_count": null, "id": "3b852606-5229-4abd-b166-496d35da1009", "metadata": {}, "outputs": [], "source": [ "pred_tracks, pred_visibility = model(video_interp, grid_query_frame=20, backward_tracking=True)" ] }, { "cell_type": "markdown", "id": "4143ab14-810e-4e65-93f1-5775957cf4da", "metadata": {}, "source": [ "Visualization with an optical flow color encoding:" ] }, { "cell_type": "code", "execution_count": null, "id": "5394b0ba-1fc7-4843-91d5-6113a6e86bdf", "metadata": {}, "outputs": [], "source": [ "vis = Visualizer(\n", " save_dir='./videos',\n", " pad_value=20,\n", " linewidth=1,\n", " mode='optical_flow'\n", ")\n", "vis.visualize(\n", " video=video_interp,\n", " tracks=pred_tracks,\n", " visibility=pred_visibility,\n", " query_frame=grid_query_frame,\n", " filename='dense');" ] }, { "cell_type": "code", "execution_count": null, "id": "9113c2ac-4d25-4ef2-8951-71a1c1be74dd", "metadata": {}, "outputs": [], "source": [ "show_video(\"./videos/dense_pred_track.mp4\")" ] }, { "cell_type": "markdown", "id": "95e9bce0-382b-4d18-9316-7f92093ada1d", "metadata": {}, "source": [ "That's all, now you can use CoTracker in your projects!" ] }, { "cell_type": "code", "execution_count": null, "id": "54e0ba0c-b532-46a9-af6f-9508de689dd2", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.17" } }, "nbformat": 4, "nbformat_minor": 5 }