update codes
This commit is contained in:
@@ -10,7 +10,7 @@ from .TieredImageNet import TieredImageNet
|
||||
Dataset2Class = {'cifar10' : 10,
|
||||
'cifar100': 100,
|
||||
'tiered' : -1,
|
||||
'imagnet-1k' : 1000,
|
||||
'imagenet-1k' : 1000,
|
||||
'imagenet-100': 100}
|
||||
|
||||
|
||||
@@ -25,8 +25,8 @@ def get_datasets(name, root, cutout):
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
elif name == 'tiered':
|
||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
elif name == 'imagnet-1k' or name == 'imagenet-100':
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
elif name == 'imagenet-1k' or name == 'imagenet-100':
|
||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
else: raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ def get_datasets(name, root, cutout):
|
||||
if cutout > 0 : lists += [Cutout(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
elif name == 'imagnet-1k' or name == 'imagenet-100':
|
||||
elif name == 'imagenet-1k' or name == 'imagenet-100':
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_transform = transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
@@ -57,15 +57,14 @@ def get_datasets(name, root, cutout):
|
||||
])
|
||||
test_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
|
||||
else: raise TypeError("Unknow dataset : {:}".format(name))
|
||||
train_data = TieredImageNet(root, 'train-val', train_transform)
|
||||
test_data = None
|
||||
|
||||
if name == 'cifar10':
|
||||
train_data = dset.CIFAR10(root, train=True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR10(root, train=True, transform=test_transform , download=True)
|
||||
elif name == 'cifar100':
|
||||
train_data = dset.CIFAR100(root, train=True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(root, train=True, transform=test_transform , download=True)
|
||||
elif name == 'imagnet-1k' or name == 'imagenet-100':
|
||||
elif name == 'imagenet-1k' or name == 'imagenet-100':
|
||||
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
||||
test_data = dset.ImageFolder(osp.join(root, 'val'), train_transform)
|
||||
else: raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
@@ -4,7 +4,6 @@ from .utils import test_imagenet_data
|
||||
from .utils import print_log
|
||||
from .evaluation_utils import obtain_accuracy
|
||||
from .draw_pts import draw_points
|
||||
from .fb_transform import ApplyOffset
|
||||
from .gpu_manager import GPUManager
|
||||
|
||||
from .save_meta import Save_Meta
|
||||
|
@@ -1,14 +0,0 @@
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
class ApplyOffset(object):
|
||||
def __init__(self, offset):
|
||||
assert isinstance(offset, int), 'The offset is not right : {}'.format(offset)
|
||||
self.offset = offset
|
||||
def __call__(self, x):
|
||||
if isinstance(x, np.ndarray) and x.dtype == 'uint8':
|
||||
x = x.astype(int)
|
||||
if isinstance(x, np.ndarray) and x.size == 1:
|
||||
x = int(x)
|
||||
return x + self.offset
|
Reference in New Issue
Block a user