From ab0ce3c97795222f528c52c53dcbb0bf95e9b785 Mon Sep 17 00:00:00 2001 From: nikitakaraevv Date: Fri, 21 Jul 2023 13:41:52 -0700 Subject: [PATCH] add cpu-only mode --- cotracker/evaluation/core/evaluator.py | 10 ++++-- cotracker/evaluation/evaluate.py | 2 ++ cotracker/models/core/cotracker/cotracker.py | 10 +++--- cotracker/models/evaluation_predictor.py | 21 ++++++------ cotracker/predictor.py | 2 +- notebooks/demo.ipynb | 34 ++++++-------------- train.py | 2 ++ 7 files changed, 39 insertions(+), 42 deletions(-) diff --git a/cotracker/evaluation/core/evaluator.py b/cotracker/evaluation/core/evaluator.py index 9f4053b..423f965 100644 --- a/cotracker/evaluation/core/evaluator.py +++ b/cotracker/evaluation/core/evaluator.py @@ -185,7 +185,11 @@ class Evaluator: if not all(gotit): print("batch is None") continue - dataclass_to_cuda_(sample) + if torch.cuda.is_available(): + dataclass_to_cuda_(sample) + device = torch.device("cuda") + else: + device = torch.device("cpu") if ( not train_mode @@ -205,7 +209,7 @@ class Evaluator: queries[:, :, 1], ], dim=2, - ) + ).to(device) else: queries = torch.cat( [ @@ -213,7 +217,7 @@ class Evaluator: sample.trajectory[:, 0], ], dim=2, - ) + ).to(device) pred_tracks = model(sample.video, queries) if "strided" in dataset_name: diff --git a/cotracker/evaluation/evaluate.py b/cotracker/evaluation/evaluate.py index 1995629..cd2e00e 100644 --- a/cotracker/evaluation/evaluate.py +++ b/cotracker/evaluation/evaluate.py @@ -102,6 +102,8 @@ def run_eval(cfg: DefaultConfig): single_point=cfg.single_point, n_iters=cfg.n_iters, ) + if torch.cuda.is_available(): + predictor.model = predictor.model.cuda() # Setting the random seeds torch.manual_seed(cfg.seed) diff --git a/cotracker/models/core/cotracker/cotracker.py b/cotracker/models/core/cotracker/cotracker.py index c9eca1f..c25e7db 100644 --- a/cotracker/models/core/cotracker/cotracker.py +++ b/cotracker/models/core/cotracker/cotracker.py @@ -25,14 +25,14 @@ from cotracker.models.core.embeddings import ( torch.manual_seed(0) -def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): +def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cuda"): if grid_size == 1: - return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[ + return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[ None, None - ].cuda() + ] grid_y, grid_x = meshgrid2d( - 1, grid_size, grid_size, stack=False, norm=False, device="cuda" + 1, grid_size, grid_size, stack=False, norm=False, device=device ) step = interp_shape[1] // 64 if grid_center[0] != 0 or grid_center[1] != 0: @@ -47,7 +47,7 @@ def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): grid_y = grid_y + grid_center[0] grid_x = grid_x + grid_center[1] - xy = torch.stack([grid_x, grid_y], dim=-1).cuda() + xy = torch.stack([grid_x, grid_y], dim=-1).to(device) return xy diff --git a/cotracker/models/evaluation_predictor.py b/cotracker/models/evaluation_predictor.py index b50233a..074fcdb 100644 --- a/cotracker/models/evaluation_predictor.py +++ b/cotracker/models/evaluation_predictor.py @@ -29,11 +29,10 @@ class EvaluationPredictor(torch.nn.Module): self.n_iters = n_iters self.model = cotracker_model - self.model.to("cuda") self.model.eval() def forward(self, video, queries): - queries = queries.clone().cuda() + queries = queries.clone() B, T, C, H, W = video.shape B, N, D = queries.shape @@ -42,14 +41,16 @@ class EvaluationPredictor(torch.nn.Module): rgbs = video.reshape(B * T, C, H, W) rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear") - rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda() + rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) + + device = rgbs.device queries[:, :, 1] *= self.interp_shape[1] / W queries[:, :, 2] *= self.interp_shape[0] / H if self.single_point: - traj_e = torch.zeros((B, T, N, 2)).cuda() - vis_e = torch.zeros((B, T, N)).cuda() + traj_e = torch.zeros((B, T, N, 2), device=device) + vis_e = torch.zeros((B, T, N), device=device) for pind in range((N)): query = queries[:, pind : pind + 1] @@ -60,8 +61,10 @@ class EvaluationPredictor(torch.nn.Module): vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1] else: if self.grid_size > 0: - xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:]) - xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() # + xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device) + xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to( + device + ) # queries = torch.cat([queries, xy], dim=1) # traj_e, __, vis_e, __ = self.model( @@ -91,8 +94,8 @@ class EvaluationPredictor(torch.nn.Module): query = torch.cat([query, xy_target], dim=1).to(device) # if self.grid_size > 0: - xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:]) - xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() # + xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device) + xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) # query = torch.cat([query, xy], dim=1).to(device) # # crop the video to start from the queried frame query[0, 0, 0] = 0 diff --git a/cotracker/predictor.py b/cotracker/predictor.py index e664240..f36b500 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -116,7 +116,7 @@ class CoTrackerPredictor(torch.nn.Module): queries[:, :, 1] *= self.interp_shape[1] / W queries[:, :, 2] *= self.interp_shape[0] / H elif grid_size > 0: - grid_pts = get_points_on_a_grid(grid_size, self.interp_shape) + grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device) if segm_mask is not None: segm_mask = F.interpolate( segm_mask, tuple(self.interp_shape), mode="nearest" diff --git a/notebooks/demo.ipynb b/notebooks/demo.ipynb index 7291235..55e0cc2 100644 --- a/notebooks/demo.ipynb +++ b/notebooks/demo.ipynb @@ -65,26 +65,10 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "1745a859-71d4-4ec3-8ef3-027cabe786d4", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/private/home/nikitakaraev/dev/neurips_2023/co-tracker\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/private/home/nikitakaraev/.conda/envs/stereoformer/lib/python3.8/site-packages/requests/__init__.py:109: RequestsDependencyWarning: urllib3 (1.26.14) or chardet (None)/charset_normalizer (3.2.0) doesn't match a supported version!\n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ "%cd ..\n", "import os\n", @@ -105,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "f1f9ca4d-951e-49d2-8844-91f7bcadfecd", "metadata": {}, "outputs": [], @@ -116,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "fb4c2e9d-0e85-4c10-81a2-827d0759bf87", "metadata": {}, "outputs": [ @@ -129,7 +113,7 @@ "" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -175,8 +159,8 @@ "outputs": [], "source": [ "if torch.cuda.is_available():\n", - " model=model.cuda()\n", - " video=video.cuda()" + " model = model.cuda()\n", + " video = video.cuda()" ] }, { @@ -282,7 +266,9 @@ " [10., 600., 500.], # frame number 10\n", " [20., 750., 600.], # ...\n", " [30., 900., 200.]\n", - "]).cuda()" + "])\n", + "if torch.cuda.is_available():\n", + " queries = queries.cuda()" ] }, { diff --git a/train.py b/train.py index 57b9f87..13ad8ef 100644 --- a/train.py +++ b/train.py @@ -138,6 +138,8 @@ def run_test_eval(evaluator, model, dataloaders, writer, step): single_point=False, n_iters=6, ) + if torch.cuda.is_available(): + predictor.model = predictor.model.cuda() metrics = evaluator.evaluate_sequence( model=predictor,