Updates
This commit is contained in:
@@ -20,172 +20,282 @@ import torch.utils.data as data
|
||||
|
||||
|
||||
class LandmarkDataset(data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
transform,
|
||||
sigma,
|
||||
downsample,
|
||||
heatmap_type,
|
||||
shape,
|
||||
use_gray,
|
||||
mean_file,
|
||||
data_indicator,
|
||||
cache_images=None,
|
||||
):
|
||||
|
||||
def __init__(self, transform, sigma, downsample, heatmap_type, shape, use_gray, mean_file, data_indicator, cache_images=None):
|
||||
|
||||
self.transform = transform
|
||||
self.sigma = sigma
|
||||
self.downsample = downsample
|
||||
self.heatmap_type = heatmap_type
|
||||
self.dataset_name = data_indicator
|
||||
self.shape = shape # [H,W]
|
||||
self.use_gray = use_gray
|
||||
assert transform is not None, 'transform : {:}'.format(transform)
|
||||
self.mean_file = mean_file
|
||||
if mean_file is None:
|
||||
self.mean_data = None
|
||||
warnings.warn('LandmarkDataset initialized with mean_data = None')
|
||||
else:
|
||||
assert osp.isfile(mean_file), '{:} is not a file.'.format(mean_file)
|
||||
self.mean_data = torch.load(mean_file)
|
||||
self.reset()
|
||||
self.cutout = None
|
||||
self.cache_images = cache_images
|
||||
print ('The general dataset initialization done : {:}'.format(self))
|
||||
warnings.simplefilter( 'once' )
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
|
||||
def set_cutout(self, length):
|
||||
if length is not None and length >= 1:
|
||||
self.cutout = CutOut( int(length) )
|
||||
else: self.cutout = None
|
||||
|
||||
|
||||
def reset(self, num_pts=-1, boxid='default', only_pts=False):
|
||||
self.NUM_PTS = num_pts
|
||||
if only_pts: return
|
||||
self.length = 0
|
||||
self.datas = []
|
||||
self.labels = []
|
||||
self.NormDistances = []
|
||||
self.BOXID = boxid
|
||||
if self.mean_data is None:
|
||||
self.mean_face = None
|
||||
else:
|
||||
self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T)
|
||||
assert (self.mean_face >= -1).all() and (self.mean_face <= 1).all(), 'mean-{:}-face : {:}'.format(boxid, self.mean_face)
|
||||
#assert self.dataset_name is not None, 'The dataset name is None'
|
||||
|
||||
|
||||
def __len__(self):
|
||||
assert len(self.datas) == self.length, 'The length is not correct : {}'.format(self.length)
|
||||
return self.length
|
||||
|
||||
|
||||
def append(self, data, label, distance):
|
||||
assert osp.isfile(data), 'The image path is not a file : {:}'.format(data)
|
||||
self.datas.append( data ) ; self.labels.append( label )
|
||||
self.NormDistances.append( distance )
|
||||
self.length = self.length + 1
|
||||
|
||||
|
||||
def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset):
|
||||
if reset: self.reset(num_pts, boxindicator)
|
||||
else : assert self.NUM_PTS == num_pts and self.BOXID == boxindicator, 'The number of point is inconsistance : {:} vs {:}'.format(self.NUM_PTS, num_pts)
|
||||
if isinstance(file_lists, str): file_lists = [file_lists]
|
||||
samples = []
|
||||
for idx, file_path in enumerate(file_lists):
|
||||
print (':::: load list {:}/{:} : {:}'.format(idx, len(file_lists), file_path))
|
||||
xdata = torch.load(file_path)
|
||||
if isinstance(xdata, list) : data = xdata # image or video dataset list
|
||||
elif isinstance(xdata, dict): data = xdata['datas'] # multi-view dataset list
|
||||
else: raise ValueError('Invalid Type Error : {:}'.format( type(xdata) ))
|
||||
samples = samples + data
|
||||
# samples is a dict, where the key is the image-path and the value is the annotation
|
||||
# each annotation is a dict, contains 'points' (3,num_pts), and various box
|
||||
print ('GeneralDataset-V2 : {:} samples'.format(len(samples)))
|
||||
|
||||
#for index, annotation in enumerate(samples):
|
||||
for index in tqdm( range( len(samples) ) ):
|
||||
annotation = samples[index]
|
||||
image_path = annotation['current_frame']
|
||||
points, box = annotation['points'], annotation['box-{:}'.format(boxindicator)]
|
||||
label = PointMeta2V(self.NUM_PTS, points, box, image_path, self.dataset_name)
|
||||
if normalizeL is None: normDistance = None
|
||||
else : normDistance = annotation['normalizeL-{:}'.format(normalizeL)]
|
||||
self.append(image_path, label, normDistance)
|
||||
|
||||
assert len(self.datas) == self.length, 'The length and the data is not right {} vs {}'.format(self.length, len(self.datas))
|
||||
assert len(self.labels) == self.length, 'The length and the labels is not right {} vs {}'.format(self.length, len(self.labels))
|
||||
assert len(self.NormDistances) == self.length, 'The length and the NormDistances is not right {} vs {}'.format(self.length, len(self.NormDistance))
|
||||
print ('Load data done for LandmarkDataset, which has {:} images.'.format(self.length))
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
assert index >= 0 and index < self.length, 'Invalid index : {:}'.format(index)
|
||||
if self.cache_images is not None and self.datas[index] in self.cache_images:
|
||||
image = self.cache_images[ self.datas[index] ].clone()
|
||||
else:
|
||||
image = pil_loader(self.datas[index], self.use_gray)
|
||||
target = self.labels[index].copy()
|
||||
return self._process_(image, target, index)
|
||||
|
||||
|
||||
def _process_(self, image, target, index):
|
||||
|
||||
# transform the image and points
|
||||
image, target, theta = self.transform(image, target)
|
||||
(C, H, W), (height, width) = image.size(), self.shape
|
||||
|
||||
# obtain the visiable indicator vector
|
||||
if target.is_none(): nopoints = True
|
||||
else : nopoints = False
|
||||
if index == -1: __path = None
|
||||
else : __path = self.datas[index]
|
||||
if isinstance(theta, list) or isinstance(theta, tuple):
|
||||
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = [], [], [], [], [], []
|
||||
for _theta in theta:
|
||||
_affineImage, _heatmaps, _mask, _norm_trans_points, _theta, _transpose_theta \
|
||||
= self.__process_affine(image, target, _theta, nopoints, 'P[{:}]@{:}'.format(index, __path))
|
||||
affineImage.append(_affineImage)
|
||||
heatmaps.append(_heatmaps)
|
||||
mask.append(_mask)
|
||||
norm_trans_points.append(_norm_trans_points)
|
||||
THETA.append(_theta)
|
||||
transpose_theta.append(_transpose_theta)
|
||||
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = \
|
||||
torch.stack(affineImage), torch.stack(heatmaps), torch.stack(mask), torch.stack(norm_trans_points), torch.stack(THETA), torch.stack(transpose_theta)
|
||||
else:
|
||||
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = self.__process_affine(image, target, theta, nopoints, 'S[{:}]@{:}'.format(index, __path))
|
||||
|
||||
torch_index = torch.IntTensor([index])
|
||||
torch_nopoints = torch.ByteTensor( [ nopoints ] )
|
||||
torch_shape = torch.IntTensor([H,W])
|
||||
|
||||
return affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta, torch_index, torch_nopoints, torch_shape
|
||||
|
||||
|
||||
def __process_affine(self, image, target, theta, nopoints, aux_info=None):
|
||||
image, target, theta = image.clone(), target.copy(), theta.clone()
|
||||
(C, H, W), (height, width) = image.size(), self.shape
|
||||
if nopoints: # do not have label
|
||||
norm_trans_points = torch.zeros((3, self.NUM_PTS))
|
||||
heatmaps = torch.zeros((self.NUM_PTS+1, height//self.downsample, width//self.downsample))
|
||||
mask = torch.ones((self.NUM_PTS+1, 1, 1), dtype=torch.uint8)
|
||||
transpose_theta = identity2affine(False)
|
||||
else:
|
||||
norm_trans_points = apply_affine2point(target.get_points(), theta, (H,W))
|
||||
norm_trans_points = apply_boundary(norm_trans_points)
|
||||
real_trans_points = norm_trans_points.clone()
|
||||
real_trans_points[:2, :] = denormalize_points(self.shape, real_trans_points[:2,:])
|
||||
heatmaps, mask = generate_label_map(real_trans_points.numpy(), height//self.downsample, width//self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type) # H*W*C
|
||||
heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(torch.FloatTensor)
|
||||
mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor)
|
||||
if self.mean_face is None:
|
||||
#warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.')
|
||||
transpose_theta = identity2affine(False)
|
||||
else:
|
||||
if torch.sum(norm_trans_points[2,:] == 1) < 3:
|
||||
warnings.warn('In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}'.format(aux_info))
|
||||
transpose_theta = identity2affine(False)
|
||||
self.transform = transform
|
||||
self.sigma = sigma
|
||||
self.downsample = downsample
|
||||
self.heatmap_type = heatmap_type
|
||||
self.dataset_name = data_indicator
|
||||
self.shape = shape # [H,W]
|
||||
self.use_gray = use_gray
|
||||
assert transform is not None, "transform : {:}".format(transform)
|
||||
self.mean_file = mean_file
|
||||
if mean_file is None:
|
||||
self.mean_data = None
|
||||
warnings.warn("LandmarkDataset initialized with mean_data = None")
|
||||
else:
|
||||
transpose_theta = solve2theta(norm_trans_points, self.mean_face.clone())
|
||||
assert osp.isfile(mean_file), "{:} is not a file.".format(mean_file)
|
||||
self.mean_data = torch.load(mean_file)
|
||||
self.reset()
|
||||
self.cutout = None
|
||||
self.cache_images = cache_images
|
||||
print("The general dataset initialization done : {:}".format(self))
|
||||
warnings.simplefilter("once")
|
||||
|
||||
affineImage = affine2image(image, theta, self.shape)
|
||||
if self.cutout is not None: affineImage = self.cutout( affineImage )
|
||||
def __repr__(self):
|
||||
return "{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta
|
||||
def set_cutout(self, length):
|
||||
if length is not None and length >= 1:
|
||||
self.cutout = CutOut(int(length))
|
||||
else:
|
||||
self.cutout = None
|
||||
|
||||
def reset(self, num_pts=-1, boxid="default", only_pts=False):
|
||||
self.NUM_PTS = num_pts
|
||||
if only_pts:
|
||||
return
|
||||
self.length = 0
|
||||
self.datas = []
|
||||
self.labels = []
|
||||
self.NormDistances = []
|
||||
self.BOXID = boxid
|
||||
if self.mean_data is None:
|
||||
self.mean_face = None
|
||||
else:
|
||||
self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T)
|
||||
assert (self.mean_face >= -1).all() and (
|
||||
self.mean_face <= 1
|
||||
).all(), "mean-{:}-face : {:}".format(boxid, self.mean_face)
|
||||
# assert self.dataset_name is not None, 'The dataset name is None'
|
||||
|
||||
def __len__(self):
|
||||
assert len(self.datas) == self.length, "The length is not correct : {}".format(
|
||||
self.length
|
||||
)
|
||||
return self.length
|
||||
|
||||
def append(self, data, label, distance):
|
||||
assert osp.isfile(data), "The image path is not a file : {:}".format(data)
|
||||
self.datas.append(data)
|
||||
self.labels.append(label)
|
||||
self.NormDistances.append(distance)
|
||||
self.length = self.length + 1
|
||||
|
||||
def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset):
|
||||
if reset:
|
||||
self.reset(num_pts, boxindicator)
|
||||
else:
|
||||
assert (
|
||||
self.NUM_PTS == num_pts and self.BOXID == boxindicator
|
||||
), "The number of point is inconsistance : {:} vs {:}".format(
|
||||
self.NUM_PTS, num_pts
|
||||
)
|
||||
if isinstance(file_lists, str):
|
||||
file_lists = [file_lists]
|
||||
samples = []
|
||||
for idx, file_path in enumerate(file_lists):
|
||||
print(
|
||||
":::: load list {:}/{:} : {:}".format(idx, len(file_lists), file_path)
|
||||
)
|
||||
xdata = torch.load(file_path)
|
||||
if isinstance(xdata, list):
|
||||
data = xdata # image or video dataset list
|
||||
elif isinstance(xdata, dict):
|
||||
data = xdata["datas"] # multi-view dataset list
|
||||
else:
|
||||
raise ValueError("Invalid Type Error : {:}".format(type(xdata)))
|
||||
samples = samples + data
|
||||
# samples is a dict, where the key is the image-path and the value is the annotation
|
||||
# each annotation is a dict, contains 'points' (3,num_pts), and various box
|
||||
print("GeneralDataset-V2 : {:} samples".format(len(samples)))
|
||||
|
||||
# for index, annotation in enumerate(samples):
|
||||
for index in tqdm(range(len(samples))):
|
||||
annotation = samples[index]
|
||||
image_path = annotation["current_frame"]
|
||||
points, box = (
|
||||
annotation["points"],
|
||||
annotation["box-{:}".format(boxindicator)],
|
||||
)
|
||||
label = PointMeta2V(
|
||||
self.NUM_PTS, points, box, image_path, self.dataset_name
|
||||
)
|
||||
if normalizeL is None:
|
||||
normDistance = None
|
||||
else:
|
||||
normDistance = annotation["normalizeL-{:}".format(normalizeL)]
|
||||
self.append(image_path, label, normDistance)
|
||||
|
||||
assert (
|
||||
len(self.datas) == self.length
|
||||
), "The length and the data is not right {} vs {}".format(
|
||||
self.length, len(self.datas)
|
||||
)
|
||||
assert (
|
||||
len(self.labels) == self.length
|
||||
), "The length and the labels is not right {} vs {}".format(
|
||||
self.length, len(self.labels)
|
||||
)
|
||||
assert (
|
||||
len(self.NormDistances) == self.length
|
||||
), "The length and the NormDistances is not right {} vs {}".format(
|
||||
self.length, len(self.NormDistance)
|
||||
)
|
||||
print(
|
||||
"Load data done for LandmarkDataset, which has {:} images.".format(
|
||||
self.length
|
||||
)
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
assert index >= 0 and index < self.length, "Invalid index : {:}".format(index)
|
||||
if self.cache_images is not None and self.datas[index] in self.cache_images:
|
||||
image = self.cache_images[self.datas[index]].clone()
|
||||
else:
|
||||
image = pil_loader(self.datas[index], self.use_gray)
|
||||
target = self.labels[index].copy()
|
||||
return self._process_(image, target, index)
|
||||
|
||||
def _process_(self, image, target, index):
|
||||
|
||||
# transform the image and points
|
||||
image, target, theta = self.transform(image, target)
|
||||
(C, H, W), (height, width) = image.size(), self.shape
|
||||
|
||||
# obtain the visiable indicator vector
|
||||
if target.is_none():
|
||||
nopoints = True
|
||||
else:
|
||||
nopoints = False
|
||||
if index == -1:
|
||||
__path = None
|
||||
else:
|
||||
__path = self.datas[index]
|
||||
if isinstance(theta, list) or isinstance(theta, tuple):
|
||||
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
for _theta in theta:
|
||||
(
|
||||
_affineImage,
|
||||
_heatmaps,
|
||||
_mask,
|
||||
_norm_trans_points,
|
||||
_theta,
|
||||
_transpose_theta,
|
||||
) = self.__process_affine(
|
||||
image, target, _theta, nopoints, "P[{:}]@{:}".format(index, __path)
|
||||
)
|
||||
affineImage.append(_affineImage)
|
||||
heatmaps.append(_heatmaps)
|
||||
mask.append(_mask)
|
||||
norm_trans_points.append(_norm_trans_points)
|
||||
THETA.append(_theta)
|
||||
transpose_theta.append(_transpose_theta)
|
||||
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = (
|
||||
torch.stack(affineImage),
|
||||
torch.stack(heatmaps),
|
||||
torch.stack(mask),
|
||||
torch.stack(norm_trans_points),
|
||||
torch.stack(THETA),
|
||||
torch.stack(transpose_theta),
|
||||
)
|
||||
else:
|
||||
(
|
||||
affineImage,
|
||||
heatmaps,
|
||||
mask,
|
||||
norm_trans_points,
|
||||
THETA,
|
||||
transpose_theta,
|
||||
) = self.__process_affine(
|
||||
image, target, theta, nopoints, "S[{:}]@{:}".format(index, __path)
|
||||
)
|
||||
|
||||
torch_index = torch.IntTensor([index])
|
||||
torch_nopoints = torch.ByteTensor([nopoints])
|
||||
torch_shape = torch.IntTensor([H, W])
|
||||
|
||||
return (
|
||||
affineImage,
|
||||
heatmaps,
|
||||
mask,
|
||||
norm_trans_points,
|
||||
THETA,
|
||||
transpose_theta,
|
||||
torch_index,
|
||||
torch_nopoints,
|
||||
torch_shape,
|
||||
)
|
||||
|
||||
def __process_affine(self, image, target, theta, nopoints, aux_info=None):
|
||||
image, target, theta = image.clone(), target.copy(), theta.clone()
|
||||
(C, H, W), (height, width) = image.size(), self.shape
|
||||
if nopoints: # do not have label
|
||||
norm_trans_points = torch.zeros((3, self.NUM_PTS))
|
||||
heatmaps = torch.zeros(
|
||||
(self.NUM_PTS + 1, height // self.downsample, width // self.downsample)
|
||||
)
|
||||
mask = torch.ones((self.NUM_PTS + 1, 1, 1), dtype=torch.uint8)
|
||||
transpose_theta = identity2affine(False)
|
||||
else:
|
||||
norm_trans_points = apply_affine2point(target.get_points(), theta, (H, W))
|
||||
norm_trans_points = apply_boundary(norm_trans_points)
|
||||
real_trans_points = norm_trans_points.clone()
|
||||
real_trans_points[:2, :] = denormalize_points(
|
||||
self.shape, real_trans_points[:2, :]
|
||||
)
|
||||
heatmaps, mask = generate_label_map(
|
||||
real_trans_points.numpy(),
|
||||
height // self.downsample,
|
||||
width // self.downsample,
|
||||
self.sigma,
|
||||
self.downsample,
|
||||
nopoints,
|
||||
self.heatmap_type,
|
||||
) # H*W*C
|
||||
heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(
|
||||
torch.FloatTensor
|
||||
)
|
||||
mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor)
|
||||
if self.mean_face is None:
|
||||
# warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.')
|
||||
transpose_theta = identity2affine(False)
|
||||
else:
|
||||
if torch.sum(norm_trans_points[2, :] == 1) < 3:
|
||||
warnings.warn(
|
||||
"In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}".format(
|
||||
aux_info
|
||||
)
|
||||
)
|
||||
transpose_theta = identity2affine(False)
|
||||
else:
|
||||
transpose_theta = solve2theta(
|
||||
norm_trans_points, self.mean_face.clone()
|
||||
)
|
||||
|
||||
affineImage = affine2image(image, theta, self.shape)
|
||||
if self.cutout is not None:
|
||||
affineImage = self.cutout(affineImage)
|
||||
|
||||
return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta
|
||||
|
Reference in New Issue
Block a user