Refine lib -> xautodl

This commit is contained in:
D-X-Y
2021-05-19 08:10:42 +00:00
parent bd407ac4dc
commit 1c6c3e7166
12 changed files with 83 additions and 53 deletions

View File

@@ -25,7 +25,7 @@ def main(args):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
torch.set_num_threads(args.workers)
# torch.set_num_threads(args.workers)
prepare_seed(args.rand_seed)
logger = prepare_logger(args)

View File

@@ -470,7 +470,7 @@ if __name__ == "__main__":
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.set_num_threads(args.workers)
# torch.set_num_threads(args.workers)
main(
save_dir,

View File

@@ -340,7 +340,7 @@ def train_single_model(
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = True
torch.set_num_threads(workers)
# torch.set_num_threads(workers)
save_dir = (
Path(save_dir)
@@ -675,7 +675,7 @@ if __name__ == "__main__":
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.set_num_threads(args.workers if args.workers > 0 else 1)
# torch.set_num_threads(args.workers if args.workers > 0 else 1)
main(
save_dir,

View File

@@ -132,7 +132,7 @@ def select_action(policy):
def main(xargs, api):
torch.set_num_threads(4)
# torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)

View File

@@ -204,7 +204,7 @@ def main(xargs):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
# torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)

View File

@@ -8,17 +8,14 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, obtain_basic_args as obtain_args
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
from procedures import get_optim_scheduler, get_procedures
from datasets import get_datasets
from models import obtain_model
from nas_infer_model import obtain_nas_infer_model
from utils import get_model_infos
from log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.datasets import get_datasets
from xautodl.config_utils import load_config, obtain_basic_args as obtain_args
from xautodl.procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
from xautodl.procedures import get_optim_scheduler, get_procedures
from xautodl.models import obtain_model
from xautodl.nas_infer_model import obtain_nas_infer_model
from xautodl.utils import get_model_infos
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
def main(args):
@@ -26,7 +23,7 @@ def main(args):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
torch.set_num_threads(args.workers)
# torch.set_num_threads(args.workers)
prepare_seed(args.rand_seed)
logger = prepare_logger(args)

View File

@@ -10,21 +10,17 @@ import numpy as np
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
print("lib_dir : {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import (
from xautodl.config_utils import (
load_config,
configure2str,
obtain_search_single_args as obtain_args,
)
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
from procedures import get_optim_scheduler, get_procedures
from datasets import get_datasets, SearchDataset
from models import obtain_search_model, obtain_model, change_key
from utils import get_model_infos
from log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
from xautodl.procedures import get_optim_scheduler, get_procedures
from xautodl.datasets import get_datasets, SearchDataset
from xautodl.models import obtain_search_model, obtain_model, change_key
from xautodl.utils import get_model_infos
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
def main(args):
@@ -32,7 +28,7 @@ def main(args):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
torch.set_num_threads(args.workers)
# torch.set_num_threads(args.workers)
prepare_seed(args.rand_seed)
logger = prepare_logger(args)