Refine lib -> xautodl
This commit is contained in:
@@ -25,7 +25,7 @@ def main(args):
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(args.workers)
|
||||
# torch.set_num_threads(args.workers)
|
||||
|
||||
prepare_seed(args.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
@@ -470,7 +470,7 @@ if __name__ == "__main__":
|
||||
assert torch.cuda.is_available(), "CUDA is not available."
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(args.workers)
|
||||
# torch.set_num_threads(args.workers)
|
||||
|
||||
main(
|
||||
save_dir,
|
||||
|
@@ -340,7 +340,7 @@ def train_single_model(
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = True
|
||||
torch.set_num_threads(workers)
|
||||
# torch.set_num_threads(workers)
|
||||
|
||||
save_dir = (
|
||||
Path(save_dir)
|
||||
@@ -675,7 +675,7 @@ if __name__ == "__main__":
|
||||
assert torch.cuda.is_available(), "CUDA is not available."
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(args.workers if args.workers > 0 else 1)
|
||||
# torch.set_num_threads(args.workers if args.workers > 0 else 1)
|
||||
|
||||
main(
|
||||
save_dir,
|
||||
|
@@ -132,7 +132,7 @@ def select_action(policy):
|
||||
|
||||
|
||||
def main(xargs, api):
|
||||
torch.set_num_threads(4)
|
||||
# torch.set_num_threads(4)
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
|
@@ -204,7 +204,7 @@ def main(xargs):
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(xargs.workers)
|
||||
# torch.set_num_threads(xargs.workers)
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
|
@@ -8,17 +8,14 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, obtain_basic_args as obtain_args
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
|
||||
from procedures import get_optim_scheduler, get_procedures
|
||||
from datasets import get_datasets
|
||||
from models import obtain_model
|
||||
from nas_infer_model import obtain_nas_infer_model
|
||||
from utils import get_model_infos
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from xautodl.datasets import get_datasets
|
||||
from xautodl.config_utils import load_config, obtain_basic_args as obtain_args
|
||||
from xautodl.procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
|
||||
from xautodl.procedures import get_optim_scheduler, get_procedures
|
||||
from xautodl.models import obtain_model
|
||||
from xautodl.nas_infer_model import obtain_nas_infer_model
|
||||
from xautodl.utils import get_model_infos
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
|
||||
|
||||
def main(args):
|
||||
@@ -26,7 +23,7 @@ def main(args):
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(args.workers)
|
||||
# torch.set_num_threads(args.workers)
|
||||
|
||||
prepare_seed(args.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
@@ -10,21 +10,17 @@ import numpy as np
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
|
||||
print("lib_dir : {:}".format(lib_dir))
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import (
|
||||
from xautodl.config_utils import (
|
||||
load_config,
|
||||
configure2str,
|
||||
obtain_search_single_args as obtain_args,
|
||||
)
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
|
||||
from procedures import get_optim_scheduler, get_procedures
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from models import obtain_search_model, obtain_model, change_key
|
||||
from utils import get_model_infos
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from xautodl.procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
|
||||
from xautodl.procedures import get_optim_scheduler, get_procedures
|
||||
from xautodl.datasets import get_datasets, SearchDataset
|
||||
from xautodl.models import obtain_search_model, obtain_model, change_key
|
||||
from xautodl.utils import get_model_infos
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
|
||||
|
||||
def main(args):
|
||||
@@ -32,7 +28,7 @@ def main(args):
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(args.workers)
|
||||
# torch.set_num_threads(args.workers)
|
||||
|
||||
prepare_seed(args.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
Reference in New Issue
Block a user