Update the sync data v1
This commit is contained in:
@@ -222,7 +222,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
|
||||
|
||||
def main(args):
|
||||
logger, env_info, model_kwargs = lfna_setup(args)
|
||||
logger, model_kwargs = lfna_setup(args)
|
||||
train_env = get_synthetic_env(mode="train", version=args.env_version)
|
||||
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
|
||||
all_env = get_synthetic_env(mode=None, version=args.env_version)
|
||||
|
@@ -11,33 +11,6 @@ from xautodl.datasets.synthetic_core import get_synthetic_env
|
||||
def lfna_setup(args):
|
||||
prepare_seed(args.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
cache_path = (
|
||||
logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version)
|
||||
).resolve()
|
||||
if cache_path.exists():
|
||||
env_info = torch.load(cache_path)
|
||||
else:
|
||||
env_info = dict()
|
||||
dynamic_env = get_synthetic_env(version=args.env_version)
|
||||
env_info["total"] = len(dynamic_env)
|
||||
for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)):
|
||||
env_info["{:}-timestamp".format(idx)] = timestamp
|
||||
env_info["{:}-x".format(idx)] = _allx
|
||||
env_info["{:}-y".format(idx)] = _ally
|
||||
env_info["dynamic_env"] = dynamic_env
|
||||
torch.save(env_info, cache_path)
|
||||
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
config=dict(model_type="simple_mlp"),
|
||||
input_dim=1,
|
||||
output_dim=1,
|
||||
hidden_dim=args.hidden_dim,
|
||||
act_cls="leaky_relu",
|
||||
norm_cls="identity",
|
||||
)
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
config=dict(model_type="norm_mlp"),
|
||||
input_dim=1,
|
||||
@@ -46,7 +19,7 @@ def lfna_setup(args):
|
||||
act_cls="gelu",
|
||||
norm_cls="layer_norm_1d",
|
||||
)
|
||||
return logger, env_info, model_kwargs
|
||||
return logger, model_kwargs
|
||||
|
||||
|
||||
def train_model(model, dataset, lr, epochs):
|
||||
|
@@ -20,14 +20,13 @@ matplotlib.use("agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
|
||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
||||
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from models.xcore import get_model
|
||||
from datasets.synthetic_core import get_synthetic_env
|
||||
from utils.temp_sync import optimize_fn, evaluate_fn
|
||||
from procedures.metric_utils import MSEMetric
|
||||
from xautodl.models.xcore import get_model
|
||||
from xautodl.datasets.synthetic_core import get_synthetic_env
|
||||
from xautodl.procedures.metric_utils import MSEMetric
|
||||
|
||||
|
||||
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
|
||||
@@ -181,10 +180,17 @@ def compare_cl(save_dir):
|
||||
|
||||
def visualize_env(save_dir, version):
|
||||
save_dir = Path(str(save_dir))
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
for substr in ("pdf", "png"):
|
||||
sub_save_dir = save_dir / substr
|
||||
sub_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dynamic_env = get_synthetic_env(version=version)
|
||||
min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp
|
||||
# min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp
|
||||
allxs, allys = [], []
|
||||
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||
allxs.append(allx)
|
||||
allys.append(ally)
|
||||
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
|
||||
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||
dpi, width, height = 30, 1800, 1400
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
@@ -201,21 +207,18 @@ def visualize_env(save_dir, version):
|
||||
tick.label.set_rotation(10)
|
||||
for tick in cur_ax.yaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(LabelSize - font_gap)
|
||||
if version == "v1":
|
||||
cur_ax.set_xlim(-2, 2)
|
||||
cur_ax.set_ylim(-8, 8)
|
||||
elif version == "v2":
|
||||
cur_ax.set_xlim(-10, 10)
|
||||
cur_ax.set_ylim(-60, 60)
|
||||
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
|
||||
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
|
||||
cur_ax.legend(loc=1, fontsize=LegendFontsize)
|
||||
|
||||
save_path = save_dir / "v{:}-{:05d}".format(version, idx)
|
||||
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
|
||||
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
|
||||
pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx)
|
||||
fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
|
||||
png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx)
|
||||
fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
|
||||
plt.close("all")
|
||||
save_dir = save_dir.resolve()
|
||||
base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
|
||||
xdir=save_dir, version=version
|
||||
xdir=save_dir / "png", version=version
|
||||
)
|
||||
print(base_cmd)
|
||||
os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
|
||||
@@ -371,7 +374,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v1")
|
||||
visualize_env(os.path.join(args.save_dir, "vis-env"), "v1")
|
||||
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
|
||||
compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
|
||||
# compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
|
||||
# compare_cl(os.path.join(args.save_dir, "compare-cl"))
|
||||
|
@@ -13,7 +13,10 @@ from xautodl.config_utils import dict2config
|
||||
|
||||
# NAS-Bench-201 related module or function
|
||||
from xautodl.models import CellStructure, get_cell_based_tiny_net
|
||||
from xautodl.procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
|
||||
from xautodl.procedures import (
|
||||
bench_pure_evaluate as pure_evaluate,
|
||||
get_nas_bench_loaders,
|
||||
)
|
||||
from nas_201_api import NASBench201API, ArchResults, ResultsCount
|
||||
|
||||
api = NASBench201API(
|
||||
|
21
exps/experimental/test-dynamic.py
Normal file
21
exps/experimental/test-dynamic.py
Normal file
@@ -0,0 +1,21 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
# python test-dynamic.py
|
||||
#####################################################
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
|
||||
print("LIB-DIR: {:}".format(lib_dir))
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from xautodl.datasets.math_core import ConstantFunc
|
||||
from xautodl.datasets.math_core import GaussianDGenerator
|
||||
|
||||
mean_generator = ConstantFunc(0)
|
||||
cov_generator = ConstantFunc(1)
|
||||
|
||||
generator = GaussianDGenerator([mean_generator], [[cov_generator]], (-1, 1))
|
||||
generator(0, 10)
|
@@ -19,9 +19,11 @@ import seaborn as sns
|
||||
matplotlib.use("agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
||||
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
|
||||
print("LIB-DIR: {:}".format(lib_dir))
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from log_utils import time_string
|
||||
from nats_bench import create
|
||||
from models import get_cell_based_tiny_net
|
||||
|
@@ -3,11 +3,7 @@ from copy import deepcopy
|
||||
import torchvision.models as models
|
||||
from pathlib import Path
|
||||
|
||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from utils import weight_watcher
|
||||
from xautodl.utils import weight_watcher
|
||||
|
||||
|
||||
def main():
|
||||
|
Reference in New Issue
Block a user