Refine lib -> xautodl

This commit is contained in:
D-X-Y
2021-05-19 07:19:20 +00:00
parent bda202ce87
commit 5b9a028e60
19 changed files with 46 additions and 46 deletions

View File

@@ -27,8 +27,8 @@ from xautodl.datasets.synthetic_core import get_synthetic_env, EnvSampler
from xautodl.models.xcore import get_model
from xautodl.xlayers import super_core, trunc_normal_
from xautodl.lfna_utils import lfna_setup, train_model, TimeData
from xautodl.lfna_meta_model import LFNA_Meta
from lfna_utils import lfna_setup, train_model, TimeData
from lfna_meta_model import LFNA_Meta
def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger):

View File

@@ -4,8 +4,8 @@
import copy
import torch
from tqdm import tqdm
from procedures import prepare_seed, prepare_logger
from datasets.synthetic_core import get_synthetic_env
from xautodl.procedures import prepare_seed, prepare_logger
from xautodl.datasets.synthetic_core import get_synthetic_env
def lfna_setup(args):

View File

@@ -665,7 +665,7 @@ if __name__ == "__main__":
len(args.datasets), len(args.xpaths), len(args.splits)
)
)
if args.workers <= 0:
if args.workers < 0:
raise ValueError("invalid number of workers : {:}".format(args.workers))
target_indexes = filter_indexes(
@@ -675,7 +675,7 @@ if __name__ == "__main__":
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.set_num_threads(args.workers)
torch.set_num_threads(args.workers if args.workers > 0 else 1)
main(
save_dir,

View File

@@ -1,6 +1,10 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
#####################################################
# python exps/prepare.py --name cifar10 --root $TORCH_HOME/cifar.python --save ./data/cifar10.split.pth
# python exps/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python --save ./data/cifar100.split.pth
# python exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ./data/imagenet-1k.split.pth
#####################################################
import sys, time, torch, random, argparse
from collections import defaultdict
import os.path as osp
@@ -12,9 +16,6 @@ from pathlib import Path
import torchvision
import torchvision.datasets as dset
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
parser = argparse.ArgumentParser(
description="Prepare splits for searching",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
@@ -35,9 +36,9 @@ def main():
print("torchvision version : {:}".format(torchvision.__version__))
if name == "cifar10":
dataset = dset.CIFAR10(args.root, train=True)
dataset = dset.CIFAR10(args.root, train=True, download=True)
elif name == "cifar100":
dataset = dset.CIFAR100(args.root, train=True)
dataset = dset.CIFAR100(args.root, train=True, download=True)
elif name == "imagenet-1k":
dataset = dset.ImageFolder(osp.join(args.root, "train"))
else: