add naswot
This commit is contained in:
138
graph_dit/naswot/pycls/core/logging.py
Normal file
138
graph_dit/naswot/pycls/core/logging.py
Normal file
@@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Logging."""
|
||||
|
||||
import builtins
|
||||
import decimal
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pycls.core.distributed as dist
|
||||
import simplejson
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
# Show filename and line number in logs
|
||||
_FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"
|
||||
|
||||
# Log file name (for cfg.LOG_DEST = 'file')
|
||||
_LOG_FILE = "stdout.log"
|
||||
|
||||
# Data output with dump_log_data(data, data_type) will be tagged w/ this
|
||||
_TAG = "json_stats: "
|
||||
|
||||
# Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type
|
||||
_TYPE = "_type"
|
||||
|
||||
|
||||
def _suppress_print():
|
||||
"""Suppresses printing from the current process."""
|
||||
|
||||
def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
|
||||
pass
|
||||
|
||||
builtins.print = ignore
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""Sets up the logging."""
|
||||
# Enable logging only for the master process
|
||||
if dist.is_master_proc():
|
||||
# Clear the root logger to prevent any existing logging config
|
||||
# (e.g. set by another module) from messing with our setup
|
||||
logging.root.handlers = []
|
||||
# Construct logging configuration
|
||||
logging_config = {"level": logging.INFO, "format": _FORMAT}
|
||||
# Log either to stdout or to a file
|
||||
if cfg.LOG_DEST == "stdout":
|
||||
logging_config["stream"] = sys.stdout
|
||||
else:
|
||||
logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
|
||||
# Configure logging
|
||||
logging.basicConfig(**logging_config)
|
||||
else:
|
||||
_suppress_print()
|
||||
|
||||
|
||||
def get_logger(name):
|
||||
"""Retrieves the logger."""
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def dump_log_data(data, data_type, prec=4):
|
||||
"""Covert data (a dictionary) into tagged json string for logging."""
|
||||
data[_TYPE] = data_type
|
||||
data = float_to_decimal(data, prec)
|
||||
data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True)
|
||||
return "{:s}{:s}".format(_TAG, data_json)
|
||||
|
||||
|
||||
def float_to_decimal(data, prec=4):
|
||||
"""Convert floats to decimals which allows for fixed width json."""
|
||||
if isinstance(data, dict):
|
||||
return {k: float_to_decimal(v, prec) for k, v in data.items()}
|
||||
if isinstance(data, float):
|
||||
return decimal.Decimal(("{:." + str(prec) + "f}").format(data))
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE):
|
||||
"""Get all log files in directory containing subdirs of trained models."""
|
||||
names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
|
||||
files = [os.path.join(log_dir, n, log_file) for n in names]
|
||||
f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
|
||||
files, names = zip(*f_n_ps) if f_n_ps else ([], [])
|
||||
return files, names
|
||||
|
||||
|
||||
def load_log_data(log_file, data_types_to_skip=()):
|
||||
"""Loads log data into a dictionary of the form data[data_type][metric][index]."""
|
||||
# Load log_file
|
||||
assert os.path.exists(log_file), "Log file not found: {}".format(log_file)
|
||||
with open(log_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
# Extract and parse lines that start with _TAG and have a type specified
|
||||
lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
|
||||
lines = [simplejson.loads(l) for l in lines]
|
||||
lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip]
|
||||
# Generate data structure accessed by data[data_type][index][metric]
|
||||
data_types = [l[_TYPE] for l in lines]
|
||||
data = {t: [] for t in data_types}
|
||||
for t, line in zip(data_types, lines):
|
||||
del line[_TYPE]
|
||||
data[t].append(line)
|
||||
# Generate data structure accessed by data[data_type][metric][index]
|
||||
for t in data:
|
||||
metrics = sorted(data[t][0].keys())
|
||||
err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics)
|
||||
assert all(sorted(d.keys()) == metrics for d in data[t]), err_str
|
||||
data[t] = {m: [d[m] for d in data[t]] for m in metrics}
|
||||
return data
|
||||
|
||||
|
||||
def sort_log_data(data):
|
||||
"""Sort each data[data_type][metric] by epoch or keep only first instance."""
|
||||
for t in data:
|
||||
if "epoch" in data[t]:
|
||||
assert "epoch_ind" not in data[t] and "epoch_max" not in data[t]
|
||||
data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]]
|
||||
data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]]
|
||||
epoch = data[t]["epoch_ind"]
|
||||
if "iter" in data[t]:
|
||||
assert "iter_ind" not in data[t] and "iter_max" not in data[t]
|
||||
data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]]
|
||||
data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]]
|
||||
itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"])
|
||||
epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr]
|
||||
for m in data[t]:
|
||||
data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))]
|
||||
else:
|
||||
data[t] = {m: d[0] for m, d in data[t].items()}
|
||||
return data
|
Reference in New Issue
Block a user