release cotracker 2.0
This commit is contained in:
30
demo.py
30
demo.py
@@ -5,7 +5,6 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
@@ -14,9 +13,18 @@ from PIL import Image
|
||||
from cotracker.utils.visualizer import Visualizer, read_video_from_path
|
||||
from cotracker.predictor import CoTrackerPredictor
|
||||
|
||||
DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else
|
||||
'mps' if torch.backends.mps.is_available() else
|
||||
'cpu')
|
||||
# Unfortunately MPS acceleration does not support all the features we require,
|
||||
# but we may be able to enable it in the future
|
||||
|
||||
DEFAULT_DEVICE = (
|
||||
# "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
# if DEFAULT_DEVICE == "mps":
|
||||
# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -32,15 +40,16 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
default="./checkpoints/cotracker_stride_4_wind_8.pth",
|
||||
help="cotracker model",
|
||||
# default="./checkpoints/cotracker.pth",
|
||||
default=None,
|
||||
help="CoTracker model parameters",
|
||||
)
|
||||
parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
|
||||
parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size")
|
||||
parser.add_argument(
|
||||
"--grid_query_frame",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Compute dense and grid tracks starting from this frame ",
|
||||
help="Compute dense and grid tracks starting from this frame",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -57,7 +66,10 @@ if __name__ == "__main__":
|
||||
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
|
||||
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
||||
|
||||
model = CoTrackerPredictor(checkpoint=args.checkpoint)
|
||||
if args.checkpoint is not None:
|
||||
model = CoTrackerPredictor(checkpoint=args.checkpoint)
|
||||
else:
|
||||
model = torch.hub.load("facebookresearch/co-tracker", "cotracker2")
|
||||
model = model.to(DEFAULT_DEVICE)
|
||||
video = video.to(DEFAULT_DEVICE)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user