added cuda extension for efficent implementation

This commit is contained in:
Zach Teed
2020-08-22 18:49:24 -06:00
parent 5b1f510d6b
commit c86b3dc8f3
13 changed files with 519 additions and 191 deletions

View File

@@ -61,7 +61,7 @@ def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
for test_id in range(len(test_dataset)):
image1, image2, (frame_id, ) = test_dataset[test_id]
padder = InputPadder(image1.shape)
padder = InputPadder(image1.shape, mode='kitti')
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
@@ -139,7 +139,7 @@ def validate_kitti(model, iters=24):
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape)
padder = InputPadder(image1.shape, mode='kitti')
image1, image2 = padder.pad(image1, image2)
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
@@ -172,6 +172,7 @@ if __name__ == '__main__':
parser.add_argument('--dataset', help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
args = parser.parse_args()
model = torch.nn.DataParallel(RAFT(args))