Fix path errors in TAS due to lib->xautodl

This commit is contained in:
D-X-Y
2021-05-20 10:53:57 +08:00
parent b50ad2a522
commit b4e8eae63a
15 changed files with 57 additions and 56 deletions

View File

@@ -10,6 +10,10 @@ from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / "..").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,

View File

@@ -3,21 +3,16 @@
##############################################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
##############################################################################
# python ./exps/NATS-Bench/main-tss.py --mode meta #
# python ./exps/NATS-Bench/show-dataset.py #
##############################################################################
import os, sys, time, torch, random, argparse
from typing import List, Text, Dict, Any
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from datasets import get_datasets
from xautodl.config_utils import dict2config, load_config
from xautodl.datasets import get_datasets
from nats_bench import create

View File

@@ -12,16 +12,22 @@ import numpy as np
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, configure2str, obtain_search_args as obtain_args
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
from procedures import get_optim_scheduler, get_procedures
from datasets import get_datasets, SearchDataset
from models import obtain_search_model, obtain_model, change_key
from utils import get_model_infos
from log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.config_utils import (
load_config,
configure2str,
obtain_search_args as obtain_args,
)
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.procedures import get_optim_scheduler, get_procedures
from xautodl.datasets import get_datasets, SearchDataset
from xautodl.models import obtain_search_model, obtain_model, change_key
from xautodl.utils import get_model_infos
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
def main(args):

View File

@@ -8,16 +8,18 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, obtain_cls_kd_args as obtain_args
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
from procedures import get_optim_scheduler, get_procedures
from datasets import get_datasets
from models import obtain_model, load_net_from_checkpoint
from utils import get_model_infos
from log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.config_utils import load_config, obtain_cls_kd_args as obtain_args
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.procedures import get_optim_scheduler, get_procedures
from xautodl.datasets import get_datasets
from xautodl.models import obtain_model, load_net_from_checkpoint
from xautodl.utils import get_model_infos
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
def main(args):

View File

@@ -6,20 +6,13 @@ from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config
from procedures import get_procedures, get_optim_scheduler
from datasets import get_datasets
from models import obtain_model
from utils import get_model_infos
from log_utils import PrintLogger, time_string
assert torch.cuda.is_available(), "torch.cuda is not available"
from xautodl.config_utils import load_config, dict2config
from xautodl.procedures import get_procedures, get_optim_scheduler
from xautodl.datasets import get_datasets
from xautodl.models import obtain_model
from xautodl.utils import get_model_infos
from xautodl.log_utils import PrintLogger, time_string
def main(args):
@@ -118,4 +111,5 @@ if __name__ == "__main__":
"--checkpoint", type=str, help="Choose between Cifar10/100 and ImageNet."
)
args = parser.parse_args()
assert torch.cuda.is_available(), "torch.cuda is not available"
main(args)