Update the sync data v1

This commit is contained in:
D-X-Y
2021-05-24 13:06:10 +08:00
parent da2575cc6c
commit 3ee0d348af
17 changed files with 228 additions and 274 deletions

View File

@@ -20,14 +20,13 @@ matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from models.xcore import get_model
from datasets.synthetic_core import get_synthetic_env
from utils.temp_sync import optimize_fn, evaluate_fn
from procedures.metric_utils import MSEMetric
from xautodl.models.xcore import get_model
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.procedures.metric_utils import MSEMetric
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
@@ -181,10 +180,17 @@ def compare_cl(save_dir):
def visualize_env(save_dir, version):
save_dir = Path(str(save_dir))
save_dir.mkdir(parents=True, exist_ok=True)
for substr in ("pdf", "png"):
sub_save_dir = save_dir / substr
sub_save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env = get_synthetic_env(version=version)
min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp
# min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp
allxs, allys = [], []
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
allxs.append(allx)
allys.append(ally)
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
dpi, width, height = 30, 1800, 1400
figsize = width / float(dpi), height / float(dpi)
@@ -201,21 +207,18 @@ def visualize_env(save_dir, version):
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
if version == "v1":
cur_ax.set_xlim(-2, 2)
cur_ax.set_ylim(-8, 8)
elif version == "v2":
cur_ax.set_xlim(-10, 10)
cur_ax.set_ylim(-60, 60)
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
cur_ax.legend(loc=1, fontsize=LegendFontsize)
save_path = save_dir / "v{:}-{:05d}".format(version, idx)
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx)
fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx)
fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
save_dir = save_dir.resolve()
base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
xdir=save_dir, version=version
xdir=save_dir / "png", version=version
)
print(base_cmd)
os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
@@ -371,7 +374,7 @@ if __name__ == "__main__":
)
args = parser.parse_args()
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v1")
visualize_env(os.path.join(args.save_dir, "vis-env"), "v1")
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
# compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
# compare_cl(os.path.join(args.save_dir, "compare-cl"))