diff --git a/online_demo.py b/online_demo.py index e05ed41..8802849 100644 --- a/online_demo.py +++ b/online_demo.py @@ -52,25 +52,33 @@ if __name__ == "__main__": window_frames = [] - def _process_step(window_frames, is_first_step, grid_size): + def _process_step(window_frames, is_first_step, grid_size, grid_query_frame): video_chunk = ( torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE) .float() .permute(0, 3, 1, 2)[None] ) # (1, T, 3, H, W) - return model(video_chunk, is_first_step=is_first_step, grid_size=grid_size) + return model( + video_chunk, + is_first_step=is_first_step, + grid_size=grid_size, + grid_query_frame=grid_query_frame, + ) # Iterating over video frames, processing one window at a time: is_first_step = True for i, frame in enumerate( iio.imiter( - "https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4", + "./assets/apple.mp4", plugin="FFMPEG", ) ): if i % model.step == 0 and i != 0: pred_tracks, pred_visibility = _process_step( - window_frames, is_first_step, grid_size=args.grid_size + window_frames, + is_first_step, + grid_size=args.grid_size, + grid_query_frame=args.grid_query_frame, ) is_first_step = False window_frames.append(frame) @@ -79,6 +87,7 @@ if __name__ == "__main__": window_frames[-(i % model.step) - model.step - 1 :], is_first_step, grid_size=args.grid_size, + grid_query_frame=args.grid_query_frame, ) print("Tracks are computed")