add autodl

This commit is contained in:
mhz
2024-08-25 18:02:31 +02:00
parent 192f286cfb
commit a0a25f291c
431 changed files with 50646 additions and 8 deletions

View File

@@ -0,0 +1,311 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "afraid-minutes",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[70148:MainThread](2021-04-12 13:23:30,262) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
"[70148:MainThread](2021-04-12 13:23:30,266) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
"[70148:MainThread](2021-04-12 13:23:30,269) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
"[70148:MainThread](2021-04-12 13:23:30,271) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
]
}
],
"source": [
"#\n",
"# Exhaustive Search Results\n",
"#\n",
"import os\n",
"import re\n",
"import sys\n",
"import qlib\n",
"import pprint\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from pathlib import Path\n",
"\n",
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
"lib_dir = (root_dir / \"lib\").resolve()\n",
"print(\"The root path: {:}\".format(root_dir))\n",
"print(\"The library path: {:}\".format(lib_dir))\n",
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
"if str(lib_dir) not in sys.path:\n",
" sys.path.insert(0, str(lib_dir))\n",
"\n",
"import qlib\n",
"from qlib import config as qconfig\n",
"from qlib.workflow import R\n",
"qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "hidden-exemption",
"metadata": {},
"outputs": [],
"source": [
"from utils.qlib_utils import QResult"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "continental-drain",
"metadata": {},
"outputs": [],
"source": [
"def filter_finished(recorders):\n",
" returned_recorders = dict()\n",
" not_finished = 0\n",
" for key, recorder in recorders.items():\n",
" if recorder.status == \"FINISHED\":\n",
" returned_recorders[key] = recorder\n",
" else:\n",
" not_finished += 1\n",
" return returned_recorders, not_finished\n",
"\n",
"def query_info(save_dir, verbose, name_filter, key_map):\n",
" if isinstance(save_dir, list):\n",
" results = []\n",
" for x in save_dir:\n",
" x = query_info(x, verbose, name_filter, key_map)\n",
" results.extend(x)\n",
" return results\n",
" # Here, the save_dir must be a string\n",
" R.set_uri(str(save_dir))\n",
" experiments = R.list_experiments()\n",
"\n",
" if verbose:\n",
" print(\"There are {:} experiments.\".format(len(experiments)))\n",
" qresults = []\n",
" for idx, (key, experiment) in enumerate(experiments.items()):\n",
" if experiment.id == \"0\":\n",
" continue\n",
" if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:\n",
" continue\n",
" recorders = experiment.list_recorders()\n",
" recorders, not_finished = filter_finished(recorders)\n",
" if verbose:\n",
" print(\n",
" \"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.\".format(\n",
" idx + 1,\n",
" len(experiments),\n",
" experiment.name,\n",
" len(recorders),\n",
" len(recorders) + not_finished,\n",
" )\n",
" )\n",
" result = QResult(experiment.name)\n",
" for recorder_id, recorder in recorders.items():\n",
" result.update(recorder.list_metrics(), key_map)\n",
" result.append_path(\n",
" os.path.join(recorder.uri, recorder.experiment_id, recorder.id)\n",
" )\n",
" if not len(result):\n",
" print(\"There are no valid recorders for {:}\".format(experiment))\n",
" continue\n",
" else:\n",
" if verbose:\n",
" print(\n",
" \"There are {:} valid recorders for {:}\".format(\n",
" len(recorders), experiment.name\n",
" )\n",
" )\n",
" qresults.append(result)\n",
" return qresults"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "filled-multiple",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[70148:MainThread](2021-04-12 13:23:31,137) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7f8c4a47efa0>\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[PosixPath('/Users/xuanyidong/Desktop/AutoDL-Projects/outputs/qlib-baselines-csi300')]\n"
]
}
],
"source": [
"paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']\n",
"paths = [path.resolve() for path in paths]\n",
"print(paths)\n",
"\n",
"key_map = dict()\n",
"for xset in (\"train\", \"valid\", \"test\"):\n",
" key_map[\"{:}-mean-IC\".format(xset)] = \"IC ({:})\".format(xset)\n",
" key_map[\"{:}-mean-ICIR\".format(xset)] = \"ICIR ({:})\".format(xset)\n",
"qresults = query_info(paths, False, 'TSF-.*-drop0_0', key_map)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "intimate-approval",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib\n",
"from matplotlib import cm\n",
"matplotlib.use(\"agg\")\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "supreme-basis",
"metadata": {},
"outputs": [],
"source": [
"def vis_depth_channel(qresults, save_path):\n",
" save_dir = (save_path / '..').resolve()\n",
" save_dir.mkdir(parents=True, exist_ok=True)\n",
" print('There are {:} qlib-results'.format(len(qresults)))\n",
" \n",
" dpi, width, height = 200, 4000, 2000\n",
" figsize = width / float(dpi), height / float(dpi)\n",
" LabelSize, LegendFontsize = 22, 12\n",
" font_gap = 5\n",
" \n",
" fig = plt.figure(figsize=figsize)\n",
" # fig, axs = plt.subplots(1, 2, figsize=figsize, projection='3d')\n",
" \n",
" def plot_ax(cur_ax, train_or_test):\n",
" depths, channels = [], []\n",
" ic_values, xmaps = [], dict()\n",
" for qresult in qresults:\n",
" name = qresult.name.split('-')[1]\n",
" depths.append(float(name.split('x')[0]))\n",
" channels.append(float(name.split('x')[1]))\n",
" if train_or_test:\n",
" ic_values.append(qresult['IC (train)'])\n",
" else:\n",
" ic_values.append(qresult['IC (valid)'])\n",
" xmaps[(depths[-1], channels[-1])] = ic_values[-1]\n",
" # cur_ax.scatter(depths, channels, ic_values, marker='o', c=\"tab:orange\")\n",
" raw_depths = np.arange(1, 9, dtype=np.int32)\n",
" raw_channels = np.array([6, 12, 24, 32, 48, 64], dtype=np.int32)\n",
" depths, channels = np.meshgrid(raw_depths, raw_channels)\n",
" ic_values = np.sin(depths) # initialize\n",
" # print(ic_values.shape)\n",
" num_x, num_y = ic_values.shape\n",
" for i in range(num_x):\n",
" for j in range(num_y):\n",
" xkey = (int(depths[i][j]), int(channels[i][j]))\n",
" if xkey not in xmaps:\n",
" raise ValueError(\"Did not find {:}\".format(xkey))\n",
" ic_values[i][j] = xmaps[xkey]\n",
" #print(sorted(list(xmaps.keys())))\n",
" #surf = cur_ax.plot_surface(\n",
" # np.array(depths), np.array(channels), np.array(ic_values),\n",
" # cmap=cm.coolwarm, linewidth=0, antialiased=False)\n",
" surf = cur_ax.plot_surface(\n",
" depths, channels, ic_values,\n",
" cmap=cm.Spectral, linewidth=0.2, antialiased=True)\n",
" cur_ax.set_xticks(raw_depths)\n",
" cur_ax.set_yticks(raw_channels)\n",
" cur_ax.set_zticks(np.arange(4, 11, 2))\n",
" cur_ax.set_xlabel(\"#depth\", fontsize=LabelSize)\n",
" cur_ax.set_ylabel(\"#channels\", fontsize=LabelSize)\n",
" cur_ax.set_zlabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n",
" for tick in cur_ax.xaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" for tick in cur_ax.yaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" for tick in cur_ax.zaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" # Add a color bar which maps values to colors.\n",
"# cax = fig.add_axes([cur_ax.get_position().x1 + 0.01,\n",
"# cur_ax.get_position().y0,\n",
"# 0.01,\n",
"# cur_ax.get_position().height * 0.9])\n",
" # fig.colorbar(surf, cax=cax)\n",
" # fig.colorbar(surf, shrink=0.5, aspect=5)\n",
" # import pdb; pdb.set_trace()\n",
" # ax1.legend(loc=4, fontsize=LegendFontsize)\n",
" ax = fig.add_subplot(1, 2, 1, projection='3d')\n",
" plot_ax(ax, True)\n",
" ax = fig.add_subplot(1, 2, 2, projection='3d')\n",
" plot_ax(ax, False)\n",
" # fig.tight_layout()\n",
" plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n",
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
" plt.close(\"all\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "shared-envelope",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The Desktop is at: /Users/xuanyidong/Desktop\n",
"There are 48 qlib-results\n"
]
}
],
"source": [
"# Visualization\n",
"home_dir = Path.home()\n",
"desktop_dir = home_dir / 'Desktop'\n",
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
"\n",
"vis_depth_channel(qresults, desktop_dir / 'es_csi300_d_vs_c.pdf')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,312 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "afraid-minutes",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[70363:MainThread](2021-04-12 13:25:01,065) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
"[70363:MainThread](2021-04-12 13:25:01,069) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
"[70363:MainThread](2021-04-12 13:25:01,085) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
"[70363:MainThread](2021-04-12 13:25:01,092) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
]
}
],
"source": [
"#\n",
"# Exhaustive Search Results\n",
"#\n",
"import os\n",
"import re\n",
"import sys\n",
"import qlib\n",
"import pprint\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from pathlib import Path\n",
"\n",
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
"lib_dir = (root_dir / \"lib\").resolve()\n",
"print(\"The root path: {:}\".format(root_dir))\n",
"print(\"The library path: {:}\".format(lib_dir))\n",
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
"if str(lib_dir) not in sys.path:\n",
" sys.path.insert(0, str(lib_dir))\n",
"\n",
"import qlib\n",
"from qlib import config as qconfig\n",
"from qlib.workflow import R\n",
"qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "hidden-exemption",
"metadata": {},
"outputs": [],
"source": [
"from utils.qlib_utils import QResult"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "continental-drain",
"metadata": {},
"outputs": [],
"source": [
"def filter_finished(recorders):\n",
" returned_recorders = dict()\n",
" not_finished = 0\n",
" for key, recorder in recorders.items():\n",
" if recorder.status == \"FINISHED\":\n",
" returned_recorders[key] = recorder\n",
" else:\n",
" not_finished += 1\n",
" return returned_recorders, not_finished\n",
"\n",
"def query_info(save_dir, verbose, name_filter, key_map):\n",
" if isinstance(save_dir, list):\n",
" results = []\n",
" for x in save_dir:\n",
" x = query_info(x, verbose, name_filter, key_map)\n",
" results.extend(x)\n",
" return results\n",
" # Here, the save_dir must be a string\n",
" R.set_uri(str(save_dir))\n",
" experiments = R.list_experiments()\n",
"\n",
" if verbose:\n",
" print(\"There are {:} experiments.\".format(len(experiments)))\n",
" qresults = []\n",
" for idx, (key, experiment) in enumerate(experiments.items()):\n",
" if experiment.id == \"0\":\n",
" continue\n",
" if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:\n",
" continue\n",
" recorders = experiment.list_recorders()\n",
" recorders, not_finished = filter_finished(recorders)\n",
" if verbose:\n",
" print(\n",
" \"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.\".format(\n",
" idx + 1,\n",
" len(experiments),\n",
" experiment.name,\n",
" len(recorders),\n",
" len(recorders) + not_finished,\n",
" )\n",
" )\n",
" result = QResult(experiment.name)\n",
" for recorder_id, recorder in recorders.items():\n",
" result.update(recorder.list_metrics(), key_map)\n",
" result.append_path(\n",
" os.path.join(recorder.uri, recorder.experiment_id, recorder.id)\n",
" )\n",
" if not len(result):\n",
" print(\"There are no valid recorders for {:}\".format(experiment))\n",
" continue\n",
" else:\n",
" if verbose:\n",
" print(\n",
" \"There are {:} valid recorders for {:}\".format(\n",
" len(recorders), experiment.name\n",
" )\n",
" )\n",
" qresults.append(result)\n",
" return qresults"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "filled-multiple",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[70363:MainThread](2021-04-12 13:25:01,647) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7fa920e56820>\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[PosixPath('/Users/xuanyidong/Desktop/AutoDL-Projects/outputs/qlib-baselines-csi300')]\n"
]
}
],
"source": [
"paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']\n",
"paths = [path.resolve() for path in paths]\n",
"print(paths)\n",
"\n",
"key_map = dict()\n",
"for xset in (\"train\", \"valid\", \"test\"):\n",
" key_map[\"{:}-mean-IC\".format(xset)] = \"IC ({:})\".format(xset)\n",
" key_map[\"{:}-mean-ICIR\".format(xset)] = \"ICIR ({:})\".format(xset)\n",
"\n",
"qresults = query_info(paths, False, 'TSF-.*', key_map)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "intimate-approval",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib\n",
"from matplotlib import cm\n",
"matplotlib.use(\"agg\")\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "supreme-basis",
"metadata": {},
"outputs": [],
"source": [
"def vis_dropouts(qresults, basenames, name2suffix, save_path):\n",
" save_dir = (save_path / '..').resolve()\n",
" save_dir.mkdir(parents=True, exist_ok=True)\n",
" print('There are {:} qlib-results'.format(len(qresults)))\n",
" \n",
" name2qresult = dict()\n",
" for qresult in qresults:\n",
" name2qresult[qresult.name] = qresult\n",
" # sort architectures\n",
" accuracies = []\n",
" for basename in basenames:\n",
" qresult = name2qresult[basename + '-drop0_0']\n",
" accuracies.append(qresult['ICIR (train)'])\n",
" sorted_basenames = sorted(basenames, key=lambda x: accuracies[basenames.index(x)])\n",
" \n",
" dpi, width, height = 200, 4000, 2000\n",
" figsize = width / float(dpi), height / float(dpi)\n",
" LabelSize, LegendFontsize = 22, 22\n",
" font_gap = 5\n",
" colors = ['k', 'r']\n",
" markers = ['*', 'o']\n",
" \n",
" fig = plt.figure(figsize=figsize)\n",
" \n",
" def plot_ax(cur_ax, train_or_test):\n",
" for idx, (legend, suffix) in enumerate(name2suffix.items()):\n",
" x_values = list(range(len(sorted_basenames)))\n",
" y_values = []\n",
" for i, name in enumerate(sorted_basenames):\n",
" name = '{:}{:}'.format(name, suffix)\n",
" qresult = name2qresult[name]\n",
" if train_or_test:\n",
" value = qresult['IC (train)']\n",
" else:\n",
" value = qresult['IC (valid)']\n",
" y_values.append(value)\n",
" cur_ax.plot(x_values, y_values, c=colors[idx])\n",
" cur_ax.scatter(x_values, y_values,\n",
" marker=markers[idx], s=3, c=colors[idx], alpha=0.9,\n",
" label=legend)\n",
" cur_ax.set_yticks(np.arange(4, 11, 2))\n",
" cur_ax.set_xlabel(\"sorted architectures\", fontsize=LabelSize)\n",
" cur_ax.set_ylabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n",
" for tick in cur_ax.xaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" for tick in cur_ax.yaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" cur_ax.legend(loc=4, fontsize=LegendFontsize)\n",
" ax = fig.add_subplot(1, 2, 1)\n",
" plot_ax(ax, True)\n",
" ax = fig.add_subplot(1, 2, 2)\n",
" plot_ax(ax, False)\n",
" # fig.tight_layout()\n",
" # plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n",
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
" plt.close(\"all\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "shared-envelope",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'TSF-3x48', 'TSF-2x64', 'TSF-2x12', 'TSF-8x48', 'TSF-6x32', 'TSF-4x48', 'TSF-8x6', 'TSF-4x6', 'TSF-2x32', 'TSF-5x12', 'TSF-5x64', 'TSF-1x64', 'TSF-2x24', 'TSF-8x24', 'TSF-4x12', 'TSF-6x12', 'TSF-1x32', 'TSF-5x32', 'TSF-3x24', 'TSF-8x12', 'TSF-5x48', 'TSF-6x64', 'TSF-7x64', 'TSF-7x48', 'TSF-1x6', 'TSF-2x48', 'TSF-7x24', 'TSF-3x32', 'TSF-1x24', 'TSF-4x64', 'TSF-3x12', 'TSF-8x64', 'TSF-4x32', 'TSF-5x6', 'TSF-7x6', 'TSF-7x12', 'TSF-3x6', 'TSF-4x24', 'TSF-6x48', 'TSF-6x6', 'TSF-1x48', 'TSF-1x12', 'TSF-7x32', 'TSF-5x24', 'TSF-2x6', 'TSF-6x24', 'TSF-3x64', 'TSF-8x32'}\n",
"The Desktop is at: /Users/xuanyidong/Desktop\n",
"There are 104 qlib-results\n"
]
}
],
"source": [
"# Visualization\n",
"names = [qresult.name for qresult in qresults]\n",
"base_names = set()\n",
"for name in names:\n",
" base_name = name.split('-drop')[0]\n",
" base_names.add(base_name)\n",
"print(base_names)\n",
"# filter\n",
"filtered_base_names = set()\n",
"for base_name in base_names:\n",
" if (base_name + '-drop0_0') in names and (base_name + '-drop0.1_0') in names:\n",
" filtered_base_names.add(base_name)\n",
" else:\n",
" print('Cannot find all names for {:}'.format(base_name))\n",
"# print(filtered_base_names)\n",
"home_dir = Path.home()\n",
"desktop_dir = home_dir / 'Desktop'\n",
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
"\n",
"vis_dropouts(qresults, list(filtered_base_names),\n",
" {'No-dropout': '-drop0_0',\n",
" 'Ratio=0.1' : '-drop0.1_0'},\n",
" desktop_dir / 'es_csi300_drop.pdf')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,208 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "afraid-minutes",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
]
}
],
"source": [
"import os\n",
"import re\n",
"import sys\n",
"import torch\n",
"import pprint\n",
"import numpy as np\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"from scipy.interpolate import make_interp_spline\n",
"\n",
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
"lib_dir = (root_dir / \"lib\").resolve()\n",
"print(\"The root path: {:}\".format(root_dir))\n",
"print(\"The library path: {:}\".format(lib_dir))\n",
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
"if str(lib_dir) not in sys.path:\n",
" sys.path.insert(0, str(lib_dir))\n",
"from utils.qlib_utils import QResult"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "continental-drain",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TSF-2x24-drop0_0s2013-01-01\n",
"TSF-2x24-drop0_0s2012-01-01\n",
"TSF-2x24-drop0_0s2008-01-01\n",
"TSF-2x24-drop0_0s2009-01-01\n",
"TSF-2x24-drop0_0s2010-01-01\n",
"TSF-2x24-drop0_0s2011-01-01\n",
"TSF-2x24-drop0_0s2008-07-01\n",
"TSF-2x24-drop0_0s2009-07-01\n",
"There are 3011 dates\n",
"Dates: 2008-01-02 2008-01-03\n"
]
}
],
"source": [
"qresults = torch.load(os.path.join(root_dir, 'notebooks', 'TOT', 'temp-time-x.pth'))\n",
"for qresult in qresults:\n",
" print(qresult.name)\n",
"all_dates = set()\n",
"for qresult in qresults:\n",
" dates = qresult.find_all_dates()\n",
" for date in dates:\n",
" all_dates.add(date)\n",
"all_dates = sorted(list(all_dates))\n",
"print('There are {:} dates'.format(len(all_dates)))\n",
"print('Dates: {:} {:}'.format(all_dates[0], all_dates[1]))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "intimate-approval",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib\n",
"from matplotlib import cm\n",
"matplotlib.use(\"agg\")\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "supreme-basis",
"metadata": {},
"outputs": [],
"source": [
"def vis_time_curve(qresults, dates, use_original, save_path):\n",
" save_dir = (save_path / '..').resolve()\n",
" save_dir.mkdir(parents=True, exist_ok=True)\n",
" print('There are {:} qlib-results'.format(len(qresults)))\n",
" \n",
" dpi, width, height = 200, 5000, 2000\n",
" figsize = width / float(dpi), height / float(dpi)\n",
" LabelSize, LegendFontsize = 22, 12\n",
" font_gap = 5\n",
" linestyles = ['-', '--']\n",
" colors = ['k', 'r']\n",
" \n",
" fig = plt.figure(figsize=figsize)\n",
" cur_ax = fig.add_subplot(1, 1, 1)\n",
" for idx, qresult in enumerate(qresults):\n",
" print('Visualize [{:}] -- {:}'.format(idx, qresult.name))\n",
" x_axis, y_axis = [], []\n",
" for idate, date in enumerate(dates):\n",
" if date in qresult._date2ICs[-1]:\n",
" mean, std = qresult.get_IC_by_date(date, 100)\n",
" if not np.isnan(mean):\n",
" x_axis.append(idate)\n",
" y_axis.append(mean)\n",
" x_axis, y_axis = np.array(x_axis), np.array(y_axis)\n",
" if use_original:\n",
" cur_ax.plot(x_axis, y_axis, linewidth=1, color=colors[idx], linestyle=linestyles[idx])\n",
" else:\n",
" xnew = np.linspace(x_axis.min(), x_axis.max(), 200)\n",
" spl = make_interp_spline(x_axis, y_axis, k=5)\n",
" ynew = spl(xnew)\n",
" cur_ax.plot(xnew, ynew, linewidth=2, color=colors[idx], linestyle=linestyles[idx])\n",
" \n",
" for tick in cur_ax.xaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" for tick in cur_ax.yaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" cur_ax.set_ylabel(\"IC (%)\", fontsize=LabelSize)\n",
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
" plt.close(\"all\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "shared-envelope",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The Desktop is at: /Users/xuanyidong/Desktop\n",
"There are 2 qlib-results\n",
"Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n",
"Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n",
"There are 2 qlib-results\n",
"Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n",
"Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n"
]
}
],
"source": [
"# Visualization\n",
"home_dir = Path.home()\n",
"desktop_dir = home_dir / 'Desktop'\n",
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
"\n",
"vis_time_curve(\n",
" (qresults[2], qresults[-1]),\n",
" all_dates,\n",
" True,\n",
" desktop_dir / 'es_csi300_time_curve.pdf')\n",
"\n",
"vis_time_curve(\n",
" (qresults[2], qresults[-1]),\n",
" all_dates,\n",
" False,\n",
" desktop_dir / 'es_csi300_time_curve-inter.pdf')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "exempt-stable",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,129 @@
import os
import re
import sys
import torch
import qlib
import pprint
from collections import OrderedDict
import numpy as np
import pandas as pd
from pathlib import Path
# __file__ = os.path.dirname(os.path.realpath("__file__"))
note_dir = Path(__file__).parent.resolve()
root_dir = (Path(__file__).parent / ".." / "..").resolve()
lib_dir = (root_dir / "lib").resolve()
print("The root path: {:}".format(root_dir))
print("The library path: {:}".format(lib_dir))
assert lib_dir.exists(), "{:} does not exist".format(lib_dir)
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
import qlib
from qlib import config as qconfig
from qlib.workflow import R
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN)
from utils.qlib_utils import QResult
def filter_finished(recorders):
returned_recorders = dict()
not_finished = 0
for key, recorder in recorders.items():
if recorder.status == "FINISHED":
returned_recorders[key] = recorder
else:
not_finished += 1
return returned_recorders, not_finished
def add_to_dict(xdict, timestamp, value):
date = timestamp.date().strftime("%Y-%m-%d")
if date in xdict:
raise ValueError("This date [{:}] is already in the dict".format(date))
xdict[date] = value
def query_info(save_dir, verbose, name_filter, key_map):
if isinstance(save_dir, list):
results = []
for x in save_dir:
x = query_info(x, verbose, name_filter, key_map)
results.extend(x)
return results
# Here, the save_dir must be a string
R.set_uri(str(save_dir))
experiments = R.list_experiments()
if verbose:
print("There are {:} experiments.".format(len(experiments)))
qresults = []
for idx, (key, experiment) in enumerate(experiments.items()):
if experiment.id == "0":
continue
if (
name_filter is not None
and re.fullmatch(name_filter, experiment.name) is None
):
continue
recorders = experiment.list_recorders()
recorders, not_finished = filter_finished(recorders)
if verbose:
print(
"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.".format(
idx + 1,
len(experiments),
experiment.name,
len(recorders),
len(recorders) + not_finished,
)
)
result = QResult(experiment.name)
for recorder_id, recorder in recorders.items():
file_names = ["results-train.pkl", "results-valid.pkl", "results-test.pkl"]
date2IC = OrderedDict()
for file_name in file_names:
xtemp = recorder.load_object(file_name)["all-IC"]
timestamps, values = xtemp.index.tolist(), xtemp.tolist()
for timestamp, value in zip(timestamps, values):
add_to_dict(date2IC, timestamp, value)
result.update(recorder.list_metrics(), key_map)
result.append_path(
os.path.join(recorder.uri, recorder.experiment_id, recorder.id)
)
result.append_date2ICs(date2IC)
if not len(result):
print("There are no valid recorders for {:}".format(experiment))
continue
else:
if verbose:
print(
"There are {:} valid recorders for {:}".format(
len(recorders), experiment.name
)
)
qresults.append(result)
return qresults
##
paths = [root_dir / "outputs" / "qlib-baselines-csi300"]
paths = [path.resolve() for path in paths]
print(paths)
key_map = dict()
for xset in ("train", "valid", "test"):
key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset)
key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset)
qresults = query_info(paths, False, "TSF-2x24-drop0_0s.*-.*-01", key_map)
print("Find {:} results".format(len(qresults)))
times = []
for qresult in qresults:
times.append(qresult.name.split("0_0s")[-1])
print(times)
save_path = os.path.join(note_dir, "temp-time-x.pth")
torch.save(qresults, save_path)
print(save_path)