initial commit
This commit is contained in:
90
demo.py
Normal file
90
demo.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import sys
|
||||
sys.path.append('core')
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
|
||||
import datasets
|
||||
from utils import flow_viz
|
||||
from raft import RAFT
|
||||
|
||||
|
||||
DEVICE = 'cuda'
|
||||
|
||||
def pad8(img):
|
||||
"""pad image such that dimensions are divisible by 8"""
|
||||
ht, wd = img.shape[2:]
|
||||
pad_ht = (((ht // 8) + 1) * 8 - ht) % 8
|
||||
pad_wd = (((wd // 8) + 1) * 8 - wd) % 8
|
||||
pad_ht1 = [pad_ht//2, pad_ht-pad_ht//2]
|
||||
pad_wd1 = [pad_wd//2, pad_wd-pad_wd//2]
|
||||
|
||||
img = F.pad(img, pad_wd1 + pad_ht1, mode='replicate')
|
||||
return img
|
||||
|
||||
def load_image(imfile):
|
||||
img = np.array(Image.open(imfile)).astype(np.uint8)[..., :3]
|
||||
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
||||
return pad8(img[None]).to(DEVICE)
|
||||
|
||||
|
||||
def display(image1, image2, flow):
|
||||
image1 = image1.permute(1, 2, 0).cpu().numpy() / 255.0
|
||||
image2 = image2.permute(1, 2, 0).cpu().numpy() / 255.0
|
||||
|
||||
flow = flow.permute(1, 2, 0).cpu().numpy()
|
||||
flow_image = flow_viz.flow_to_image(flow)
|
||||
flow_image = cv2.resize(flow_image, (image1.shape[1], image1.shape[0]))
|
||||
|
||||
|
||||
cv2.imshow('image1', image1[..., ::-1])
|
||||
cv2.imshow('image2', image2[..., ::-1])
|
||||
cv2.imshow('flow', flow_image[..., ::-1])
|
||||
cv2.waitKey()
|
||||
|
||||
|
||||
def demo(args):
|
||||
model = RAFT(args)
|
||||
model = torch.nn.DataParallel(model)
|
||||
model.load_state_dict(torch.load(args.model))
|
||||
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
# sintel images
|
||||
image1 = load_image('images/sintel_0.png')
|
||||
image2 = load_image('images/sintel_1.png')
|
||||
|
||||
flow_predictions = model(image1, image2, iters=args.iters, upsample=False)
|
||||
display(image1[0], image2[0], flow_predictions[-1][0])
|
||||
|
||||
# kitti images
|
||||
image1 = load_image('images/kitti_0.png')
|
||||
image2 = load_image('images/kitti_1.png')
|
||||
|
||||
flow_predictions = model(image1, image2, iters=16)
|
||||
display(image1[0], image2[0], flow_predictions[-1][0])
|
||||
|
||||
# davis images
|
||||
image1 = load_image('images/davis_0.jpg')
|
||||
image2 = load_image('images/davis_1.jpg')
|
||||
|
||||
flow_predictions = model(image1, image2, iters=16)
|
||||
display(image1[0], image2[0], flow_predictions[-1][0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', help="restore checkpoint")
|
||||
parser.add_argument('--small', action='store_true', help='use small model')
|
||||
parser.add_argument('--iters', type=int, default=12)
|
||||
|
||||
args = parser.parse_args()
|
||||
demo(args)
|
Reference in New Issue
Block a user