Reformulate via black
This commit is contained in:
122
exps/prepare.py
122
exps/prepare.py
@@ -4,74 +4,78 @@
|
||||
import sys, time, torch, random, argparse
|
||||
from collections import defaultdict
|
||||
import os.path as osp
|
||||
from PIL import ImageFile
|
||||
from PIL import ImageFile
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
from copy import deepcopy
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import torchvision
|
||||
import torchvision.datasets as dset
|
||||
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
parser = argparse.ArgumentParser(description='Prepare splits for searching', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--name' , type=str, help='The dataset name.')
|
||||
parser.add_argument('--root' , type=str, help='The directory to the dataset.')
|
||||
parser.add_argument('--save' , type=str, help='The save path.')
|
||||
parser.add_argument('--ratio', type=float, help='The save path.')
|
||||
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Prepare splits for searching", formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument("--name", type=str, help="The dataset name.")
|
||||
parser.add_argument("--root", type=str, help="The directory to the dataset.")
|
||||
parser.add_argument("--save", type=str, help="The save path.")
|
||||
parser.add_argument("--ratio", type=float, help="The save path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
save_path = Path(args.save)
|
||||
save_dir = save_path.parent
|
||||
name = args.name
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
assert not save_path.exists(), '{:} already exists'.format(save_path)
|
||||
print ('torchvision version : {:}'.format(torchvision.__version__))
|
||||
save_path = Path(args.save)
|
||||
save_dir = save_path.parent
|
||||
name = args.name
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
assert not save_path.exists(), "{:} already exists".format(save_path)
|
||||
print("torchvision version : {:}".format(torchvision.__version__))
|
||||
|
||||
if name == 'cifar10':
|
||||
dataset = dset.CIFAR10 (args.root, train=True)
|
||||
elif name == 'cifar100':
|
||||
dataset = dset.CIFAR100(args.root, train=True)
|
||||
elif name == 'imagenet-1k':
|
||||
dataset = dset.ImageFolder(osp.join(args.root, 'train'))
|
||||
else: raise TypeError("Unknow dataset : {:}".format(name))
|
||||
if name == "cifar10":
|
||||
dataset = dset.CIFAR10(args.root, train=True)
|
||||
elif name == "cifar100":
|
||||
dataset = dset.CIFAR100(args.root, train=True)
|
||||
elif name == "imagenet-1k":
|
||||
dataset = dset.ImageFolder(osp.join(args.root, "train"))
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
if hasattr(dataset, 'targets'):
|
||||
targets = dataset.targets
|
||||
elif hasattr(dataset, 'train_labels'):
|
||||
targets = dataset.train_labels
|
||||
elif hasattr(dataset, 'imgs'):
|
||||
targets = [x[1] for x in dataset.imgs]
|
||||
else:
|
||||
raise ValueError('invalid pattern')
|
||||
print ('There are {:} samples in this dataset.'.format( len(targets) ))
|
||||
if hasattr(dataset, "targets"):
|
||||
targets = dataset.targets
|
||||
elif hasattr(dataset, "train_labels"):
|
||||
targets = dataset.train_labels
|
||||
elif hasattr(dataset, "imgs"):
|
||||
targets = [x[1] for x in dataset.imgs]
|
||||
else:
|
||||
raise ValueError("invalid pattern")
|
||||
print("There are {:} samples in this dataset.".format(len(targets)))
|
||||
|
||||
class2index = defaultdict(list)
|
||||
train, valid = [], []
|
||||
random.seed(111)
|
||||
for index, cls in enumerate(targets):
|
||||
class2index[cls].append( index )
|
||||
classes = sorted( list(class2index.keys()) )
|
||||
for cls in classes:
|
||||
xlist = class2index[cls]
|
||||
xtrain = random.sample(xlist, int(len(xlist)*args.ratio))
|
||||
xvalid = list(set(xlist) - set(xtrain))
|
||||
train += xtrain
|
||||
valid += xvalid
|
||||
train.sort()
|
||||
valid.sort()
|
||||
## for statistics
|
||||
class2numT, class2numV = defaultdict(int), defaultdict(int)
|
||||
for index in train:
|
||||
class2numT[ targets[index] ] += 1
|
||||
for index in valid:
|
||||
class2numV[ targets[index] ] += 1
|
||||
class2numT, class2numV = dict(class2numT), dict(class2numV)
|
||||
torch.save({'train': train,
|
||||
'valid': valid,
|
||||
'class2numTrain': class2numT,
|
||||
'class2numValid': class2numV}, save_path)
|
||||
print ('-'*80)
|
||||
class2index = defaultdict(list)
|
||||
train, valid = [], []
|
||||
random.seed(111)
|
||||
for index, cls in enumerate(targets):
|
||||
class2index[cls].append(index)
|
||||
classes = sorted(list(class2index.keys()))
|
||||
for cls in classes:
|
||||
xlist = class2index[cls]
|
||||
xtrain = random.sample(xlist, int(len(xlist) * args.ratio))
|
||||
xvalid = list(set(xlist) - set(xtrain))
|
||||
train += xtrain
|
||||
valid += xvalid
|
||||
train.sort()
|
||||
valid.sort()
|
||||
## for statistics
|
||||
class2numT, class2numV = defaultdict(int), defaultdict(int)
|
||||
for index in train:
|
||||
class2numT[targets[index]] += 1
|
||||
for index in valid:
|
||||
class2numV[targets[index]] += 1
|
||||
class2numT, class2numV = dict(class2numT), dict(class2numV)
|
||||
torch.save({"train": train, "valid": valid, "class2numTrain": class2numT, "class2numValid": class2numV}, save_path)
|
||||
print("-" * 80)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user