Update find_best API
This commit is contained in:
@@ -92,6 +92,10 @@ class ImageNet16(data.Dataset):
|
||||
#std_data = np.mean(np.mean(std_data, axis=0), axis=0)
|
||||
#print ('Std : {:}'.format(std_data))
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}({num} images, {classes} classes)'.format(name=self.__class__.__name__, num=len(self.data), classes=len(set(self.targets))))
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.data[index], self.targets[index] - 1
|
||||
|
||||
@@ -114,16 +118,16 @@ class ImageNet16(data.Dataset):
|
||||
return False
|
||||
return True
|
||||
|
||||
#
|
||||
"""
|
||||
if __name__ == '__main__':
|
||||
train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None)
|
||||
valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)
|
||||
train = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None)
|
||||
valid = ImageNet16('~/.torch/cifar.python/ImageNet16', False, None)
|
||||
|
||||
print ( len(train) )
|
||||
print ( len(valid) )
|
||||
image, label = train[111]
|
||||
trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200)
|
||||
validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200)
|
||||
trainX = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None, 200)
|
||||
validX = ImageNet16('~/.torch/cifar.python/ImageNet16', False , None, 200)
|
||||
print ( len(trainX) )
|
||||
print ( len(validX) )
|
||||
#import pdb; pdb.set_trace()
|
||||
"""
|
||||
|
Reference in New Issue
Block a user