Update xmisc with yaml
This commit is contained in:
@@ -1,35 +1,28 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||
#####################################################
|
||||
# python exps/basic/xmain.py --save_dir outputs/x #
|
||||
#####################################################
|
||||
import sys, time, torch, random, argparse
|
||||
from PIL import ImageFile
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
from xautodl.datasets import get_datasets
|
||||
from xautodl.config_utils import load_config, obtain_basic_args as obtain_args
|
||||
from xautodl.procedures import (
|
||||
prepare_seed,
|
||||
prepare_logger,
|
||||
save_checkpoint,
|
||||
copy_checkpoint,
|
||||
)
|
||||
from xautodl.procedures import get_optim_scheduler, get_procedures
|
||||
from xautodl.models import obtain_model
|
||||
from xautodl.xmodels import obtain_model as obtain_xmodel
|
||||
from xautodl.nas_infer_model import obtain_nas_infer_model
|
||||
from xautodl.utils import get_model_infos
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
|
||||
print("LIB-DIR: {:}".format(lib_dir))
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from xautodl.xmisc import nested_call_by_yaml
|
||||
|
||||
|
||||
def main(args):
|
||||
assert torch.cuda.is_available(), "CUDA is not available."
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.set_num_threads(args.workers)
|
||||
|
||||
train_data = nested_call_by_yaml(args.train_data_config, args.data_path)
|
||||
valid_data = nested_call_by_yaml(args.valid_data_config, args.data_path)
|
||||
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
prepare_seed(args.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
@@ -290,5 +283,44 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = obtain_args()
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a model with a loss function.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dir", type=str, help="Folder to save checkpoints and log."
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||
parser.add_argument("--model_config", type=str, help="The path to the model config")
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer config"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_config", type=str, help="The dataset config path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--valid_data_config", type=str, help="The dataset config path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_path", type=str, help="The path to the dataset."
|
||||
)
|
||||
parser.add_argument("--algorithm", type=str, help="The algorithm.")
|
||||
# Optimization options
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="The batch size.")
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="number of data loading workers (default: 8)",
|
||||
)
|
||||
# Random Seed
|
||||
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
if args.save_dir is None:
|
||||
raise ValueError("The save-path argument can not be None")
|
||||
|
||||
main(args)
|
||||
|
Reference in New Issue
Block a user