update code styles

This commit is contained in:
D-X-Y
2020-01-09 22:26:23 +11:00
parent 5ac5060a33
commit ad34af9913
26 changed files with 192 additions and 81 deletions

View File

@@ -9,16 +9,25 @@ class SearchDataset(data.Dataset):
def __init__(self, name, data, train_split, valid_split, check=True):
self.datasetname = name
self.data = data
self.train_split = train_split.copy()
self.valid_split = valid_split.copy()
if check:
intersection = set(train_split).intersection(set(valid_split))
assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection'
if isinstance(data, (list, tuple)): # new type of SearchDataset
assert len(data) == 2, 'invalid length: {:}'.format( len(data) )
self.train_data = data[0]
self.valid_data = data[1]
self.train_split = train_split.copy()
self.valid_split = valid_split.copy()
self.mode_str = 'V2' # new mode
else:
self.mode_str = 'V1' # old mode
self.data = data
self.train_split = train_split.copy()
self.valid_split = valid_split.copy()
if check:
intersection = set(train_split).intersection(set(valid_split))
assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection'
self.length = len(self.train_split)
def __repr__(self):
return ('{name}(name={datasetname}, train={tr_L}, valid={val_L})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split)))
return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str))
def __len__(self):
return self.length
@@ -27,6 +36,11 @@ class SearchDataset(data.Dataset):
assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index)
train_index = self.train_split[index]
valid_index = random.choice( self.valid_split )
train_image, train_label = self.data[train_index]
valid_image, valid_label = self.data[valid_index]
if self.mode_str == 'V1':
train_image, train_label = self.data[train_index]
valid_image, valid_label = self.data[valid_index]
elif self.mode_str == 'V2':
train_image, train_label = self.train_data[train_index]
valid_image, valid_label = self.valid_data[valid_index]
else: raise ValueError('invalid mode : {:}'.format(self.mode_str))
return train_image, train_label, valid_image, valid_label

View File

@@ -34,7 +34,7 @@ class PointMeta():
def get_box(self, return_diagonal=False):
if self.box is None: return None
if return_diagonal == False:
if not return_diagonal:
return self.box.clone()
else:
W = (self.box[2]-self.box[0]).item()