Fix small bugs
This commit is contained in:
@@ -60,14 +60,14 @@ def get_datasets(name, root, cutout):
|
||||
else: raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
if name == 'cifar10':
|
||||
train_data = dset.CIFAR10(root, train=True , transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR10(root, train=False, transform=test_transform , download=True)
|
||||
train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
|
||||
elif name == 'cifar100':
|
||||
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
|
||||
elif name == 'imagenet-1k' or name == 'imagenet-100':
|
||||
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
||||
test_data = dset.ImageFolder(osp.join(root, 'val'), train_transform)
|
||||
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
|
||||
else: raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
class_num = Dataset2Class[name]
|
||||
|
@@ -80,6 +80,9 @@ class NetworkImageNet(nn.Module):
|
||||
def update_drop_path(self, drop_path_prob):
|
||||
self.drop_path_prob = drop_path_prob
|
||||
|
||||
def get_drop_path(self):
|
||||
return self.drop_path_prob
|
||||
|
||||
def auxiliary_param(self):
|
||||
if self.auxiliary_head is None: return []
|
||||
else: return list( self.auxiliary_head.parameters() )
|
||||
|
Reference in New Issue
Block a user