Move str2bool to config_utils
This commit is contained in:
135
lib/config_utils/config_utils.py
Normal file
135
lib/config_utils/config_utils.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
import os, json
|
||||
from os import path as osp
|
||||
from pathlib import Path
|
||||
from collections import namedtuple
|
||||
|
||||
support_types = ("str", "int", "bool", "float", "none")
|
||||
|
||||
|
||||
def convert_param(original_lists):
|
||||
assert isinstance(original_lists, list), "The type is not right : {:}".format(
|
||||
original_lists
|
||||
)
|
||||
ctype, value = original_lists[0], original_lists[1]
|
||||
assert ctype in support_types, "Ctype={:}, support={:}".format(ctype, support_types)
|
||||
is_list = isinstance(value, list)
|
||||
if not is_list:
|
||||
value = [value]
|
||||
outs = []
|
||||
for x in value:
|
||||
if ctype == "int":
|
||||
x = int(x)
|
||||
elif ctype == "str":
|
||||
x = str(x)
|
||||
elif ctype == "bool":
|
||||
x = bool(int(x))
|
||||
elif ctype == "float":
|
||||
x = float(x)
|
||||
elif ctype == "none":
|
||||
if x.lower() != "none":
|
||||
raise ValueError(
|
||||
"For the none type, the value must be none instead of {:}".format(x)
|
||||
)
|
||||
x = None
|
||||
else:
|
||||
raise TypeError("Does not know this type : {:}".format(ctype))
|
||||
outs.append(x)
|
||||
if not is_list:
|
||||
outs = outs[0]
|
||||
return outs
|
||||
|
||||
|
||||
def load_config(path, extra, logger):
|
||||
path = str(path)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log(path)
|
||||
assert os.path.exists(path), "Can not find {:}".format(path)
|
||||
# Reading data back
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
content = {k: convert_param(v) for k, v in data.items()}
|
||||
assert extra is None or isinstance(
|
||||
extra, dict
|
||||
), "invalid type of extra : {:}".format(extra)
|
||||
if isinstance(extra, dict):
|
||||
content = {**content, **extra}
|
||||
Arguments = namedtuple("Configure", " ".join(content.keys()))
|
||||
content = Arguments(**content)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("{:}".format(content))
|
||||
return content
|
||||
|
||||
|
||||
def configure2str(config, xpath=None):
|
||||
if not isinstance(config, dict):
|
||||
config = config._asdict()
|
||||
|
||||
def cstring(x):
|
||||
return '"{:}"'.format(x)
|
||||
|
||||
def gtype(x):
|
||||
if isinstance(x, list):
|
||||
x = x[0]
|
||||
if isinstance(x, str):
|
||||
return "str"
|
||||
elif isinstance(x, bool):
|
||||
return "bool"
|
||||
elif isinstance(x, int):
|
||||
return "int"
|
||||
elif isinstance(x, float):
|
||||
return "float"
|
||||
elif x is None:
|
||||
return "none"
|
||||
else:
|
||||
raise ValueError("invalid : {:}".format(x))
|
||||
|
||||
def cvalue(x, xtype):
|
||||
if isinstance(x, list):
|
||||
is_list = True
|
||||
else:
|
||||
is_list, x = False, [x]
|
||||
temps = []
|
||||
for temp in x:
|
||||
if xtype == "bool":
|
||||
temp = cstring(int(temp))
|
||||
elif xtype == "none":
|
||||
temp = cstring("None")
|
||||
else:
|
||||
temp = cstring(temp)
|
||||
temps.append(temp)
|
||||
if is_list:
|
||||
return "[{:}]".format(", ".join(temps))
|
||||
else:
|
||||
return temps[0]
|
||||
|
||||
xstrings = []
|
||||
for key, value in config.items():
|
||||
xtype = gtype(value)
|
||||
string = " {:20s} : [{:8s}, {:}]".format(
|
||||
cstring(key), cstring(xtype), cvalue(value, xtype)
|
||||
)
|
||||
xstrings.append(string)
|
||||
Fstring = "{\n" + ",\n".join(xstrings) + "\n}"
|
||||
if xpath is not None:
|
||||
parent = Path(xpath).resolve().parent
|
||||
parent.mkdir(parents=True, exist_ok=True)
|
||||
if osp.isfile(xpath):
|
||||
os.remove(xpath)
|
||||
with open(xpath, "w") as text_file:
|
||||
text_file.write("{:}".format(Fstring))
|
||||
return Fstring
|
||||
|
||||
|
||||
def dict2config(xdict, logger):
|
||||
assert isinstance(xdict, dict), "invalid type : {:}".format(type(xdict))
|
||||
Arguments = namedtuple("Configure", " ".join(xdict.keys()))
|
||||
content = Arguments(**xdict)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("{:}".format(content))
|
||||
return content
|
Reference in New Issue
Block a user