This commit is contained in:
D-X-Y
2021-04-26 21:44:03 +08:00
parent 8358d71cdf
commit d3371296a7
10 changed files with 270 additions and 264 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,152 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "filled-multiple",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
]
}
],
"source": [
"import os, sys\n",
"import torch\n",
"from pathlib import Path\n",
"import numpy as np\n",
"import matplotlib\n",
"from matplotlib import cm\n",
"matplotlib.use(\"agg\")\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker\n",
"\n",
"\n",
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
"lib_dir = (root_dir / \"lib\").resolve()\n",
"print(\"The root path: {:}\".format(root_dir))\n",
"print(\"The library path: {:}\".format(lib_dir))\n",
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
"if str(lib_dir) not in sys.path:\n",
" sys.path.insert(0, str(lib_dir))\n",
"\n",
"from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv\n",
"from datasets import DynamicQuadraticFunc\n",
"from datasets.synthetic_example import create_example_v1"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "detected-second",
"metadata": {},
"outputs": [],
"source": [
"def draw_fig(save_dir, timestamp, xaxis, yaxis):\n",
" save_path = save_dir / '{:04d}'.format(timestamp)\n",
" # print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path))\n",
" dpi, width, height = 40, 1500, 1500\n",
" figsize = width / float(dpi), height / float(dpi)\n",
" LabelSize, LegendFontsize, font_gap = 80, 80, 5\n",
"\n",
" fig = plt.figure(figsize=figsize)\n",
" \n",
" cur_ax = fig.add_subplot(1, 1, 1)\n",
" cur_ax.scatter(xaxis, yaxis, color=\"k\", s=10, alpha=0.9, label=\"Timestamp={:02d}\".format(timestamp))\n",
" cur_ax.set_xlabel(\"X\", fontsize=LabelSize)\n",
" cur_ax.set_ylabel(\"f(X)\", rotation=0, fontsize=LabelSize)\n",
" cur_ax.set_xlim(-6, 6)\n",
" cur_ax.set_ylim(-40, 40)\n",
" for tick in cur_ax.xaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" tick.label.set_rotation(10)\n",
" for tick in cur_ax.yaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" \n",
" plt.legend(loc=1, fontsize=LegendFontsize)\n",
" fig.savefig(str(save_path) + '.pdf', dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
" fig.savefig(str(save_path) + '.png', dpi=dpi, bbox_inches=\"tight\", format=\"png\")\n",
" plt.close(\"all\")\n",
"\n",
"\n",
"def visualize_env(save_dir):\n",
" save_dir.mkdir(parents=True, exist_ok=True)\n",
" dynamic_env, function = create_example_v1(100, num_per_task=500)\n",
" \n",
" additional_xaxis = np.arange(-6, 6, 0.1)\n",
" for timestamp, dataset in dynamic_env:\n",
" num = dataset.shape[0]\n",
" # timeaxis = (torch.zeros(num) + timestamp).numpy()\n",
" xaxis = dataset[:,0].numpy()\n",
" xaxis = np.concatenate((additional_xaxis, xaxis))\n",
" # compute the ground truth\n",
" function.set_timestamp(timestamp)\n",
" yaxis = function(xaxis)\n",
" draw_fig(save_dir, timestamp, xaxis, yaxis)\n",
"\n",
"home_dir = Path.home()\n",
"desktop_dir = home_dir / 'Desktop'\n",
"vis_save_dir = desktop_dir / 'vis-synthetic'\n",
"visualize_env(vis_save_dir)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "rapid-uruguay",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ffmpeg -y -i /Users/xuanyidong/Desktop/vis-synthetic/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1000:1000 -vb 5000k /Users/xuanyidong/Desktop/vis-synthetic/vis.mp4\n"
]
},
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Plot the data\n",
"cmd = 'ffmpeg -y -i {:}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1000:1000 -vb 5000k {:}/vis.mp4'.format(vis_save_dir, vis_save_dir)\n",
"print(cmd)\n",
"os.system(cmd)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}