Update sync codes

This commit is contained in:
D-X-Y
2021-04-14 01:04:46 +08:00
parent c82c7e9f3f
commit cd253112ee
5 changed files with 220 additions and 44 deletions

View File

@@ -5,17 +5,39 @@
"execution_count": 1,
"id": "filled-multiple",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
]
}
],
"source": [
"#\n",
"# %matplotlib notebook\n",
"import os, sys\n",
"import torch\n",
"from pathlib import Path\n",
"import numpy as np\n",
"import matplotlib\n",
"from matplotlib import cm\n",
"matplotlib.use(\"agg\")\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker"
"import matplotlib.ticker as ticker\n",
"\n",
"\n",
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
"lib_dir = (root_dir / \"lib\").resolve()\n",
"print(\"The root path: {:}\".format(root_dir))\n",
"print(\"The library path: {:}\".format(lib_dir))\n",
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
"if str(lib_dir) not in sys.path:\n",
" sys.path.insert(0, str(lib_dir))\n",
"\n",
"from datasets import SynAdaptiveEnv\n",
"from xlayers.super_core import SuperMLPv1"
]
},
{
@@ -25,49 +47,97 @@
"metadata": {},
"outputs": [],
"source": [
"def optimize_fn(xs, ys, test_sets):\n",
" xs = torch.FloatTensor(xs).view(-1, 1)\n",
" ys = torch.FloatTensor(ys).view(-1, 1)\n",
" \n",
" model = SuperMLPv1(1, 10, 1, torch.nn.ReLU)\n",
" optimizer = torch.optim.Adam(\n",
" model.parameters(),\n",
" lr=0.01, weight_decay=1e-4, amsgrad=True\n",
" )\n",
" for _iter in range(100):\n",
" preds = model(ys)\n",
"\n",
" optimizer.zero_grad()\n",
" loss = torch.nn.functional.mse_loss(preds, ys)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" with torch.no_grad():\n",
" answers = []\n",
" for test_set in test_sets:\n",
" test_set = torch.FloatTensor(test_set).view(-1, 1)\n",
" preds = model(test_set).view(-1).numpy()\n",
" answers.append(preds.tolist())\n",
" return answers\n",
"\n",
"def f(x):\n",
" return np.cos( 0.5 * x + 0.)\n",
"\n",
"def get_data(mode):\n",
" dataset = SynAdaptiveEnv(mode=mode)\n",
" times, xs, ys = [], [], []\n",
" for i, (_, t, x) in enumerate(dataset):\n",
" times.append(t)\n",
" xs.append(x)\n",
" dataset.set_transform(f)\n",
" for i, (_, _, y) in enumerate(dataset):\n",
" ys.append(y)\n",
" return times, xs, ys\n",
"\n",
"def visualize_syn(save_path):\n",
" save_dir = (save_path / '..').resolve()\n",
" save_dir.mkdir(parents=True, exist_ok=True)\n",
" \n",
" dpi, width, height = 50, 2000, 1000\n",
" dpi, width, height = 40, 2000, 900\n",
" figsize = width / float(dpi), height / float(dpi)\n",
" LabelSize, font_gap = 30, 4\n",
" LabelSize, LegendFontsize, font_gap = 40, 40, 5\n",
" \n",
" fig = plt.figure(figsize=figsize)\n",
" \n",
" times = np.arange(0, np.pi * 100, 0.1)\n",
" num = len(times)\n",
" x = []\n",
" for i in range(num):\n",
" scale = (i + 1.) / num * 4\n",
" value = times[i] * scale\n",
" x.append(np.sin(value) * (1.3 - scale))\n",
" x = np.array(x)\n",
" y = np.cos( x * x - 0.3 * x )\n",
" times, xs, ys = get_data(None)\n",
" \n",
" def draw_ax(cur_ax, xaxis, yaxis, xlabel, ylabel,\n",
" alpha=0.1, color='k', linestyle='-', legend=None, plot_only=False):\n",
" if legend is not None:\n",
" cur_ax.plot(xaxis[:1], yaxis[:1], color=color, label=legend)\n",
" cur_ax.plot(xaxis, yaxis, color=color, linestyle=linestyle, alpha=alpha, label=None)\n",
" if not plot_only:\n",
" cur_ax.set_xlabel(xlabel, fontsize=LabelSize)\n",
" cur_ax.set_ylabel(ylabel, rotation=0, fontsize=LabelSize)\n",
" for tick in cur_ax.xaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" tick.label.set_rotation(10)\n",
" for tick in cur_ax.yaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" \n",
" cur_ax = fig.add_subplot(2, 1, 1)\n",
" cur_ax.plot(times, x)\n",
" cur_ax.set_xlabel(\"time\", fontsize=LabelSize)\n",
" cur_ax.set_ylabel(\"x\", fontsize=LabelSize)\n",
" for tick in cur_ax.xaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" tick.label.set_rotation(30)\n",
" for tick in cur_ax.yaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" \n",
" \n",
" draw_ax(cur_ax, times, xs, \"time\", \"x\", alpha=1.0, legend=None)\n",
"\n",
" cur_ax = fig.add_subplot(2, 1, 2)\n",
" cur_ax.plot(times, y)\n",
" cur_ax.set_xlabel(\"time\", fontsize=LabelSize)\n",
" cur_ax.set_ylabel(\"f(x)\", fontsize=LabelSize)\n",
" for tick in cur_ax.xaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" tick.label.set_rotation(30)\n",
" for tick in cur_ax.yaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" \n",
" # fig.tight_layout()\n",
" # plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n",
" draw_ax(cur_ax, times, ys, \"time\", \"y\", alpha=0.1, legend=\"ground truth\")\n",
" \n",
" train_times, train_xs, train_ys = get_data(\"train\")\n",
" draw_ax(cur_ax, train_times, train_ys, None, None, alpha=1.0, color='r', legend=None, plot_only=True)\n",
" \n",
" valid_times, valid_xs, valid_ys = get_data(\"valid\")\n",
" draw_ax(cur_ax, valid_times, valid_ys, None, None, alpha=1.0, color='g', legend=None, plot_only=True)\n",
" \n",
" test_times, test_xs, test_ys = get_data(\"test\")\n",
" draw_ax(cur_ax, test_times, test_ys, None, None, alpha=1.0, color='b', legend=None, plot_only=True)\n",
" \n",
" # optimize MLP models\n",
" [train_preds, valid_preds, test_preds] = optimize_fn(train_xs, train_ys, [train_xs, valid_xs, test_xs])\n",
" draw_ax(cur_ax, train_times, train_preds, None, None,\n",
" alpha=1.0, linestyle='--', color='r', legend=\"MLP\", plot_only=True)\n",
" draw_ax(cur_ax, valid_times, valid_preds, None, None,\n",
" alpha=1.0, linestyle='--', color='g', legend=None, plot_only=True)\n",
" draw_ax(cur_ax, test_times, test_preds, None, None,\n",
" alpha=1.0, linestyle='--', color='b', legend=None, plot_only=True)\n",
"\n",
" plt.legend(loc=1, fontsize=LegendFontsize)\n",
"\n",
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
" plt.close(\"all\")\n",
" # plt.show()"
@@ -94,14 +164,6 @@
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
"visualize_syn(desktop_dir / 'tot-synthetic-v0.pdf')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "romantic-ordinance",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {