added upsampling module

This commit is contained in:
Zach Teed
2020-07-25 17:36:17 -06:00
parent dc1220825d
commit a2408eab78
32 changed files with 23559 additions and 619 deletions

89
demo.py
View File

@@ -4,87 +4,76 @@ sys.path.append('core')
import argparse
import os
import cv2
import glob
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import datasets
from utils import flow_viz
from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder
DEVICE = 'cuda'
def pad8(img):
"""pad image such that dimensions are divisible by 8"""
ht, wd = img.shape[2:]
pad_ht = (((ht // 8) + 1) * 8 - ht) % 8
pad_wd = (((wd // 8) + 1) * 8 - wd) % 8
pad_ht1 = [pad_ht//2, pad_ht-pad_ht//2]
pad_wd1 = [pad_wd//2, pad_wd-pad_wd//2]
img = F.pad(img, pad_wd1 + pad_ht1, mode='replicate')
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)[..., :3]
img = torch.from_numpy(img).permute(2, 0, 1).float()
return pad8(img[None]).to(DEVICE)
def load_image_list(image_files):
images = []
for imfile in sorted(image_files):
images.append(load_image(imfile))
images = torch.stack(images, dim=0)
images = images.to(DEVICE)
def display(image1, image2, flow):
image1 = image1.permute(1, 2, 0).cpu().numpy() / 255.0
image2 = image2.permute(1, 2, 0).cpu().numpy() / 255.0
padder = InputPadder(images.shape)
return padder.pad(images)[0]
flow = flow.permute(1, 2, 0).cpu().numpy()
flow_image = flow_viz.flow_to_image(flow)
flow_image = cv2.resize(flow_image, (image1.shape[1], image1.shape[0]))
def viz(img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)
cv2.imshow('image1', image1[..., ::-1])
cv2.imshow('image2', image2[..., ::-1])
cv2.imshow('flow', flow_image[..., ::-1])
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
cv2.waitKey()
def demo(args):
model = RAFT(args)
model = torch.nn.DataParallel(model)
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model = model.module
model.to(DEVICE)
model.eval()
with torch.no_grad():
images = glob.glob(os.path.join(args.path, '*.png')) + \
glob.glob(os.path.join(args.path, '*.jpg'))
# sintel images
image1 = load_image('images/sintel_0.png')
image2 = load_image('images/sintel_1.png')
images = load_image_list(images)
for i in range(images.shape[0]-1):
image1 = images[i,None]
image2 = images[i+1,None]
flow_predictions = model(image1, image2, iters=args.iters, upsample=False)
display(image1[0], image2[0], flow_predictions[-1][0])
# kitti images
image1 = load_image('images/kitti_0.png')
image2 = load_image('images/kitti_1.png')
flow_predictions = model(image1, image2, iters=16)
display(image1[0], image2[0], flow_predictions[-1][0])
# davis images
image1 = load_image('images/davis_0.jpg')
image2 = load_image('images/davis_1.jpg')
flow_predictions = model(image1, image2, iters=16)
display(image1[0], image2[0], flow_predictions[-1][0])
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
viz(image1, flow_up)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--path', help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--iters', type=int, default=12)
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
args = parser.parse_args()
demo(args)
demo(args)