Update xmisc with yaml

This commit is contained in:
D-X-Y
2021-06-10 02:11:27 -07:00
parent aef5c7579b
commit 1a7440d2af
11 changed files with 259 additions and 76 deletions

View File

@@ -1,35 +1,28 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
#####################################################
# python exps/basic/xmain.py --save_dir outputs/x #
#####################################################
import sys, time, torch, random, argparse
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from pathlib import Path
from xautodl.datasets import get_datasets
from xautodl.config_utils import load_config, obtain_basic_args as obtain_args
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.procedures import get_optim_scheduler, get_procedures
from xautodl.models import obtain_model
from xautodl.xmodels import obtain_model as obtain_xmodel
from xautodl.nas_infer_model import obtain_nas_infer_model
from xautodl.utils import get_model_infos
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.xmisc import nested_call_by_yaml
def main(args):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
# torch.set_num_threads(args.workers)
train_data = nested_call_by_yaml(args.train_data_config, args.data_path)
valid_data = nested_call_by_yaml(args.valid_data_config, args.data_path)
import pdb
pdb.set_trace()
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
@@ -290,5 +283,44 @@ def main(args):
if __name__ == "__main__":
args = obtain_args()
parser = argparse.ArgumentParser(
description="Train a model with a loss function.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--save_dir", type=str, help="Folder to save checkpoints and log."
)
parser.add_argument("--resume", type=str, help="Resume path.")
parser.add_argument("--init_model", type=str, help="The initialization model path.")
parser.add_argument("--model_config", type=str, help="The path to the model config")
parser.add_argument(
"--optim_config", type=str, help="The path to the optimizer config"
)
parser.add_argument(
"--train_data_config", type=str, help="The dataset config path."
)
parser.add_argument(
"--valid_data_config", type=str, help="The dataset config path."
)
parser.add_argument(
"--data_path", type=str, help="The path to the dataset."
)
parser.add_argument("--algorithm", type=str, help="The algorithm.")
# Optimization options
parser.add_argument("--batch_size", type=int, default=2, help="The batch size.")
parser.add_argument(
"--workers",
type=int,
default=8,
help="number of data loading workers (default: 8)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
if args.save_dir is None:
raise ValueError("The save-path argument can not be None")
main(args)