Re-org GeMOSA codes

This commit is contained in:
D-X-Y
2021-05-27 11:17:57 +08:00
parent a507f8dd94
commit 8961215416
8 changed files with 82 additions and 162 deletions

View File

@@ -1,8 +1,9 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
############################################################################
# python exps/GMOA/vis-synthetic.py --env_version v1 #
# python exps/GMOA/vis-synthetic.py --env_version v2 #
# python exps/GeMOSA/vis-synthetic.py --env_version v1 #
# python exps/GeMOSA/vis-synthetic.py --env_version v2 #
# python exps/GeMOSA/vis-synthetic.py --env_version v2 #
############################################################################
import os, sys, copy, random
import torch
@@ -181,7 +182,7 @@ def compare_cl(save_dir):
def visualize_env(save_dir, version):
save_dir = Path(str(save_dir))
for substr in ("pdf", "png"):
sub_save_dir = save_dir / substr
sub_save_dir = save_dir / "{:}-{:}".format(substr, version)
sub_save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env = get_synthetic_env(version=version)
@@ -190,6 +191,8 @@ def visualize_env(save_dir, version):
allxs.append(allx)
allys.append(ally)
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
dpi, width, height = 30, 1800, 1400
figsize = width / float(dpi), height / float(dpi)
@@ -210,14 +213,22 @@ def visualize_env(save_dir, version):
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
cur_ax.legend(loc=1, fontsize=LegendFontsize)
pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx)
pdf_save_path = (
save_dir
/ "pdf-{:}".format(version)
/ "v{:}-{:05d}.pdf".format(version, idx)
)
fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx)
png_save_path = (
save_dir
/ "png-{:}".format(version)
/ "v{:}-{:05d}.png".format(version, idx)
)
fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
save_dir = save_dir.resolve()
base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
xdir=save_dir / "png", version=version
xdir=save_dir / "png-{:}".format(version), version=version
)
print(base_cmd)
os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
@@ -367,7 +378,7 @@ if __name__ == "__main__":
)
args = parser.parse_args()
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v1")
visualize_env(os.path.join(args.save_dir, "vis-env"), args.env_version)
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
# compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
# compare_cl(os.path.join(args.save_dir, "compare-cl"))