Add more algorithms
This commit is contained in:
116
lib/datasets/landmark_utils/point_meta.py
Normal file
116
lib/datasets/landmark_utils/point_meta.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
import copy, math, torch, numpy as np
|
||||
from xvision import normalize_points
|
||||
from xvision import denormalize_points
|
||||
|
||||
|
||||
class PointMeta():
|
||||
# points : 3 x num_pts (x, y, oculusion)
|
||||
# image_size: original [width, height]
|
||||
def __init__(self, num_point, points, box, image_path, dataset_name):
|
||||
|
||||
self.num_point = num_point
|
||||
if box is not None:
|
||||
assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4
|
||||
self.box = torch.Tensor(box)
|
||||
else: self.box = None
|
||||
if points is None:
|
||||
self.points = points
|
||||
else:
|
||||
assert len(points.shape) == 2 and points.shape[0] == 3 and points.shape[1] == self.num_point, 'The shape of point is not right : {}'.format( points )
|
||||
self.points = torch.Tensor(points.copy())
|
||||
self.image_path = image_path
|
||||
self.datasets = dataset_name
|
||||
|
||||
def __repr__(self):
|
||||
if self.box is None: boxstr = 'None'
|
||||
else : boxstr = 'box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]'.format(*self.box.tolist())
|
||||
return ('{name}(points={num_point}, '.format(name=self.__class__.__name__, **self.__dict__) + boxstr + ')')
|
||||
|
||||
def get_box(self, return_diagonal=False):
|
||||
if self.box is None: return None
|
||||
if return_diagonal == False:
|
||||
return self.box.clone()
|
||||
else:
|
||||
W = (self.box[2]-self.box[0]).item()
|
||||
H = (self.box[3]-self.box[1]).item()
|
||||
return math.sqrt(H*H+W*W)
|
||||
|
||||
def get_points(self, ignore_indicator=False):
|
||||
if ignore_indicator: last = 2
|
||||
else : last = 3
|
||||
if self.points is not None: return self.points.clone()[:last, :]
|
||||
else : return torch.zeros((last, self.num_point))
|
||||
|
||||
def is_none(self):
|
||||
#assert self.box is not None, 'The box should not be None'
|
||||
return self.points is None
|
||||
#if self.box is None: return True
|
||||
#else : return self.points is None
|
||||
|
||||
def copy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def visiable_pts_num(self):
|
||||
with torch.no_grad():
|
||||
ans = self.points[2,:] > 0
|
||||
ans = torch.sum(ans)
|
||||
ans = ans.item()
|
||||
return ans
|
||||
|
||||
def special_fun(self, indicator):
|
||||
if indicator == '68to49': # For 300W or 300VW, convert the default 68 points to 49 points.
|
||||
assert self.num_point == 68, 'num-point must be 68 vs. {:}'.format(self.num_point)
|
||||
self.num_point = 49
|
||||
out = torch.ones((68), dtype=torch.uint8)
|
||||
out[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,60,64]] = 0
|
||||
if self.points is not None: self.points = self.points.clone()[:, out]
|
||||
else:
|
||||
raise ValueError('Invalid indicator : {:}'.format( indicator ))
|
||||
|
||||
def apply_horizontal_flip(self):
|
||||
#self.points[0, :] = width - self.points[0, :] - 1
|
||||
# Mugsy spefic or Synthetic
|
||||
if self.datasets.startswith('HandsyROT'):
|
||||
ori = np.array(list(range(0, 42)))
|
||||
pos = np.array(list(range(21,42)) + list(range(0,21)))
|
||||
self.points[:, pos] = self.points[:, ori]
|
||||
elif self.datasets.startswith('face68'):
|
||||
ori = np.array(list(range(0, 68)))
|
||||
pos = np.array([17,16,15,14,13,12,11,10, 9, 8,7,6,5,4,3,2,1, 27,26,25,24,23,22,21,20,19,18, 28,29,30,31, 36,35,34,33,32, 46,45,44,43,48,47, 40,39,38,37,42,41, 55,54,53,52,51,50,49,60,59,58,57,56,65,64,63,62,61,68,67,66])-1
|
||||
self.points[:, ori] = self.points[:, pos]
|
||||
else:
|
||||
raise ValueError('Does not support {:}'.format(self.datasets))
|
||||
|
||||
|
||||
|
||||
# shape = (H,W)
|
||||
def apply_affine2point(points, theta, shape):
|
||||
assert points.size(0) == 3, 'invalid points shape : {:}'.format(points.size())
|
||||
with torch.no_grad():
|
||||
ok_points = points[2,:] == 1
|
||||
assert torch.sum(ok_points).item() > 0, 'there is no visiable point'
|
||||
points[:2,:] = normalize_points(shape, points[:2,:])
|
||||
|
||||
norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float()
|
||||
|
||||
trans_points, ___ = torch.gesv(points[:, ok_points], theta)
|
||||
|
||||
norm_trans_points[:, ok_points] = trans_points
|
||||
|
||||
return norm_trans_points
|
||||
|
||||
|
||||
|
||||
def apply_boundary(norm_trans_points):
|
||||
with torch.no_grad():
|
||||
norm_trans_points = norm_trans_points.clone()
|
||||
oks = torch.stack((norm_trans_points[0]>-1, norm_trans_points[0]<1, norm_trans_points[1]>-1, norm_trans_points[1]<1, norm_trans_points[2]>0))
|
||||
oks = torch.sum(oks, dim=0) == 5
|
||||
norm_trans_points[2, :] = oks
|
||||
return norm_trans_points
|
Reference in New Issue
Block a user