Update find_best API

This commit is contained in:
D-X-Y
2020-11-20 09:52:29 +08:00
parent a9eec30b05
commit 8949d0b18e
3 changed files with 60 additions and 6 deletions

View File

@@ -92,6 +92,10 @@ class ImageNet16(data.Dataset):
#std_data = np.mean(np.mean(std_data, axis=0), axis=0)
#print ('Std : {:}'.format(std_data))
def __repr__(self):
return ('{name}({num} images, {classes} classes)'.format(name=self.__class__.__name__, num=len(self.data), classes=len(set(self.targets))))
def __getitem__(self, index):
img, target = self.data[index], self.targets[index] - 1
@@ -114,16 +118,16 @@ class ImageNet16(data.Dataset):
return False
return True
#
"""
if __name__ == '__main__':
train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None)
valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)
train = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None)
valid = ImageNet16('~/.torch/cifar.python/ImageNet16', False, None)
print ( len(train) )
print ( len(valid) )
image, label = train[111]
trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200)
validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200)
trainX = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None, 200)
validX = ImageNet16('~/.torch/cifar.python/ImageNet16', False , None, 200)
print ( len(trainX) )
print ( len(validX) )
#import pdb; pdb.set_trace()
"""