add autodl

This commit is contained in:
mhz
2024-08-25 18:02:31 +02:00
parent 192f286cfb
commit a0a25f291c
431 changed files with 50646 additions and 8 deletions

View File

@@ -0,0 +1,118 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 8,
"id": "german-madonna",
"metadata": {},
"outputs": [],
"source": [
"# Implementation for \"A Tutorial on Bayesian Optimization\"\n",
"import numpy as np\n",
"\n",
"def get_data():\n",
" return np.random.random(2) * 10\n",
"\n",
"def f(x):\n",
" return float(np.power((x[0] * 3 - x[1]), 3) - np.exp(x[1]) + np.power(x[0], 2))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "broke-citizenship",
"metadata": {},
"outputs": [],
"source": [
"# Kernels typically have the property that points closer in the input space are more strongly correlated\n",
"# i.e., if |x1 - x2| < |x1 - x3|, then sigma(x1, x2) > sigma(x1, x3).\n",
"# the commonly used and simple kernel is the power exponential or Gaussian kernel:\n",
"def sigma0(x1, x2, alpha0=1, alpha=[1,1]):\n",
" \"\"\"alpha could be a vector\"\"\"\n",
" power = np.array(alpha, dtype=np.float32) * np.power(np.array(x1)-np.array(x2), 2)\n",
" return alpha0 * np.exp( -np.sum(power) )\n",
"\n",
"# the most common choice for the mean function is a constant value\n",
"def mu0(x, mu):\n",
" return mu"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "aerial-carnival",
"metadata": {},
"outputs": [],
"source": [
"K = 5\n",
"X = np.array([get_data() for i in range(K)])\n",
"mu = np.mean(X, axis=0)\n",
"mu0_over_K = [mu0(x, mu) for x in X]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "polished-discussion",
"metadata": {},
"outputs": [],
"source": [
"sigma0_over_KK = []\n",
"for i in range(K):\n",
" sigma0_over_KK.append(np.array([sigma0(X[i], X[j]) for j in range(K)]))\n",
"sigma0_over_KK = np.array(sigma0_over_KK)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "comic-jesus",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(20, 20)\n",
"1.1038803861344952e-06\n",
"1.1038803861344952e-06\n"
]
}
],
"source": [
"print(sigma0_over_KK.shape)\n",
"print(sigma0_over_KK[1][2])\n",
"print(sigma0_over_KK[2][1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "statistical-wrist",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,88 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2021-03-27 06:46:38] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n"
]
}
],
"source": [
"from nats_bench import create\n",
"from pprint import pprint\n",
"# Create the API for tologoy search space\n",
"api = create(None, 'tss', fast_mode=True, verbose=False)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'test-accuracy': 22.39999992879232,\n",
" 'test-all-time': 7.7054752962929856,\n",
" 'test-loss': 3.1626377182006835,\n",
" 'test-per-time': 0.6421229413577488,\n",
" 'train-accuracy': 21.68885959195242,\n",
" 'train-all-time': 1260.0195466594694,\n",
" 'train-loss': 3.1863493608815463,\n",
" 'train-per-time': 105.00162888828912,\n",
" 'valid-accuracy': 23.266666631062826,\n",
" 'valid-all-time': 7.7054752962929856,\n",
" 'valid-loss': 3.1219845104217527,\n",
" 'valid-per-time': 0.6421229413577488,\n",
" 'valtest-accuracy': 22.833333323160808,\n",
" 'valtest-all-time': 15.410950592585971,\n",
" 'valtest-loss': 3.142311067581177,\n",
" 'valtest-per-time': 1.2842458827154977}\n"
]
}
],
"source": [
"largest_candidate_tss = '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|'\n",
"\n",
"arch_index = api.query_index_by_arch(largest_candidate_tss)\n",
"info = api.get_more_info(arch_index, 'ImageNet16-120', hp='12', is_random=False)\n",
"pprint(info)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -0,0 +1,91 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2021-03-01 12:28:12] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n"
]
}
],
"source": [
"from nats_bench import create\n",
"import numpy as np\n",
"\n",
"def get_correlation(A, B):\n",
" return float(np.corrcoef(A, B)[0,1])\n",
"\n",
"# Create the API for tologoy search space\n",
"api = create(None, 'tss', fast_mode=True, verbose=False)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"There are 15625 architectures on the topology search space\n"
]
}
],
"source": [
"print('There are {:} architectures on the topology search space'.format(len(api)))\n",
"accuracies_12, accuracies_200 = [], []\n",
"for i, arch in enumerate(api):\n",
" info_a = api.get_more_info(i, dataset='cifar10-valid', hp='12', is_random=False)\n",
" accuracies_12.append(info_a['valid-accuracy'])\n",
"\n",
" info_b = api.get_more_info(i, dataset='cifar10-valid', hp='200', is_random=False)\n",
" accuracies_200.append(info_b['test-accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[CIFAR-10] The correlation between 12-epoch validation accuracy and 200-epoch test accuracy is: 91.18%\n"
]
}
],
"source": [
"correlation = get_correlation(accuracies_12, accuracies_200)\n",
"print('[CIFAR-10] The correlation between 12-epoch validation accuracy and 200-epoch test accuracy is: {:.2f}%'.format(correlation * 100))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -0,0 +1,86 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2021-03-09 08:44:19] Try to use the default NATS-Bench (size) path from fast_mode=True and path=None.\n"
]
}
],
"source": [
"from nats_bench import create\n",
"import numpy as np\n",
"\n",
"# Create the API for size search space\n",
"api = create(None, 'sss', fast_mode=True, verbose=False)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"There are 32768 architectures on the size search space\n"
]
}
],
"source": [
"print('There are {:} architectures on the size search space'.format(len(api)))\n",
"\n",
"c2acc = dict()\n",
"for index in range(len(api)):\n",
" info = api.get_more_info(index, 'cifar10', hp='90')\n",
" config = api.get_net_config(index, 'cifar10')\n",
" c2acc[config['channels']] = info['test-accuracy']"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"91.08546417236329\n"
]
}
],
"source": [
"print(np.mean(list(c2acc.values())))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -0,0 +1,274 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[82189:MainThread](2021-03-02 21:02:54,241) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
"[82189:MainThread](2021-03-02 21:02:54,255) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
"[82189:MainThread](2021-03-02 21:02:54,828) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
"[82189:MainThread](2021-03-02 21:02:54,829) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
]
}
],
"source": [
"import os\n",
"import sys\n",
"import qlib\n",
"import pprint\n",
"import numpy as np\n",
"import pandas as pd\n",
"qlib.init(provider_uri='~/.qlib/qlib_data/cn_data')\n",
"\n",
"from qlib.config import C\n",
"from qlib.data import D\n",
"from qlib.data.data import DatasetD, ExpressionD, Inst, Cal, FeatureD\n",
"from qlib.data.cache import H\n",
"from qlib.data.filter import NameDFilter\n",
"from qlib.utils import code_to_fname, read_bin"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n"
]
}
],
"source": [
"nameDFilter = NameDFilter(name_rule_re='SH[0-9]{4}55')\n",
"instruments_config = D.instruments(market='csi300', filter_pipe=[nameDFilter])\n",
"instruments = D.list_instruments(instruments=instruments_config,\n",
" start_time='2015-01-01',\n",
" end_time='2016-02-15',\n",
" as_list=True)\n",
"\n",
"fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']\n",
"features = D.features(instruments_config, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day')\n",
"print(type(features))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" $close $volume Ref($close, 1) Mean($close, 3) \\\n",
"instrument datetime \n",
"SH600655 2010-01-04 8.934296 47799352.0 8.667867 8.691138 \n",
" 2010-01-05 8.889880 29791234.0 8.934296 8.830681 \n",
" 2010-01-06 8.845468 29002874.0 8.889880 8.889881 \n",
" 2010-01-07 8.553690 38189440.0 8.845468 8.763013 \n",
" 2010-01-08 8.645658 23417642.0 8.553690 8.681605 \n",
"... ... ... ... ... \n",
"SH601555 2017-12-25 1.393481 80615584.0 1.406559 1.408012 \n",
" 2017-12-26 1.406559 64259856.0 1.393481 1.402200 \n",
" 2017-12-27 1.400747 58551256.0 1.406559 1.400262 \n",
" 2017-12-28 1.412371 96204872.0 1.400747 1.406559 \n",
" 2017-12-29 1.412371 52801024.0 1.412371 1.408496 \n",
"\n",
" $high-$low \n",
"instrument datetime \n",
"SH600655 2010-01-04 0.412291 \n",
" 2010-01-05 0.203006 \n",
" 2010-01-06 0.250560 \n",
" 2010-01-07 0.412291 \n",
" 2010-01-08 0.275964 \n",
"... ... \n",
"SH601555 2017-12-25 0.020343 \n",
" 2017-12-26 0.018890 \n",
" 2017-12-27 0.017437 \n",
" 2017-12-28 0.045045 \n",
" 2017-12-29 0.013078 \n",
"\n",
"[2867 rows x 5 columns]\n"
]
}
],
"source": [
"print(features)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'qlib.data.data.LocalProvider'>\n",
"<class 'qlib.config.QlibConfig'>\n",
"LocalProvider\n",
"Wrapper(provider=<qlib.data.data.LocalProvider object at 0x7ff5601cb370>)\n",
"Wrapper(provider=<qlib.data.data.LocalDatasetProvider object at 0x7ff5601c3b80>)\n",
"<qlib.data.data.LocalDatasetProvider object at 0x7ff5601c3b80>\n",
"LocalDatasetProvider\n",
"--\n",
"Wrapper(provider=<qlib.data.data.LocalInstrumentProvider object at 0x7ff55fb73340>)\n",
"<qlib.data.data.LocalInstrumentProvider object at 0x7ff55fb73340>\n",
"default_disk_cache: 1\n",
"ExpressionD: Wrapper(provider=<qlib.data.data.LocalExpressionProvider object at 0x7ff5601c3bb0>)\n",
"FeatureD : <qlib.data.data.LocalFeatureProvider object at 0x7ff55fb84430>\n"
]
}
],
"source": [
"# Provider:\n",
"print(type(D._provider))\n",
"print(type(C))\n",
"print(C.provider)\n",
"print(D)\n",
"\n",
"# DatasetD Provider\n",
"print(DatasetD)\n",
"print(DatasetD._provider)\n",
"print(C.dataset_provider)\n",
"\n",
"print('--')\n",
"print(Inst)\n",
"print(Inst._provider)\n",
"\n",
"# Default Disk Cache\n",
"print('default_disk_cache: {:}'.format(C.default_disk_cache))\n",
"print('ExpressionD: {:}'.format(ExpressionD))\n",
"print('FeatureD : {:}'.format(FeatureD._provider))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'pprint' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-4-76544a8bb578>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpprint\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minstruments_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0minstruments_d\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDatasetD\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_provider\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_instruments_d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minstruments_config\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfreq\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'day'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpprint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mpprint\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minstruments_d\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'pprint' is not defined"
]
}
],
"source": [
"pprint.pprint(instruments_config)\n",
"instruments_d = DatasetD._provider.get_instruments_d(instruments_config, freq='day')\n",
"pprint.pprint(instruments_d)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2012-12-31 00:00:00 -> 2019-01-18 00:00:00\n",
"<PandasArray>\n",
"[1.1059314, 1.0935822, 1.1059314, 1.0922102, 1.0839773, 1.0839773, 1.0181155,\n",
" 1.0730004, 1.0867218, 1.068884,\n",
" ...\n",
" 1.1163876, 1.1208236, 1.1119517, 1.0986437, 1.1075157, 1.0971651, 1.1149089,\n",
" 1.083857, 1.083857, 1.0956864]\n",
"Length: 1439, dtype: float32\n"
]
}
],
"source": [
"instrument, field, freq = 'SH601555', '$close', 'day'\n",
"all_dates = D.calendar(start_time='2011-12-31', end_time='2019-02-10', freq=freq)\n",
"start_time, end_time = all_dates[0], all_dates[-11]\n",
"print(str(start_time) + ' -> ' + str(end_time))\n",
"obj = ExpressionD.expression(instrument, field, start_time, end_time, freq)\n",
"print(obj.array)\n",
"\n",
"# expression = ExpressionD.get_expression_instance(field)\n",
"# start_time = pd.Timestamp(start_time)\n",
"# end_time = pd.Timestamp(end_time)\n",
"# _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq='day', future=False)\n",
"# print(start_index)\n",
"# print(end_index)\n",
"\n",
"# fname = code_to_fname(instrument)\n",
"# uri_data = FeatureD._uri_data.format(instrument.lower(), field[1:], freq)\n",
"# print(uri_data)\n",
"# # series = read_bin(uri_data, start_index, end_index)\n",
"# series = read_bin(uri_data, 2850, 2870)\n",
"# print(series)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wrapper(provider=<qlib.data.data.LocalProvider object at 0x7ff5601cb370>)\n",
"Wrapper(provider=<qlib.data.data.LocalInstrumentProvider object at 0x7ff55fb73340>)\n",
"Wrapper(provider=<qlib.data.data.LocalExpressionProvider object at 0x7ff5601c3bb0>)\n"
]
}
],
"source": [
"from qlib.data import D\n",
"from qlib.data.data import ExpressionD, Inst\n",
"print(D)\n",
"print(Inst)\n",
"print(ExpressionD)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -0,0 +1,162 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"library path: /Users/xuanyidong/Desktop/XAutoDL/lib\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[61704:MainThread](2021-03-22 13:56:38,104) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
"[61704:MainThread](2021-03-22 13:56:38,106) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
"[61704:MainThread](2021-03-22 13:56:38,680) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
"[61704:MainThread](2021-03-22 13:56:38,681) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'class': 'DatasetH',\n",
" 'kwargs': {'handler': {'class': 'Alpha158',\n",
" 'kwargs': {'end_time': '2020-08-01',\n",
" 'fit_end_time': '2014-12-31',\n",
" 'fit_start_time': '2008-01-01',\n",
" 'instruments': 'csi100',\n",
" 'start_time': '2008-01-01'},\n",
" 'module_path': 'qlib.contrib.data.handler'},\n",
" 'segments': {'test': ('2017-01-01', '2020-08-01'),\n",
" 'train': ('2008-01-01', '2014-12-31'),\n",
" 'valid': ('2015-01-01', '2016-12-31')}},\n",
" 'module_path': 'qlib.data.dataset'}\n"
]
}
],
"source": [
"import os\n",
"import sys\n",
"import qlib\n",
"import pprint\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from pathlib import Path\n",
"\n",
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
"\n",
"lib_dir = (Path(__file__).parent / \"..\" / \"lib\").resolve()\n",
"print(\"library path: {:}\".format(lib_dir))\n",
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
"if str(lib_dir) not in sys.path:\n",
" sys.path.insert(0, str(lib_dir))\n",
"\n",
"from qlib import config as qconfig\n",
"from qlib.utils import init_instance_by_config\n",
"\n",
"qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)\n",
"\n",
"dataset_config = {\n",
" \"class\": \"DatasetH\",\n",
" \"module_path\": \"qlib.data.dataset\",\n",
" \"kwargs\": {\n",
" \"handler\": {\n",
" \"class\": \"Alpha158\",\n",
" \"module_path\": \"qlib.contrib.data.handler\",\n",
" \"kwargs\": {\n",
" \"start_time\": \"2008-01-01\",\n",
" \"end_time\": \"2020-08-01\",\n",
" \"fit_start_time\": \"2008-01-01\",\n",
" \"fit_end_time\": \"2014-12-31\",\n",
" \"instruments\": \"csi100\",\n",
" },\n",
" },\n",
" \"segments\": {\n",
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
" },\n",
" },\n",
" }\n",
"pprint.pprint(dataset_config)\n",
"dataset = init_instance_by_config(dataset_config)\n",
"\n",
"df_train, df_valid, df_test = dataset.prepare(\n",
" [\"train\", \"valid\", \"test\"],\n",
" col_set=[\"feature\", \"label\"],\n",
" data_key=DataHandlerLP.DK_L,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'class': 'DatasetH',\n",
" 'kwargs': {'handler': {'class': 'Alpha158',\n",
" 'kwargs': {'end_time': '2020-08-01',\n",
" 'fit_end_time': '2014-12-31',\n",
" 'fit_start_time': '2008-01-01',\n",
" 'instruments': 'csi300',\n",
" 'start_time': '2008-01-01'},\n",
" 'module_path': 'qlib.contrib.data.handler'},\n",
" 'segments': {'test': ('2017-01-01', '2020-08-01'),\n",
" 'train': ('2008-01-01', '2014-12-31'),\n",
" 'valid': ('2015-01-01', '2016-12-31')}},\n",
" 'module_path': 'qlib.data.dataset'}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[95290:MainThread](2021-03-03 12:18:43,481) INFO - qlib.timer - [log.py:81] - Time cost: 237.911s | Loading data Done\n",
"[95290:MainThread](2021-03-03 12:18:45,080) INFO - qlib.timer - [log.py:81] - Time cost: 0.465s | DropnaLabel Done\n",
"[95290:MainThread](2021-03-03 12:18:51,572) INFO - qlib.timer - [log.py:81] - Time cost: 6.491s | CSZScoreNorm Done\n",
"[95290:MainThread](2021-03-03 12:18:51,573) INFO - qlib.timer - [log.py:81] - Time cost: 8.090s | fit & process data Done\n",
"[95290:MainThread](2021-03-03 12:18:51,573) INFO - qlib.timer - [log.py:81] - Time cost: 246.003s | Init data Done\n"
]
}
],
"source": [
"from trade_models.transformations import get_transformer\n",
"\n",
"model = get_transformer(None)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -0,0 +1,311 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "afraid-minutes",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[70148:MainThread](2021-04-12 13:23:30,262) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
"[70148:MainThread](2021-04-12 13:23:30,266) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
"[70148:MainThread](2021-04-12 13:23:30,269) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
"[70148:MainThread](2021-04-12 13:23:30,271) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
]
}
],
"source": [
"#\n",
"# Exhaustive Search Results\n",
"#\n",
"import os\n",
"import re\n",
"import sys\n",
"import qlib\n",
"import pprint\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from pathlib import Path\n",
"\n",
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
"lib_dir = (root_dir / \"lib\").resolve()\n",
"print(\"The root path: {:}\".format(root_dir))\n",
"print(\"The library path: {:}\".format(lib_dir))\n",
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
"if str(lib_dir) not in sys.path:\n",
" sys.path.insert(0, str(lib_dir))\n",
"\n",
"import qlib\n",
"from qlib import config as qconfig\n",
"from qlib.workflow import R\n",
"qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "hidden-exemption",
"metadata": {},
"outputs": [],
"source": [
"from utils.qlib_utils import QResult"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "continental-drain",
"metadata": {},
"outputs": [],
"source": [
"def filter_finished(recorders):\n",
" returned_recorders = dict()\n",
" not_finished = 0\n",
" for key, recorder in recorders.items():\n",
" if recorder.status == \"FINISHED\":\n",
" returned_recorders[key] = recorder\n",
" else:\n",
" not_finished += 1\n",
" return returned_recorders, not_finished\n",
"\n",
"def query_info(save_dir, verbose, name_filter, key_map):\n",
" if isinstance(save_dir, list):\n",
" results = []\n",
" for x in save_dir:\n",
" x = query_info(x, verbose, name_filter, key_map)\n",
" results.extend(x)\n",
" return results\n",
" # Here, the save_dir must be a string\n",
" R.set_uri(str(save_dir))\n",
" experiments = R.list_experiments()\n",
"\n",
" if verbose:\n",
" print(\"There are {:} experiments.\".format(len(experiments)))\n",
" qresults = []\n",
" for idx, (key, experiment) in enumerate(experiments.items()):\n",
" if experiment.id == \"0\":\n",
" continue\n",
" if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:\n",
" continue\n",
" recorders = experiment.list_recorders()\n",
" recorders, not_finished = filter_finished(recorders)\n",
" if verbose:\n",
" print(\n",
" \"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.\".format(\n",
" idx + 1,\n",
" len(experiments),\n",
" experiment.name,\n",
" len(recorders),\n",
" len(recorders) + not_finished,\n",
" )\n",
" )\n",
" result = QResult(experiment.name)\n",
" for recorder_id, recorder in recorders.items():\n",
" result.update(recorder.list_metrics(), key_map)\n",
" result.append_path(\n",
" os.path.join(recorder.uri, recorder.experiment_id, recorder.id)\n",
" )\n",
" if not len(result):\n",
" print(\"There are no valid recorders for {:}\".format(experiment))\n",
" continue\n",
" else:\n",
" if verbose:\n",
" print(\n",
" \"There are {:} valid recorders for {:}\".format(\n",
" len(recorders), experiment.name\n",
" )\n",
" )\n",
" qresults.append(result)\n",
" return qresults"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "filled-multiple",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[70148:MainThread](2021-04-12 13:23:31,137) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7f8c4a47efa0>\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[PosixPath('/Users/xuanyidong/Desktop/AutoDL-Projects/outputs/qlib-baselines-csi300')]\n"
]
}
],
"source": [
"paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']\n",
"paths = [path.resolve() for path in paths]\n",
"print(paths)\n",
"\n",
"key_map = dict()\n",
"for xset in (\"train\", \"valid\", \"test\"):\n",
" key_map[\"{:}-mean-IC\".format(xset)] = \"IC ({:})\".format(xset)\n",
" key_map[\"{:}-mean-ICIR\".format(xset)] = \"ICIR ({:})\".format(xset)\n",
"qresults = query_info(paths, False, 'TSF-.*-drop0_0', key_map)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "intimate-approval",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib\n",
"from matplotlib import cm\n",
"matplotlib.use(\"agg\")\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "supreme-basis",
"metadata": {},
"outputs": [],
"source": [
"def vis_depth_channel(qresults, save_path):\n",
" save_dir = (save_path / '..').resolve()\n",
" save_dir.mkdir(parents=True, exist_ok=True)\n",
" print('There are {:} qlib-results'.format(len(qresults)))\n",
" \n",
" dpi, width, height = 200, 4000, 2000\n",
" figsize = width / float(dpi), height / float(dpi)\n",
" LabelSize, LegendFontsize = 22, 12\n",
" font_gap = 5\n",
" \n",
" fig = plt.figure(figsize=figsize)\n",
" # fig, axs = plt.subplots(1, 2, figsize=figsize, projection='3d')\n",
" \n",
" def plot_ax(cur_ax, train_or_test):\n",
" depths, channels = [], []\n",
" ic_values, xmaps = [], dict()\n",
" for qresult in qresults:\n",
" name = qresult.name.split('-')[1]\n",
" depths.append(float(name.split('x')[0]))\n",
" channels.append(float(name.split('x')[1]))\n",
" if train_or_test:\n",
" ic_values.append(qresult['IC (train)'])\n",
" else:\n",
" ic_values.append(qresult['IC (valid)'])\n",
" xmaps[(depths[-1], channels[-1])] = ic_values[-1]\n",
" # cur_ax.scatter(depths, channels, ic_values, marker='o', c=\"tab:orange\")\n",
" raw_depths = np.arange(1, 9, dtype=np.int32)\n",
" raw_channels = np.array([6, 12, 24, 32, 48, 64], dtype=np.int32)\n",
" depths, channels = np.meshgrid(raw_depths, raw_channels)\n",
" ic_values = np.sin(depths) # initialize\n",
" # print(ic_values.shape)\n",
" num_x, num_y = ic_values.shape\n",
" for i in range(num_x):\n",
" for j in range(num_y):\n",
" xkey = (int(depths[i][j]), int(channels[i][j]))\n",
" if xkey not in xmaps:\n",
" raise ValueError(\"Did not find {:}\".format(xkey))\n",
" ic_values[i][j] = xmaps[xkey]\n",
" #print(sorted(list(xmaps.keys())))\n",
" #surf = cur_ax.plot_surface(\n",
" # np.array(depths), np.array(channels), np.array(ic_values),\n",
" # cmap=cm.coolwarm, linewidth=0, antialiased=False)\n",
" surf = cur_ax.plot_surface(\n",
" depths, channels, ic_values,\n",
" cmap=cm.Spectral, linewidth=0.2, antialiased=True)\n",
" cur_ax.set_xticks(raw_depths)\n",
" cur_ax.set_yticks(raw_channels)\n",
" cur_ax.set_zticks(np.arange(4, 11, 2))\n",
" cur_ax.set_xlabel(\"#depth\", fontsize=LabelSize)\n",
" cur_ax.set_ylabel(\"#channels\", fontsize=LabelSize)\n",
" cur_ax.set_zlabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n",
" for tick in cur_ax.xaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" for tick in cur_ax.yaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" for tick in cur_ax.zaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" # Add a color bar which maps values to colors.\n",
"# cax = fig.add_axes([cur_ax.get_position().x1 + 0.01,\n",
"# cur_ax.get_position().y0,\n",
"# 0.01,\n",
"# cur_ax.get_position().height * 0.9])\n",
" # fig.colorbar(surf, cax=cax)\n",
" # fig.colorbar(surf, shrink=0.5, aspect=5)\n",
" # import pdb; pdb.set_trace()\n",
" # ax1.legend(loc=4, fontsize=LegendFontsize)\n",
" ax = fig.add_subplot(1, 2, 1, projection='3d')\n",
" plot_ax(ax, True)\n",
" ax = fig.add_subplot(1, 2, 2, projection='3d')\n",
" plot_ax(ax, False)\n",
" # fig.tight_layout()\n",
" plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n",
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
" plt.close(\"all\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "shared-envelope",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The Desktop is at: /Users/xuanyidong/Desktop\n",
"There are 48 qlib-results\n"
]
}
],
"source": [
"# Visualization\n",
"home_dir = Path.home()\n",
"desktop_dir = home_dir / 'Desktop'\n",
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
"\n",
"vis_depth_channel(qresults, desktop_dir / 'es_csi300_d_vs_c.pdf')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,312 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "afraid-minutes",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[70363:MainThread](2021-04-12 13:25:01,065) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
"[70363:MainThread](2021-04-12 13:25:01,069) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
"[70363:MainThread](2021-04-12 13:25:01,085) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
"[70363:MainThread](2021-04-12 13:25:01,092) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
]
}
],
"source": [
"#\n",
"# Exhaustive Search Results\n",
"#\n",
"import os\n",
"import re\n",
"import sys\n",
"import qlib\n",
"import pprint\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from pathlib import Path\n",
"\n",
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
"lib_dir = (root_dir / \"lib\").resolve()\n",
"print(\"The root path: {:}\".format(root_dir))\n",
"print(\"The library path: {:}\".format(lib_dir))\n",
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
"if str(lib_dir) not in sys.path:\n",
" sys.path.insert(0, str(lib_dir))\n",
"\n",
"import qlib\n",
"from qlib import config as qconfig\n",
"from qlib.workflow import R\n",
"qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "hidden-exemption",
"metadata": {},
"outputs": [],
"source": [
"from utils.qlib_utils import QResult"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "continental-drain",
"metadata": {},
"outputs": [],
"source": [
"def filter_finished(recorders):\n",
" returned_recorders = dict()\n",
" not_finished = 0\n",
" for key, recorder in recorders.items():\n",
" if recorder.status == \"FINISHED\":\n",
" returned_recorders[key] = recorder\n",
" else:\n",
" not_finished += 1\n",
" return returned_recorders, not_finished\n",
"\n",
"def query_info(save_dir, verbose, name_filter, key_map):\n",
" if isinstance(save_dir, list):\n",
" results = []\n",
" for x in save_dir:\n",
" x = query_info(x, verbose, name_filter, key_map)\n",
" results.extend(x)\n",
" return results\n",
" # Here, the save_dir must be a string\n",
" R.set_uri(str(save_dir))\n",
" experiments = R.list_experiments()\n",
"\n",
" if verbose:\n",
" print(\"There are {:} experiments.\".format(len(experiments)))\n",
" qresults = []\n",
" for idx, (key, experiment) in enumerate(experiments.items()):\n",
" if experiment.id == \"0\":\n",
" continue\n",
" if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:\n",
" continue\n",
" recorders = experiment.list_recorders()\n",
" recorders, not_finished = filter_finished(recorders)\n",
" if verbose:\n",
" print(\n",
" \"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.\".format(\n",
" idx + 1,\n",
" len(experiments),\n",
" experiment.name,\n",
" len(recorders),\n",
" len(recorders) + not_finished,\n",
" )\n",
" )\n",
" result = QResult(experiment.name)\n",
" for recorder_id, recorder in recorders.items():\n",
" result.update(recorder.list_metrics(), key_map)\n",
" result.append_path(\n",
" os.path.join(recorder.uri, recorder.experiment_id, recorder.id)\n",
" )\n",
" if not len(result):\n",
" print(\"There are no valid recorders for {:}\".format(experiment))\n",
" continue\n",
" else:\n",
" if verbose:\n",
" print(\n",
" \"There are {:} valid recorders for {:}\".format(\n",
" len(recorders), experiment.name\n",
" )\n",
" )\n",
" qresults.append(result)\n",
" return qresults"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "filled-multiple",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[70363:MainThread](2021-04-12 13:25:01,647) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7fa920e56820>\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[PosixPath('/Users/xuanyidong/Desktop/AutoDL-Projects/outputs/qlib-baselines-csi300')]\n"
]
}
],
"source": [
"paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']\n",
"paths = [path.resolve() for path in paths]\n",
"print(paths)\n",
"\n",
"key_map = dict()\n",
"for xset in (\"train\", \"valid\", \"test\"):\n",
" key_map[\"{:}-mean-IC\".format(xset)] = \"IC ({:})\".format(xset)\n",
" key_map[\"{:}-mean-ICIR\".format(xset)] = \"ICIR ({:})\".format(xset)\n",
"\n",
"qresults = query_info(paths, False, 'TSF-.*', key_map)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "intimate-approval",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib\n",
"from matplotlib import cm\n",
"matplotlib.use(\"agg\")\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "supreme-basis",
"metadata": {},
"outputs": [],
"source": [
"def vis_dropouts(qresults, basenames, name2suffix, save_path):\n",
" save_dir = (save_path / '..').resolve()\n",
" save_dir.mkdir(parents=True, exist_ok=True)\n",
" print('There are {:} qlib-results'.format(len(qresults)))\n",
" \n",
" name2qresult = dict()\n",
" for qresult in qresults:\n",
" name2qresult[qresult.name] = qresult\n",
" # sort architectures\n",
" accuracies = []\n",
" for basename in basenames:\n",
" qresult = name2qresult[basename + '-drop0_0']\n",
" accuracies.append(qresult['ICIR (train)'])\n",
" sorted_basenames = sorted(basenames, key=lambda x: accuracies[basenames.index(x)])\n",
" \n",
" dpi, width, height = 200, 4000, 2000\n",
" figsize = width / float(dpi), height / float(dpi)\n",
" LabelSize, LegendFontsize = 22, 22\n",
" font_gap = 5\n",
" colors = ['k', 'r']\n",
" markers = ['*', 'o']\n",
" \n",
" fig = plt.figure(figsize=figsize)\n",
" \n",
" def plot_ax(cur_ax, train_or_test):\n",
" for idx, (legend, suffix) in enumerate(name2suffix.items()):\n",
" x_values = list(range(len(sorted_basenames)))\n",
" y_values = []\n",
" for i, name in enumerate(sorted_basenames):\n",
" name = '{:}{:}'.format(name, suffix)\n",
" qresult = name2qresult[name]\n",
" if train_or_test:\n",
" value = qresult['IC (train)']\n",
" else:\n",
" value = qresult['IC (valid)']\n",
" y_values.append(value)\n",
" cur_ax.plot(x_values, y_values, c=colors[idx])\n",
" cur_ax.scatter(x_values, y_values,\n",
" marker=markers[idx], s=3, c=colors[idx], alpha=0.9,\n",
" label=legend)\n",
" cur_ax.set_yticks(np.arange(4, 11, 2))\n",
" cur_ax.set_xlabel(\"sorted architectures\", fontsize=LabelSize)\n",
" cur_ax.set_ylabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n",
" for tick in cur_ax.xaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" for tick in cur_ax.yaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" cur_ax.legend(loc=4, fontsize=LegendFontsize)\n",
" ax = fig.add_subplot(1, 2, 1)\n",
" plot_ax(ax, True)\n",
" ax = fig.add_subplot(1, 2, 2)\n",
" plot_ax(ax, False)\n",
" # fig.tight_layout()\n",
" # plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n",
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
" plt.close(\"all\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "shared-envelope",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'TSF-3x48', 'TSF-2x64', 'TSF-2x12', 'TSF-8x48', 'TSF-6x32', 'TSF-4x48', 'TSF-8x6', 'TSF-4x6', 'TSF-2x32', 'TSF-5x12', 'TSF-5x64', 'TSF-1x64', 'TSF-2x24', 'TSF-8x24', 'TSF-4x12', 'TSF-6x12', 'TSF-1x32', 'TSF-5x32', 'TSF-3x24', 'TSF-8x12', 'TSF-5x48', 'TSF-6x64', 'TSF-7x64', 'TSF-7x48', 'TSF-1x6', 'TSF-2x48', 'TSF-7x24', 'TSF-3x32', 'TSF-1x24', 'TSF-4x64', 'TSF-3x12', 'TSF-8x64', 'TSF-4x32', 'TSF-5x6', 'TSF-7x6', 'TSF-7x12', 'TSF-3x6', 'TSF-4x24', 'TSF-6x48', 'TSF-6x6', 'TSF-1x48', 'TSF-1x12', 'TSF-7x32', 'TSF-5x24', 'TSF-2x6', 'TSF-6x24', 'TSF-3x64', 'TSF-8x32'}\n",
"The Desktop is at: /Users/xuanyidong/Desktop\n",
"There are 104 qlib-results\n"
]
}
],
"source": [
"# Visualization\n",
"names = [qresult.name for qresult in qresults]\n",
"base_names = set()\n",
"for name in names:\n",
" base_name = name.split('-drop')[0]\n",
" base_names.add(base_name)\n",
"print(base_names)\n",
"# filter\n",
"filtered_base_names = set()\n",
"for base_name in base_names:\n",
" if (base_name + '-drop0_0') in names and (base_name + '-drop0.1_0') in names:\n",
" filtered_base_names.add(base_name)\n",
" else:\n",
" print('Cannot find all names for {:}'.format(base_name))\n",
"# print(filtered_base_names)\n",
"home_dir = Path.home()\n",
"desktop_dir = home_dir / 'Desktop'\n",
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
"\n",
"vis_dropouts(qresults, list(filtered_base_names),\n",
" {'No-dropout': '-drop0_0',\n",
" 'Ratio=0.1' : '-drop0.1_0'},\n",
" desktop_dir / 'es_csi300_drop.pdf')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,208 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "afraid-minutes",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
]
}
],
"source": [
"import os\n",
"import re\n",
"import sys\n",
"import torch\n",
"import pprint\n",
"import numpy as np\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"from scipy.interpolate import make_interp_spline\n",
"\n",
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
"lib_dir = (root_dir / \"lib\").resolve()\n",
"print(\"The root path: {:}\".format(root_dir))\n",
"print(\"The library path: {:}\".format(lib_dir))\n",
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
"if str(lib_dir) not in sys.path:\n",
" sys.path.insert(0, str(lib_dir))\n",
"from utils.qlib_utils import QResult"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "continental-drain",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TSF-2x24-drop0_0s2013-01-01\n",
"TSF-2x24-drop0_0s2012-01-01\n",
"TSF-2x24-drop0_0s2008-01-01\n",
"TSF-2x24-drop0_0s2009-01-01\n",
"TSF-2x24-drop0_0s2010-01-01\n",
"TSF-2x24-drop0_0s2011-01-01\n",
"TSF-2x24-drop0_0s2008-07-01\n",
"TSF-2x24-drop0_0s2009-07-01\n",
"There are 3011 dates\n",
"Dates: 2008-01-02 2008-01-03\n"
]
}
],
"source": [
"qresults = torch.load(os.path.join(root_dir, 'notebooks', 'TOT', 'temp-time-x.pth'))\n",
"for qresult in qresults:\n",
" print(qresult.name)\n",
"all_dates = set()\n",
"for qresult in qresults:\n",
" dates = qresult.find_all_dates()\n",
" for date in dates:\n",
" all_dates.add(date)\n",
"all_dates = sorted(list(all_dates))\n",
"print('There are {:} dates'.format(len(all_dates)))\n",
"print('Dates: {:} {:}'.format(all_dates[0], all_dates[1]))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "intimate-approval",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib\n",
"from matplotlib import cm\n",
"matplotlib.use(\"agg\")\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "supreme-basis",
"metadata": {},
"outputs": [],
"source": [
"def vis_time_curve(qresults, dates, use_original, save_path):\n",
" save_dir = (save_path / '..').resolve()\n",
" save_dir.mkdir(parents=True, exist_ok=True)\n",
" print('There are {:} qlib-results'.format(len(qresults)))\n",
" \n",
" dpi, width, height = 200, 5000, 2000\n",
" figsize = width / float(dpi), height / float(dpi)\n",
" LabelSize, LegendFontsize = 22, 12\n",
" font_gap = 5\n",
" linestyles = ['-', '--']\n",
" colors = ['k', 'r']\n",
" \n",
" fig = plt.figure(figsize=figsize)\n",
" cur_ax = fig.add_subplot(1, 1, 1)\n",
" for idx, qresult in enumerate(qresults):\n",
" print('Visualize [{:}] -- {:}'.format(idx, qresult.name))\n",
" x_axis, y_axis = [], []\n",
" for idate, date in enumerate(dates):\n",
" if date in qresult._date2ICs[-1]:\n",
" mean, std = qresult.get_IC_by_date(date, 100)\n",
" if not np.isnan(mean):\n",
" x_axis.append(idate)\n",
" y_axis.append(mean)\n",
" x_axis, y_axis = np.array(x_axis), np.array(y_axis)\n",
" if use_original:\n",
" cur_ax.plot(x_axis, y_axis, linewidth=1, color=colors[idx], linestyle=linestyles[idx])\n",
" else:\n",
" xnew = np.linspace(x_axis.min(), x_axis.max(), 200)\n",
" spl = make_interp_spline(x_axis, y_axis, k=5)\n",
" ynew = spl(xnew)\n",
" cur_ax.plot(xnew, ynew, linewidth=2, color=colors[idx], linestyle=linestyles[idx])\n",
" \n",
" for tick in cur_ax.xaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" for tick in cur_ax.yaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" cur_ax.set_ylabel(\"IC (%)\", fontsize=LabelSize)\n",
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
" plt.close(\"all\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "shared-envelope",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The Desktop is at: /Users/xuanyidong/Desktop\n",
"There are 2 qlib-results\n",
"Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n",
"Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n",
"There are 2 qlib-results\n",
"Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n",
"Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n"
]
}
],
"source": [
"# Visualization\n",
"home_dir = Path.home()\n",
"desktop_dir = home_dir / 'Desktop'\n",
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
"\n",
"vis_time_curve(\n",
" (qresults[2], qresults[-1]),\n",
" all_dates,\n",
" True,\n",
" desktop_dir / 'es_csi300_time_curve.pdf')\n",
"\n",
"vis_time_curve(\n",
" (qresults[2], qresults[-1]),\n",
" all_dates,\n",
" False,\n",
" desktop_dir / 'es_csi300_time_curve-inter.pdf')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "exempt-stable",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,129 @@
import os
import re
import sys
import torch
import qlib
import pprint
from collections import OrderedDict
import numpy as np
import pandas as pd
from pathlib import Path
# __file__ = os.path.dirname(os.path.realpath("__file__"))
note_dir = Path(__file__).parent.resolve()
root_dir = (Path(__file__).parent / ".." / "..").resolve()
lib_dir = (root_dir / "lib").resolve()
print("The root path: {:}".format(root_dir))
print("The library path: {:}".format(lib_dir))
assert lib_dir.exists(), "{:} does not exist".format(lib_dir)
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
import qlib
from qlib import config as qconfig
from qlib.workflow import R
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN)
from utils.qlib_utils import QResult
def filter_finished(recorders):
returned_recorders = dict()
not_finished = 0
for key, recorder in recorders.items():
if recorder.status == "FINISHED":
returned_recorders[key] = recorder
else:
not_finished += 1
return returned_recorders, not_finished
def add_to_dict(xdict, timestamp, value):
date = timestamp.date().strftime("%Y-%m-%d")
if date in xdict:
raise ValueError("This date [{:}] is already in the dict".format(date))
xdict[date] = value
def query_info(save_dir, verbose, name_filter, key_map):
if isinstance(save_dir, list):
results = []
for x in save_dir:
x = query_info(x, verbose, name_filter, key_map)
results.extend(x)
return results
# Here, the save_dir must be a string
R.set_uri(str(save_dir))
experiments = R.list_experiments()
if verbose:
print("There are {:} experiments.".format(len(experiments)))
qresults = []
for idx, (key, experiment) in enumerate(experiments.items()):
if experiment.id == "0":
continue
if (
name_filter is not None
and re.fullmatch(name_filter, experiment.name) is None
):
continue
recorders = experiment.list_recorders()
recorders, not_finished = filter_finished(recorders)
if verbose:
print(
"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.".format(
idx + 1,
len(experiments),
experiment.name,
len(recorders),
len(recorders) + not_finished,
)
)
result = QResult(experiment.name)
for recorder_id, recorder in recorders.items():
file_names = ["results-train.pkl", "results-valid.pkl", "results-test.pkl"]
date2IC = OrderedDict()
for file_name in file_names:
xtemp = recorder.load_object(file_name)["all-IC"]
timestamps, values = xtemp.index.tolist(), xtemp.tolist()
for timestamp, value in zip(timestamps, values):
add_to_dict(date2IC, timestamp, value)
result.update(recorder.list_metrics(), key_map)
result.append_path(
os.path.join(recorder.uri, recorder.experiment_id, recorder.id)
)
result.append_date2ICs(date2IC)
if not len(result):
print("There are no valid recorders for {:}".format(experiment))
continue
else:
if verbose:
print(
"There are {:} valid recorders for {:}".format(
len(recorders), experiment.name
)
)
qresults.append(result)
return qresults
##
paths = [root_dir / "outputs" / "qlib-baselines-csi300"]
paths = [path.resolve() for path in paths]
print(paths)
key_map = dict()
for xset in ("train", "valid", "test"):
key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset)
key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset)
qresults = query_info(paths, False, "TSF-2x24-drop0_0s.*-.*-01", key_map)
print("Find {:} results".format(len(qresults)))
times = []
for qresult in qresults:
times.append(qresult.name.split("0_0s")[-1])
print(times)
save_path = os.path.join(note_dir, "temp-time-x.pth")
torch.save(qresults, save_path)
print(save_path)

View File

@@ -0,0 +1,102 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"library path: /Users/xuanyidong/Desktop/XAutoDL/lib\n"
]
}
],
"source": [
"#####################################################\n",
"# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #\n",
"#####################################################\n",
"import abc, os, sys\n",
"from pathlib import Path\n",
"\n",
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
"\n",
"lib_dir = (Path(__file__).parent / \"..\" / \"lib\").resolve()\n",
"print(\"library path: {:}\".format(lib_dir))\n",
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
"if str(lib_dir) not in sys.path:\n",
" sys.path.insert(0, str(lib_dir))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.7.0\n",
"True\n",
"OrderedDict()\n",
"OrderedDict()\n",
"set()\n",
"OrderedDict()\n",
"OrderedDict()\n",
"OrderedDict()\n",
"OrderedDict()\n",
"OrderedDict()\n",
"OrderedDict()\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/xuanyidong/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py:551: UserWarning: Setting attributes on ParameterDict is not supported.\n",
" warnings.warn(\"Setting attributes on ParameterDict is not supported.\")\n"
]
}
],
"source": [
"# Test the Linear layer\n",
"import spaces\n",
"import torch\n",
"from xlayers import super_core\n",
"\n",
"print(torch.__version__)\n",
"mlp = super_core.SuperMLPv2(10, 12, 32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,152 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "filled-multiple",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
]
}
],
"source": [
"import os, sys\n",
"import torch\n",
"from pathlib import Path\n",
"import numpy as np\n",
"import matplotlib\n",
"from matplotlib import cm\n",
"matplotlib.use(\"agg\")\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker\n",
"\n",
"\n",
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
"lib_dir = (root_dir / \"lib\").resolve()\n",
"print(\"The root path: {:}\".format(root_dir))\n",
"print(\"The library path: {:}\".format(lib_dir))\n",
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
"if str(lib_dir) not in sys.path:\n",
" sys.path.insert(0, str(lib_dir))\n",
"\n",
"from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv\n",
"from datasets import DynamicQuadraticFunc\n",
"from datasets.synthetic_example import create_example_v1"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "detected-second",
"metadata": {},
"outputs": [],
"source": [
"def draw_fig(save_dir, timestamp, xaxis, yaxis):\n",
" save_path = save_dir / '{:04d}'.format(timestamp)\n",
" # print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path))\n",
" dpi, width, height = 40, 1500, 1500\n",
" figsize = width / float(dpi), height / float(dpi)\n",
" LabelSize, LegendFontsize, font_gap = 80, 80, 5\n",
"\n",
" fig = plt.figure(figsize=figsize)\n",
" \n",
" cur_ax = fig.add_subplot(1, 1, 1)\n",
" cur_ax.scatter(xaxis, yaxis, color=\"k\", s=10, alpha=0.9, label=\"Timestamp={:02d}\".format(timestamp))\n",
" cur_ax.set_xlabel(\"X\", fontsize=LabelSize)\n",
" cur_ax.set_ylabel(\"f(X)\", rotation=0, fontsize=LabelSize)\n",
" cur_ax.set_xlim(-6, 6)\n",
" cur_ax.set_ylim(-40, 40)\n",
" for tick in cur_ax.xaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" tick.label.set_rotation(10)\n",
" for tick in cur_ax.yaxis.get_major_ticks():\n",
" tick.label.set_fontsize(LabelSize - font_gap)\n",
" \n",
" plt.legend(loc=1, fontsize=LegendFontsize)\n",
" fig.savefig(str(save_path) + '.pdf', dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
" fig.savefig(str(save_path) + '.png', dpi=dpi, bbox_inches=\"tight\", format=\"png\")\n",
" plt.close(\"all\")\n",
"\n",
"\n",
"def visualize_env(save_dir):\n",
" save_dir.mkdir(parents=True, exist_ok=True)\n",
" dynamic_env, function = create_example_v1(100, num_per_task=500)\n",
" \n",
" additional_xaxis = np.arange(-6, 6, 0.1)\n",
" for timestamp, dataset in dynamic_env:\n",
" num = dataset.shape[0]\n",
" # timeaxis = (torch.zeros(num) + timestamp).numpy()\n",
" xaxis = dataset[:,0].numpy()\n",
" xaxis = np.concatenate((additional_xaxis, xaxis))\n",
" # compute the ground truth\n",
" function.set_timestamp(timestamp)\n",
" yaxis = function(xaxis)\n",
" draw_fig(save_dir, timestamp, xaxis, yaxis)\n",
"\n",
"home_dir = Path.home()\n",
"desktop_dir = home_dir / 'Desktop'\n",
"vis_save_dir = desktop_dir / 'vis-synthetic'\n",
"visualize_env(vis_save_dir)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "rapid-uruguay",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ffmpeg -y -i /Users/xuanyidong/Desktop/vis-synthetic/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1000:1000 -vb 5000k /Users/xuanyidong/Desktop/vis-synthetic/vis.mp4\n"
]
},
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Plot the data\n",
"cmd = 'ffmpeg -y -i {:}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1000:1000 -vb 5000k {:}/vis.mp4'.format(vis_save_dir, vis_save_dir)\n",
"print(cmd)\n",
"os.system(cmd)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,277 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "3f754c96",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from xautodl import spaces\n",
"from xautodl.xlayers import super_core\n",
"\n",
"def _create_stel(input_dim, output_dim, order):\n",
" return super_core.SuperSequential(\n",
" super_core.SuperLinear(input_dim, output_dim),\n",
" super_core.SuperTransformerEncoderLayer(\n",
" output_dim,\n",
" num_heads=spaces.Categorical(2, 4, 6),\n",
" mlp_hidden_multiplier=spaces.Categorical(1, 2, 4),\n",
" order=order,\n",
" ),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "81d42f4b",
"metadata": {},
"outputs": [],
"source": [
"batch, seq_dim, input_dim = 1, 4, 6\n",
"order = super_core.LayerOrder.PreNorm"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8056b37c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SuperSequential(\n",
" (0): SuperSequential(\n",
" (0): SuperLinear(in_features=6, out_features=Categorical(candidates=[12, 24, 36], default_index=None), bias=True)\n",
" (1): SuperTransformerEncoderLayer(\n",
" (norm1): SuperLayerNorm1D(shape=Categorical(candidates=[12, 24, 36], default_index=None), eps=1e-06, elementwise_affine=True)\n",
" (mha): SuperSelfAttention(\n",
" input_dim=Categorical(candidates=[12, 24, 36], default_index=None), proj_dim=Categorical(candidates=[12, 24, 36], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n",
" (q_fc): SuperLinear(in_features=Categorical(candidates=[12, 24, 36], default_index=None), out_features=Categorical(candidates=[12, 24, 36], default_index=None), bias=False)\n",
" (k_fc): SuperLinear(in_features=Categorical(candidates=[12, 24, 36], default_index=None), out_features=Categorical(candidates=[12, 24, 36], default_index=None), bias=False)\n",
" (v_fc): SuperLinear(in_features=Categorical(candidates=[12, 24, 36], default_index=None), out_features=Categorical(candidates=[12, 24, 36], default_index=None), bias=False)\n",
" (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n",
" )\n",
" (drop): Dropout(p=0.0, inplace=False)\n",
" (norm2): SuperLayerNorm1D(shape=Categorical(candidates=[12, 24, 36], default_index=None), eps=1e-06, elementwise_affine=True)\n",
" (mlp): SuperMLPv2(\n",
" in_features=Categorical(candidates=[12, 24, 36], default_index=None), hidden_multiplier=Categorical(candidates=[1, 2, 4], default_index=None), out_features=Categorical(candidates=[12, 24, 36], default_index=None), drop=None, fc1 -> act -> drop -> fc2 -> drop,\n",
" (_params): ParameterDict(\n",
" (fc1_super_weight): Parameter containing: [torch.FloatTensor of size 144x36]\n",
" (fc1_super_bias): Parameter containing: [torch.FloatTensor of size 144]\n",
" (fc2_super_weight): Parameter containing: [torch.FloatTensor of size 36x144]\n",
" (fc2_super_bias): Parameter containing: [torch.FloatTensor of size 36]\n",
" )\n",
" (act): GELU()\n",
" (drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (1): SuperSequential(\n",
" (0): SuperLinear(in_features=Categorical(candidates=[12, 24, 36], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), bias=True)\n",
" (1): SuperTransformerEncoderLayer(\n",
" (norm1): SuperLayerNorm1D(shape=Categorical(candidates=[24, 36, 48], default_index=None), eps=1e-06, elementwise_affine=True)\n",
" (mha): SuperSelfAttention(\n",
" input_dim=Categorical(candidates=[24, 36, 48], default_index=None), proj_dim=Categorical(candidates=[24, 36, 48], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n",
" (q_fc): SuperLinear(in_features=Categorical(candidates=[24, 36, 48], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), bias=False)\n",
" (k_fc): SuperLinear(in_features=Categorical(candidates=[24, 36, 48], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), bias=False)\n",
" (v_fc): SuperLinear(in_features=Categorical(candidates=[24, 36, 48], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), bias=False)\n",
" (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n",
" )\n",
" (drop): Dropout(p=0.0, inplace=False)\n",
" (norm2): SuperLayerNorm1D(shape=Categorical(candidates=[24, 36, 48], default_index=None), eps=1e-06, elementwise_affine=True)\n",
" (mlp): SuperMLPv2(\n",
" in_features=Categorical(candidates=[24, 36, 48], default_index=None), hidden_multiplier=Categorical(candidates=[1, 2, 4], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), drop=None, fc1 -> act -> drop -> fc2 -> drop,\n",
" (_params): ParameterDict(\n",
" (fc1_super_weight): Parameter containing: [torch.FloatTensor of size 192x48]\n",
" (fc1_super_bias): Parameter containing: [torch.FloatTensor of size 192]\n",
" (fc2_super_weight): Parameter containing: [torch.FloatTensor of size 48x192]\n",
" (fc2_super_bias): Parameter containing: [torch.FloatTensor of size 48]\n",
" )\n",
" (act): GELU()\n",
" (drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (2): SuperSequential(\n",
" (0): SuperLinear(in_features=Categorical(candidates=[24, 36, 48], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=True)\n",
" (1): SuperTransformerEncoderLayer(\n",
" (norm1): SuperLayerNorm1D(shape=Categorical(candidates=[36, 72, 100], default_index=None), eps=1e-06, elementwise_affine=True)\n",
" (mha): SuperSelfAttention(\n",
" input_dim=Categorical(candidates=[36, 72, 100], default_index=None), proj_dim=Categorical(candidates=[36, 72, 100], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n",
" (q_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
" (k_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
" (v_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
" (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n",
" )\n",
" (drop): Dropout(p=0.0, inplace=False)\n",
" (norm2): SuperLayerNorm1D(shape=Categorical(candidates=[36, 72, 100], default_index=None), eps=1e-06, elementwise_affine=True)\n",
" (mlp): SuperMLPv2(\n",
" in_features=Categorical(candidates=[36, 72, 100], default_index=None), hidden_multiplier=Categorical(candidates=[1, 2, 4], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), drop=None, fc1 -> act -> drop -> fc2 -> drop,\n",
" (_params): ParameterDict(\n",
" (fc1_super_weight): Parameter containing: [torch.FloatTensor of size 400x100]\n",
" (fc1_super_bias): Parameter containing: [torch.FloatTensor of size 400]\n",
" (fc2_super_weight): Parameter containing: [torch.FloatTensor of size 100x400]\n",
" (fc2_super_bias): Parameter containing: [torch.FloatTensor of size 100]\n",
" )\n",
" (act): GELU()\n",
" (drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" )\n",
")\n"
]
}
],
"source": [
"out1_dim = spaces.Categorical(12, 24, 36)\n",
"out2_dim = spaces.Categorical(24, 36, 48)\n",
"out3_dim = spaces.Categorical(36, 72, 100)\n",
"layer1 = _create_stel(input_dim, out1_dim, order)\n",
"layer2 = _create_stel(out1_dim, out2_dim, order)\n",
"layer3 = _create_stel(out2_dim, out3_dim, order)\n",
"model = super_core.SuperSequential(layer1, layer2, layer3)\n",
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4fd53a7c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"> \u001b[0;32m/Users/xuanyidong/anaconda3/lib/python3.8/site-packages/xautodl-0.9.9-py3.8.egg/xautodl/xlayers/super_transformer.py\u001b[0m(116)\u001b[0;36mforward_raw\u001b[0;34m()\u001b[0m\n",
"\u001b[0;32m 114 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 115 \u001b[0;31m \u001b[0;31m# feed-forward layer -- MLP\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m--> 116 \u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 117 \u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmlp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 118 \u001b[0;31m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_order\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mLayerOrder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPostNorm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\n",
"ipdb> print(self)\n",
"SuperTransformerEncoderLayer(\n",
" (norm1): SuperLayerNorm1D(shape=Categorical(candidates=[36, 72, 100], default_index=None), eps=1e-06, elementwise_affine=True)\n",
" (mha): SuperSelfAttention(\n",
" input_dim=Categorical(candidates=[36, 72, 100], default_index=None), proj_dim=Categorical(candidates=[36, 72, 100], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n",
" (q_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
" (k_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
" (v_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
" (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n",
" )\n",
" (drop): Dropout(p=0.0, inplace=False)\n",
" (norm2): SuperLayerNorm1D(shape=Categorical(candidates=[36, 72, 100], default_index=None), eps=1e-06, elementwise_affine=True)\n",
" (mlp): SuperMLPv2(\n",
" in_features=Categorical(candidates=[36, 72, 100], default_index=None), hidden_multiplier=Categorical(candidates=[1, 2, 4], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), drop=None, fc1 -> act -> drop -> fc2 -> drop,\n",
" (_params): ParameterDict(\n",
" (fc1_super_weight): Parameter containing: [torch.FloatTensor of size 400x100]\n",
" (fc1_super_bias): Parameter containing: [torch.FloatTensor of size 400]\n",
" (fc2_super_weight): Parameter containing: [torch.FloatTensor of size 100x400]\n",
" (fc2_super_bias): Parameter containing: [torch.FloatTensor of size 100]\n",
" )\n",
" (act): GELU()\n",
" (drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
")\n",
"ipdb> print(inputs.shape)\n",
"torch.Size([1, 4, 100])\n",
"ipdb> print(x.shape)\n",
"torch.Size([1, 4, 96])\n",
"ipdb> print(self.mha)\n",
"SuperSelfAttention(\n",
" input_dim=Categorical(candidates=[36, 72, 100], default_index=None), proj_dim=Categorical(candidates=[36, 72, 100], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n",
" (q_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
" (k_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
" (v_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
" (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n",
")\n",
"ipdb> print(self.mha.candidate)\n",
"*** AttributeError: 'SuperSelfAttention' object has no attribute 'candidate'\n",
"ipdb> print(self.mha.abstract_candidate)\n",
"*** AttributeError: 'SuperSelfAttention' object has no attribute 'abstract_candidate'\n",
"ipdb> print(self.mha._abstract_child)\n",
"None\n",
"ipdb> print(self.abstract_child)\n",
"None\n",
"ipdb> print(self.abstract_child.abstract_child)\n",
"*** AttributeError: 'NoneType' object has no attribute 'abstract_child'\n",
"ipdb> print(self.mha)\n",
"SuperSelfAttention(\n",
" input_dim=Categorical(candidates=[36, 72, 100], default_index=None), proj_dim=Categorical(candidates=[36, 72, 100], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n",
" (q_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
" (k_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
" (v_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
" (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n",
")\n"
]
}
],
"source": [
"inputs = torch.rand(batch, seq_dim, input_dim)\n",
"outputs = model(inputs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "05332b98",
"metadata": {},
"outputs": [],
"source": [
"abstract_space = model.abstract_search_space\n",
"abstract_space.clean_last()\n",
"abstract_child = abstract_space.random(reuse_last=True)\n",
"# print(\"The abstract child program is:\\n{:}\".format(abstract_child))\n",
"model.enable_candidate()\n",
"model.set_super_run_type(super_core.SuperRunMode.Candidate)\n",
"model.apply_candidate(abstract_child)\n",
"outputs = model(inputs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3289f938",
"metadata": {},
"outputs": [],
"source": [
"print(outputs.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36951cdf",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}