add autodl
This commit is contained in:
118
AutoDL-Projects/notebooks/NATS-Bench/BayesOpt.ipynb
Normal file
118
AutoDL-Projects/notebooks/NATS-Bench/BayesOpt.ipynb
Normal file
@@ -0,0 +1,118 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "german-madonna",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Implementation for \"A Tutorial on Bayesian Optimization\"\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"def get_data():\n",
|
||||
" return np.random.random(2) * 10\n",
|
||||
"\n",
|
||||
"def f(x):\n",
|
||||
" return float(np.power((x[0] * 3 - x[1]), 3) - np.exp(x[1]) + np.power(x[0], 2))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "broke-citizenship",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Kernels typically have the property that points closer in the input space are more strongly correlated\n",
|
||||
"# i.e., if |x1 - x2| < |x1 - x3|, then sigma(x1, x2) > sigma(x1, x3).\n",
|
||||
"# the commonly used and simple kernel is the power exponential or Gaussian kernel:\n",
|
||||
"def sigma0(x1, x2, alpha0=1, alpha=[1,1]):\n",
|
||||
" \"\"\"alpha could be a vector\"\"\"\n",
|
||||
" power = np.array(alpha, dtype=np.float32) * np.power(np.array(x1)-np.array(x2), 2)\n",
|
||||
" return alpha0 * np.exp( -np.sum(power) )\n",
|
||||
"\n",
|
||||
"# the most common choice for the mean function is a constant value\n",
|
||||
"def mu0(x, mu):\n",
|
||||
" return mu"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "aerial-carnival",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"K = 5\n",
|
||||
"X = np.array([get_data() for i in range(K)])\n",
|
||||
"mu = np.mean(X, axis=0)\n",
|
||||
"mu0_over_K = [mu0(x, mu) for x in X]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "polished-discussion",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sigma0_over_KK = []\n",
|
||||
"for i in range(K):\n",
|
||||
" sigma0_over_KK.append(np.array([sigma0(X[i], X[j]) for j in range(K)]))\n",
|
||||
"sigma0_over_KK = np.array(sigma0_over_KK)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "comic-jesus",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(20, 20)\n",
|
||||
"1.1038803861344952e-06\n",
|
||||
"1.1038803861344952e-06\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(sigma0_over_KK.shape)\n",
|
||||
"print(sigma0_over_KK[1][2])\n",
|
||||
"print(sigma0_over_KK[2][1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "statistical-wrist",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
88
AutoDL-Projects/notebooks/NATS-Bench/find-largest.ipynb
Normal file
88
AutoDL-Projects/notebooks/NATS-Bench/find-largest.ipynb
Normal file
@@ -0,0 +1,88 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[2021-03-27 06:46:38] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from nats_bench import create\n",
|
||||
"from pprint import pprint\n",
|
||||
"# Create the API for tologoy search space\n",
|
||||
"api = create(None, 'tss', fast_mode=True, verbose=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'test-accuracy': 22.39999992879232,\n",
|
||||
" 'test-all-time': 7.7054752962929856,\n",
|
||||
" 'test-loss': 3.1626377182006835,\n",
|
||||
" 'test-per-time': 0.6421229413577488,\n",
|
||||
" 'train-accuracy': 21.68885959195242,\n",
|
||||
" 'train-all-time': 1260.0195466594694,\n",
|
||||
" 'train-loss': 3.1863493608815463,\n",
|
||||
" 'train-per-time': 105.00162888828912,\n",
|
||||
" 'valid-accuracy': 23.266666631062826,\n",
|
||||
" 'valid-all-time': 7.7054752962929856,\n",
|
||||
" 'valid-loss': 3.1219845104217527,\n",
|
||||
" 'valid-per-time': 0.6421229413577488,\n",
|
||||
" 'valtest-accuracy': 22.833333323160808,\n",
|
||||
" 'valtest-all-time': 15.410950592585971,\n",
|
||||
" 'valtest-loss': 3.142311067581177,\n",
|
||||
" 'valtest-per-time': 1.2842458827154977}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"largest_candidate_tss = '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|'\n",
|
||||
"\n",
|
||||
"arch_index = api.query_index_by_arch(largest_candidate_tss)\n",
|
||||
"info = api.get_more_info(arch_index, 'ImageNet16-120', hp='12', is_random=False)\n",
|
||||
"pprint(info)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
91
AutoDL-Projects/notebooks/NATS-Bench/issue-96.ipynb
Normal file
91
AutoDL-Projects/notebooks/NATS-Bench/issue-96.ipynb
Normal file
@@ -0,0 +1,91 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[2021-03-01 12:28:12] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from nats_bench import create\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"def get_correlation(A, B):\n",
|
||||
" return float(np.corrcoef(A, B)[0,1])\n",
|
||||
"\n",
|
||||
"# Create the API for tologoy search space\n",
|
||||
"api = create(None, 'tss', fast_mode=True, verbose=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"There are 15625 architectures on the topology search space\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print('There are {:} architectures on the topology search space'.format(len(api)))\n",
|
||||
"accuracies_12, accuracies_200 = [], []\n",
|
||||
"for i, arch in enumerate(api):\n",
|
||||
" info_a = api.get_more_info(i, dataset='cifar10-valid', hp='12', is_random=False)\n",
|
||||
" accuracies_12.append(info_a['valid-accuracy'])\n",
|
||||
"\n",
|
||||
" info_b = api.get_more_info(i, dataset='cifar10-valid', hp='200', is_random=False)\n",
|
||||
" accuracies_200.append(info_b['test-accuracy'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[CIFAR-10] The correlation between 12-epoch validation accuracy and 200-epoch test accuracy is: 91.18%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"correlation = get_correlation(accuracies_12, accuracies_200)\n",
|
||||
"print('[CIFAR-10] The correlation between 12-epoch validation accuracy and 200-epoch test accuracy is: {:.2f}%'.format(correlation * 100))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
86
AutoDL-Projects/notebooks/NATS-Bench/issue-97.ipynb
Normal file
86
AutoDL-Projects/notebooks/NATS-Bench/issue-97.ipynb
Normal file
@@ -0,0 +1,86 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[2021-03-09 08:44:19] Try to use the default NATS-Bench (size) path from fast_mode=True and path=None.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from nats_bench import create\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"# Create the API for size search space\n",
|
||||
"api = create(None, 'sss', fast_mode=True, verbose=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"There are 32768 architectures on the size search space\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print('There are {:} architectures on the size search space'.format(len(api)))\n",
|
||||
"\n",
|
||||
"c2acc = dict()\n",
|
||||
"for index in range(len(api)):\n",
|
||||
" info = api.get_more_info(index, 'cifar10', hp='90')\n",
|
||||
" config = api.get_net_config(index, 'cifar10')\n",
|
||||
" c2acc[config['channels']] = info['test-accuracy']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"91.08546417236329\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(np.mean(list(c2acc.values())))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
274
AutoDL-Projects/notebooks/Q/qlib-data-play.ipynb
Normal file
274
AutoDL-Projects/notebooks/Q/qlib-data-play.ipynb
Normal file
@@ -0,0 +1,274 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[82189:MainThread](2021-03-02 21:02:54,241) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
|
||||
"[82189:MainThread](2021-03-02 21:02:54,255) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
|
||||
"[82189:MainThread](2021-03-02 21:02:54,828) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
|
||||
"[82189:MainThread](2021-03-02 21:02:54,829) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import qlib\n",
|
||||
"import pprint\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"qlib.init(provider_uri='~/.qlib/qlib_data/cn_data')\n",
|
||||
"\n",
|
||||
"from qlib.config import C\n",
|
||||
"from qlib.data import D\n",
|
||||
"from qlib.data.data import DatasetD, ExpressionD, Inst, Cal, FeatureD\n",
|
||||
"from qlib.data.cache import H\n",
|
||||
"from qlib.data.filter import NameDFilter\n",
|
||||
"from qlib.utils import code_to_fname, read_bin"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<class 'pandas.core.frame.DataFrame'>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"nameDFilter = NameDFilter(name_rule_re='SH[0-9]{4}55')\n",
|
||||
"instruments_config = D.instruments(market='csi300', filter_pipe=[nameDFilter])\n",
|
||||
"instruments = D.list_instruments(instruments=instruments_config,\n",
|
||||
" start_time='2015-01-01',\n",
|
||||
" end_time='2016-02-15',\n",
|
||||
" as_list=True)\n",
|
||||
"\n",
|
||||
"fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']\n",
|
||||
"features = D.features(instruments_config, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day')\n",
|
||||
"print(type(features))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" $close $volume Ref($close, 1) Mean($close, 3) \\\n",
|
||||
"instrument datetime \n",
|
||||
"SH600655 2010-01-04 8.934296 47799352.0 8.667867 8.691138 \n",
|
||||
" 2010-01-05 8.889880 29791234.0 8.934296 8.830681 \n",
|
||||
" 2010-01-06 8.845468 29002874.0 8.889880 8.889881 \n",
|
||||
" 2010-01-07 8.553690 38189440.0 8.845468 8.763013 \n",
|
||||
" 2010-01-08 8.645658 23417642.0 8.553690 8.681605 \n",
|
||||
"... ... ... ... ... \n",
|
||||
"SH601555 2017-12-25 1.393481 80615584.0 1.406559 1.408012 \n",
|
||||
" 2017-12-26 1.406559 64259856.0 1.393481 1.402200 \n",
|
||||
" 2017-12-27 1.400747 58551256.0 1.406559 1.400262 \n",
|
||||
" 2017-12-28 1.412371 96204872.0 1.400747 1.406559 \n",
|
||||
" 2017-12-29 1.412371 52801024.0 1.412371 1.408496 \n",
|
||||
"\n",
|
||||
" $high-$low \n",
|
||||
"instrument datetime \n",
|
||||
"SH600655 2010-01-04 0.412291 \n",
|
||||
" 2010-01-05 0.203006 \n",
|
||||
" 2010-01-06 0.250560 \n",
|
||||
" 2010-01-07 0.412291 \n",
|
||||
" 2010-01-08 0.275964 \n",
|
||||
"... ... \n",
|
||||
"SH601555 2017-12-25 0.020343 \n",
|
||||
" 2017-12-26 0.018890 \n",
|
||||
" 2017-12-27 0.017437 \n",
|
||||
" 2017-12-28 0.045045 \n",
|
||||
" 2017-12-29 0.013078 \n",
|
||||
"\n",
|
||||
"[2867 rows x 5 columns]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(features)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<class 'qlib.data.data.LocalProvider'>\n",
|
||||
"<class 'qlib.config.QlibConfig'>\n",
|
||||
"LocalProvider\n",
|
||||
"Wrapper(provider=<qlib.data.data.LocalProvider object at 0x7ff5601cb370>)\n",
|
||||
"Wrapper(provider=<qlib.data.data.LocalDatasetProvider object at 0x7ff5601c3b80>)\n",
|
||||
"<qlib.data.data.LocalDatasetProvider object at 0x7ff5601c3b80>\n",
|
||||
"LocalDatasetProvider\n",
|
||||
"--\n",
|
||||
"Wrapper(provider=<qlib.data.data.LocalInstrumentProvider object at 0x7ff55fb73340>)\n",
|
||||
"<qlib.data.data.LocalInstrumentProvider object at 0x7ff55fb73340>\n",
|
||||
"default_disk_cache: 1\n",
|
||||
"ExpressionD: Wrapper(provider=<qlib.data.data.LocalExpressionProvider object at 0x7ff5601c3bb0>)\n",
|
||||
"FeatureD : <qlib.data.data.LocalFeatureProvider object at 0x7ff55fb84430>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Provider:\n",
|
||||
"print(type(D._provider))\n",
|
||||
"print(type(C))\n",
|
||||
"print(C.provider)\n",
|
||||
"print(D)\n",
|
||||
"\n",
|
||||
"# DatasetD Provider\n",
|
||||
"print(DatasetD)\n",
|
||||
"print(DatasetD._provider)\n",
|
||||
"print(C.dataset_provider)\n",
|
||||
"\n",
|
||||
"print('--')\n",
|
||||
"print(Inst)\n",
|
||||
"print(Inst._provider)\n",
|
||||
"\n",
|
||||
"# Default Disk Cache\n",
|
||||
"print('default_disk_cache: {:}'.format(C.default_disk_cache))\n",
|
||||
"print('ExpressionD: {:}'.format(ExpressionD))\n",
|
||||
"print('FeatureD : {:}'.format(FeatureD._provider))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'pprint' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m<ipython-input-4-76544a8bb578>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpprint\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minstruments_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0minstruments_d\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDatasetD\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_provider\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_instruments_d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minstruments_config\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfreq\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'day'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpprint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mpprint\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minstruments_d\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mNameError\u001b[0m: name 'pprint' is not defined"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pprint.pprint(instruments_config)\n",
|
||||
"instruments_d = DatasetD._provider.get_instruments_d(instruments_config, freq='day')\n",
|
||||
"pprint.pprint(instruments_d)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2012-12-31 00:00:00 -> 2019-01-18 00:00:00\n",
|
||||
"<PandasArray>\n",
|
||||
"[1.1059314, 1.0935822, 1.1059314, 1.0922102, 1.0839773, 1.0839773, 1.0181155,\n",
|
||||
" 1.0730004, 1.0867218, 1.068884,\n",
|
||||
" ...\n",
|
||||
" 1.1163876, 1.1208236, 1.1119517, 1.0986437, 1.1075157, 1.0971651, 1.1149089,\n",
|
||||
" 1.083857, 1.083857, 1.0956864]\n",
|
||||
"Length: 1439, dtype: float32\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"instrument, field, freq = 'SH601555', '$close', 'day'\n",
|
||||
"all_dates = D.calendar(start_time='2011-12-31', end_time='2019-02-10', freq=freq)\n",
|
||||
"start_time, end_time = all_dates[0], all_dates[-11]\n",
|
||||
"print(str(start_time) + ' -> ' + str(end_time))\n",
|
||||
"obj = ExpressionD.expression(instrument, field, start_time, end_time, freq)\n",
|
||||
"print(obj.array)\n",
|
||||
"\n",
|
||||
"# expression = ExpressionD.get_expression_instance(field)\n",
|
||||
"# start_time = pd.Timestamp(start_time)\n",
|
||||
"# end_time = pd.Timestamp(end_time)\n",
|
||||
"# _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq='day', future=False)\n",
|
||||
"# print(start_index)\n",
|
||||
"# print(end_index)\n",
|
||||
"\n",
|
||||
"# fname = code_to_fname(instrument)\n",
|
||||
"# uri_data = FeatureD._uri_data.format(instrument.lower(), field[1:], freq)\n",
|
||||
"# print(uri_data)\n",
|
||||
"# # series = read_bin(uri_data, start_index, end_index)\n",
|
||||
"# series = read_bin(uri_data, 2850, 2870)\n",
|
||||
"# print(series)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Wrapper(provider=<qlib.data.data.LocalProvider object at 0x7ff5601cb370>)\n",
|
||||
"Wrapper(provider=<qlib.data.data.LocalInstrumentProvider object at 0x7ff55fb73340>)\n",
|
||||
"Wrapper(provider=<qlib.data.data.LocalExpressionProvider object at 0x7ff5601c3bb0>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from qlib.data import D\n",
|
||||
"from qlib.data.data import ExpressionD, Inst\n",
|
||||
"print(D)\n",
|
||||
"print(Inst)\n",
|
||||
"print(ExpressionD)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
162
AutoDL-Projects/notebooks/Q/workflow-test.ipynb
Normal file
162
AutoDL-Projects/notebooks/Q/workflow-test.ipynb
Normal file
@@ -0,0 +1,162 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"library path: /Users/xuanyidong/Desktop/XAutoDL/lib\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[61704:MainThread](2021-03-22 13:56:38,104) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
|
||||
"[61704:MainThread](2021-03-22 13:56:38,106) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
|
||||
"[61704:MainThread](2021-03-22 13:56:38,680) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
|
||||
"[61704:MainThread](2021-03-22 13:56:38,681) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'class': 'DatasetH',\n",
|
||||
" 'kwargs': {'handler': {'class': 'Alpha158',\n",
|
||||
" 'kwargs': {'end_time': '2020-08-01',\n",
|
||||
" 'fit_end_time': '2014-12-31',\n",
|
||||
" 'fit_start_time': '2008-01-01',\n",
|
||||
" 'instruments': 'csi100',\n",
|
||||
" 'start_time': '2008-01-01'},\n",
|
||||
" 'module_path': 'qlib.contrib.data.handler'},\n",
|
||||
" 'segments': {'test': ('2017-01-01', '2020-08-01'),\n",
|
||||
" 'train': ('2008-01-01', '2014-12-31'),\n",
|
||||
" 'valid': ('2015-01-01', '2016-12-31')}},\n",
|
||||
" 'module_path': 'qlib.data.dataset'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import qlib\n",
|
||||
"import pprint\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
|
||||
"\n",
|
||||
"lib_dir = (Path(__file__).parent / \"..\" / \"lib\").resolve()\n",
|
||||
"print(\"library path: {:}\".format(lib_dir))\n",
|
||||
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
|
||||
"if str(lib_dir) not in sys.path:\n",
|
||||
" sys.path.insert(0, str(lib_dir))\n",
|
||||
"\n",
|
||||
"from qlib import config as qconfig\n",
|
||||
"from qlib.utils import init_instance_by_config\n",
|
||||
"\n",
|
||||
"qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)\n",
|
||||
"\n",
|
||||
"dataset_config = {\n",
|
||||
" \"class\": \"DatasetH\",\n",
|
||||
" \"module_path\": \"qlib.data.dataset\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"handler\": {\n",
|
||||
" \"class\": \"Alpha158\",\n",
|
||||
" \"module_path\": \"qlib.contrib.data.handler\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"start_time\": \"2008-01-01\",\n",
|
||||
" \"end_time\": \"2020-08-01\",\n",
|
||||
" \"fit_start_time\": \"2008-01-01\",\n",
|
||||
" \"fit_end_time\": \"2014-12-31\",\n",
|
||||
" \"instruments\": \"csi100\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"segments\": {\n",
|
||||
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
|
||||
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
|
||||
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
"pprint.pprint(dataset_config)\n",
|
||||
"dataset = init_instance_by_config(dataset_config)\n",
|
||||
"\n",
|
||||
"df_train, df_valid, df_test = dataset.prepare(\n",
|
||||
" [\"train\", \"valid\", \"test\"],\n",
|
||||
" col_set=[\"feature\", \"label\"],\n",
|
||||
" data_key=DataHandlerLP.DK_L,\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'class': 'DatasetH',\n",
|
||||
" 'kwargs': {'handler': {'class': 'Alpha158',\n",
|
||||
" 'kwargs': {'end_time': '2020-08-01',\n",
|
||||
" 'fit_end_time': '2014-12-31',\n",
|
||||
" 'fit_start_time': '2008-01-01',\n",
|
||||
" 'instruments': 'csi300',\n",
|
||||
" 'start_time': '2008-01-01'},\n",
|
||||
" 'module_path': 'qlib.contrib.data.handler'},\n",
|
||||
" 'segments': {'test': ('2017-01-01', '2020-08-01'),\n",
|
||||
" 'train': ('2008-01-01', '2014-12-31'),\n",
|
||||
" 'valid': ('2015-01-01', '2016-12-31')}},\n",
|
||||
" 'module_path': 'qlib.data.dataset'}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[95290:MainThread](2021-03-03 12:18:43,481) INFO - qlib.timer - [log.py:81] - Time cost: 237.911s | Loading data Done\n",
|
||||
"[95290:MainThread](2021-03-03 12:18:45,080) INFO - qlib.timer - [log.py:81] - Time cost: 0.465s | DropnaLabel Done\n",
|
||||
"[95290:MainThread](2021-03-03 12:18:51,572) INFO - qlib.timer - [log.py:81] - Time cost: 6.491s | CSZScoreNorm Done\n",
|
||||
"[95290:MainThread](2021-03-03 12:18:51,573) INFO - qlib.timer - [log.py:81] - Time cost: 8.090s | fit & process data Done\n",
|
||||
"[95290:MainThread](2021-03-03 12:18:51,573) INFO - qlib.timer - [log.py:81] - Time cost: 246.003s | Init data Done\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from trade_models.transformations import get_transformer\n",
|
||||
"\n",
|
||||
"model = get_transformer(None)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
311
AutoDL-Projects/notebooks/TOT/ES-Model-DC.ipynb
Normal file
311
AutoDL-Projects/notebooks/TOT/ES-Model-DC.ipynb
Normal file
@@ -0,0 +1,311 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "afraid-minutes",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
|
||||
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[70148:MainThread](2021-04-12 13:23:30,262) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
|
||||
"[70148:MainThread](2021-04-12 13:23:30,266) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
|
||||
"[70148:MainThread](2021-04-12 13:23:30,269) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
|
||||
"[70148:MainThread](2021-04-12 13:23:30,271) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#\n",
|
||||
"# Exhaustive Search Results\n",
|
||||
"#\n",
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import sys\n",
|
||||
"import qlib\n",
|
||||
"import pprint\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
|
||||
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
|
||||
"lib_dir = (root_dir / \"lib\").resolve()\n",
|
||||
"print(\"The root path: {:}\".format(root_dir))\n",
|
||||
"print(\"The library path: {:}\".format(lib_dir))\n",
|
||||
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
|
||||
"if str(lib_dir) not in sys.path:\n",
|
||||
" sys.path.insert(0, str(lib_dir))\n",
|
||||
"\n",
|
||||
"import qlib\n",
|
||||
"from qlib import config as qconfig\n",
|
||||
"from qlib.workflow import R\n",
|
||||
"qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "hidden-exemption",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from utils.qlib_utils import QResult"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "continental-drain",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def filter_finished(recorders):\n",
|
||||
" returned_recorders = dict()\n",
|
||||
" not_finished = 0\n",
|
||||
" for key, recorder in recorders.items():\n",
|
||||
" if recorder.status == \"FINISHED\":\n",
|
||||
" returned_recorders[key] = recorder\n",
|
||||
" else:\n",
|
||||
" not_finished += 1\n",
|
||||
" return returned_recorders, not_finished\n",
|
||||
"\n",
|
||||
"def query_info(save_dir, verbose, name_filter, key_map):\n",
|
||||
" if isinstance(save_dir, list):\n",
|
||||
" results = []\n",
|
||||
" for x in save_dir:\n",
|
||||
" x = query_info(x, verbose, name_filter, key_map)\n",
|
||||
" results.extend(x)\n",
|
||||
" return results\n",
|
||||
" # Here, the save_dir must be a string\n",
|
||||
" R.set_uri(str(save_dir))\n",
|
||||
" experiments = R.list_experiments()\n",
|
||||
"\n",
|
||||
" if verbose:\n",
|
||||
" print(\"There are {:} experiments.\".format(len(experiments)))\n",
|
||||
" qresults = []\n",
|
||||
" for idx, (key, experiment) in enumerate(experiments.items()):\n",
|
||||
" if experiment.id == \"0\":\n",
|
||||
" continue\n",
|
||||
" if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:\n",
|
||||
" continue\n",
|
||||
" recorders = experiment.list_recorders()\n",
|
||||
" recorders, not_finished = filter_finished(recorders)\n",
|
||||
" if verbose:\n",
|
||||
" print(\n",
|
||||
" \"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.\".format(\n",
|
||||
" idx + 1,\n",
|
||||
" len(experiments),\n",
|
||||
" experiment.name,\n",
|
||||
" len(recorders),\n",
|
||||
" len(recorders) + not_finished,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" result = QResult(experiment.name)\n",
|
||||
" for recorder_id, recorder in recorders.items():\n",
|
||||
" result.update(recorder.list_metrics(), key_map)\n",
|
||||
" result.append_path(\n",
|
||||
" os.path.join(recorder.uri, recorder.experiment_id, recorder.id)\n",
|
||||
" )\n",
|
||||
" if not len(result):\n",
|
||||
" print(\"There are no valid recorders for {:}\".format(experiment))\n",
|
||||
" continue\n",
|
||||
" else:\n",
|
||||
" if verbose:\n",
|
||||
" print(\n",
|
||||
" \"There are {:} valid recorders for {:}\".format(\n",
|
||||
" len(recorders), experiment.name\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" qresults.append(result)\n",
|
||||
" return qresults"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "filled-multiple",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[70148:MainThread](2021-04-12 13:23:31,137) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7f8c4a47efa0>\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[PosixPath('/Users/xuanyidong/Desktop/AutoDL-Projects/outputs/qlib-baselines-csi300')]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']\n",
|
||||
"paths = [path.resolve() for path in paths]\n",
|
||||
"print(paths)\n",
|
||||
"\n",
|
||||
"key_map = dict()\n",
|
||||
"for xset in (\"train\", \"valid\", \"test\"):\n",
|
||||
" key_map[\"{:}-mean-IC\".format(xset)] = \"IC ({:})\".format(xset)\n",
|
||||
" key_map[\"{:}-mean-ICIR\".format(xset)] = \"ICIR ({:})\".format(xset)\n",
|
||||
"qresults = query_info(paths, False, 'TSF-.*-drop0_0', key_map)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "intimate-approval",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib\n",
|
||||
"from matplotlib import cm\n",
|
||||
"matplotlib.use(\"agg\")\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.ticker as ticker"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "supreme-basis",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def vis_depth_channel(qresults, save_path):\n",
|
||||
" save_dir = (save_path / '..').resolve()\n",
|
||||
" save_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
" print('There are {:} qlib-results'.format(len(qresults)))\n",
|
||||
" \n",
|
||||
" dpi, width, height = 200, 4000, 2000\n",
|
||||
" figsize = width / float(dpi), height / float(dpi)\n",
|
||||
" LabelSize, LegendFontsize = 22, 12\n",
|
||||
" font_gap = 5\n",
|
||||
" \n",
|
||||
" fig = plt.figure(figsize=figsize)\n",
|
||||
" # fig, axs = plt.subplots(1, 2, figsize=figsize, projection='3d')\n",
|
||||
" \n",
|
||||
" def plot_ax(cur_ax, train_or_test):\n",
|
||||
" depths, channels = [], []\n",
|
||||
" ic_values, xmaps = [], dict()\n",
|
||||
" for qresult in qresults:\n",
|
||||
" name = qresult.name.split('-')[1]\n",
|
||||
" depths.append(float(name.split('x')[0]))\n",
|
||||
" channels.append(float(name.split('x')[1]))\n",
|
||||
" if train_or_test:\n",
|
||||
" ic_values.append(qresult['IC (train)'])\n",
|
||||
" else:\n",
|
||||
" ic_values.append(qresult['IC (valid)'])\n",
|
||||
" xmaps[(depths[-1], channels[-1])] = ic_values[-1]\n",
|
||||
" # cur_ax.scatter(depths, channels, ic_values, marker='o', c=\"tab:orange\")\n",
|
||||
" raw_depths = np.arange(1, 9, dtype=np.int32)\n",
|
||||
" raw_channels = np.array([6, 12, 24, 32, 48, 64], dtype=np.int32)\n",
|
||||
" depths, channels = np.meshgrid(raw_depths, raw_channels)\n",
|
||||
" ic_values = np.sin(depths) # initialize\n",
|
||||
" # print(ic_values.shape)\n",
|
||||
" num_x, num_y = ic_values.shape\n",
|
||||
" for i in range(num_x):\n",
|
||||
" for j in range(num_y):\n",
|
||||
" xkey = (int(depths[i][j]), int(channels[i][j]))\n",
|
||||
" if xkey not in xmaps:\n",
|
||||
" raise ValueError(\"Did not find {:}\".format(xkey))\n",
|
||||
" ic_values[i][j] = xmaps[xkey]\n",
|
||||
" #print(sorted(list(xmaps.keys())))\n",
|
||||
" #surf = cur_ax.plot_surface(\n",
|
||||
" # np.array(depths), np.array(channels), np.array(ic_values),\n",
|
||||
" # cmap=cm.coolwarm, linewidth=0, antialiased=False)\n",
|
||||
" surf = cur_ax.plot_surface(\n",
|
||||
" depths, channels, ic_values,\n",
|
||||
" cmap=cm.Spectral, linewidth=0.2, antialiased=True)\n",
|
||||
" cur_ax.set_xticks(raw_depths)\n",
|
||||
" cur_ax.set_yticks(raw_channels)\n",
|
||||
" cur_ax.set_zticks(np.arange(4, 11, 2))\n",
|
||||
" cur_ax.set_xlabel(\"#depth\", fontsize=LabelSize)\n",
|
||||
" cur_ax.set_ylabel(\"#channels\", fontsize=LabelSize)\n",
|
||||
" cur_ax.set_zlabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n",
|
||||
" for tick in cur_ax.xaxis.get_major_ticks():\n",
|
||||
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||
" for tick in cur_ax.yaxis.get_major_ticks():\n",
|
||||
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||
" for tick in cur_ax.zaxis.get_major_ticks():\n",
|
||||
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||
" # Add a color bar which maps values to colors.\n",
|
||||
"# cax = fig.add_axes([cur_ax.get_position().x1 + 0.01,\n",
|
||||
"# cur_ax.get_position().y0,\n",
|
||||
"# 0.01,\n",
|
||||
"# cur_ax.get_position().height * 0.9])\n",
|
||||
" # fig.colorbar(surf, cax=cax)\n",
|
||||
" # fig.colorbar(surf, shrink=0.5, aspect=5)\n",
|
||||
" # import pdb; pdb.set_trace()\n",
|
||||
" # ax1.legend(loc=4, fontsize=LegendFontsize)\n",
|
||||
" ax = fig.add_subplot(1, 2, 1, projection='3d')\n",
|
||||
" plot_ax(ax, True)\n",
|
||||
" ax = fig.add_subplot(1, 2, 2, projection='3d')\n",
|
||||
" plot_ax(ax, False)\n",
|
||||
" # fig.tight_layout()\n",
|
||||
" plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n",
|
||||
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
|
||||
" plt.close(\"all\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "shared-envelope",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The Desktop is at: /Users/xuanyidong/Desktop\n",
|
||||
"There are 48 qlib-results\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Visualization\n",
|
||||
"home_dir = Path.home()\n",
|
||||
"desktop_dir = home_dir / 'Desktop'\n",
|
||||
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
|
||||
"\n",
|
||||
"vis_depth_channel(qresults, desktop_dir / 'es_csi300_d_vs_c.pdf')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
312
AutoDL-Projects/notebooks/TOT/ES-Model-Drop.ipynb
Normal file
312
AutoDL-Projects/notebooks/TOT/ES-Model-Drop.ipynb
Normal file
@@ -0,0 +1,312 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "afraid-minutes",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
|
||||
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[70363:MainThread](2021-04-12 13:25:01,065) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
|
||||
"[70363:MainThread](2021-04-12 13:25:01,069) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
|
||||
"[70363:MainThread](2021-04-12 13:25:01,085) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
|
||||
"[70363:MainThread](2021-04-12 13:25:01,092) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#\n",
|
||||
"# Exhaustive Search Results\n",
|
||||
"#\n",
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import sys\n",
|
||||
"import qlib\n",
|
||||
"import pprint\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
|
||||
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
|
||||
"lib_dir = (root_dir / \"lib\").resolve()\n",
|
||||
"print(\"The root path: {:}\".format(root_dir))\n",
|
||||
"print(\"The library path: {:}\".format(lib_dir))\n",
|
||||
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
|
||||
"if str(lib_dir) not in sys.path:\n",
|
||||
" sys.path.insert(0, str(lib_dir))\n",
|
||||
"\n",
|
||||
"import qlib\n",
|
||||
"from qlib import config as qconfig\n",
|
||||
"from qlib.workflow import R\n",
|
||||
"qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "hidden-exemption",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from utils.qlib_utils import QResult"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "continental-drain",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def filter_finished(recorders):\n",
|
||||
" returned_recorders = dict()\n",
|
||||
" not_finished = 0\n",
|
||||
" for key, recorder in recorders.items():\n",
|
||||
" if recorder.status == \"FINISHED\":\n",
|
||||
" returned_recorders[key] = recorder\n",
|
||||
" else:\n",
|
||||
" not_finished += 1\n",
|
||||
" return returned_recorders, not_finished\n",
|
||||
"\n",
|
||||
"def query_info(save_dir, verbose, name_filter, key_map):\n",
|
||||
" if isinstance(save_dir, list):\n",
|
||||
" results = []\n",
|
||||
" for x in save_dir:\n",
|
||||
" x = query_info(x, verbose, name_filter, key_map)\n",
|
||||
" results.extend(x)\n",
|
||||
" return results\n",
|
||||
" # Here, the save_dir must be a string\n",
|
||||
" R.set_uri(str(save_dir))\n",
|
||||
" experiments = R.list_experiments()\n",
|
||||
"\n",
|
||||
" if verbose:\n",
|
||||
" print(\"There are {:} experiments.\".format(len(experiments)))\n",
|
||||
" qresults = []\n",
|
||||
" for idx, (key, experiment) in enumerate(experiments.items()):\n",
|
||||
" if experiment.id == \"0\":\n",
|
||||
" continue\n",
|
||||
" if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:\n",
|
||||
" continue\n",
|
||||
" recorders = experiment.list_recorders()\n",
|
||||
" recorders, not_finished = filter_finished(recorders)\n",
|
||||
" if verbose:\n",
|
||||
" print(\n",
|
||||
" \"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.\".format(\n",
|
||||
" idx + 1,\n",
|
||||
" len(experiments),\n",
|
||||
" experiment.name,\n",
|
||||
" len(recorders),\n",
|
||||
" len(recorders) + not_finished,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" result = QResult(experiment.name)\n",
|
||||
" for recorder_id, recorder in recorders.items():\n",
|
||||
" result.update(recorder.list_metrics(), key_map)\n",
|
||||
" result.append_path(\n",
|
||||
" os.path.join(recorder.uri, recorder.experiment_id, recorder.id)\n",
|
||||
" )\n",
|
||||
" if not len(result):\n",
|
||||
" print(\"There are no valid recorders for {:}\".format(experiment))\n",
|
||||
" continue\n",
|
||||
" else:\n",
|
||||
" if verbose:\n",
|
||||
" print(\n",
|
||||
" \"There are {:} valid recorders for {:}\".format(\n",
|
||||
" len(recorders), experiment.name\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" qresults.append(result)\n",
|
||||
" return qresults"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "filled-multiple",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[70363:MainThread](2021-04-12 13:25:01,647) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7fa920e56820>\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[PosixPath('/Users/xuanyidong/Desktop/AutoDL-Projects/outputs/qlib-baselines-csi300')]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']\n",
|
||||
"paths = [path.resolve() for path in paths]\n",
|
||||
"print(paths)\n",
|
||||
"\n",
|
||||
"key_map = dict()\n",
|
||||
"for xset in (\"train\", \"valid\", \"test\"):\n",
|
||||
" key_map[\"{:}-mean-IC\".format(xset)] = \"IC ({:})\".format(xset)\n",
|
||||
" key_map[\"{:}-mean-ICIR\".format(xset)] = \"ICIR ({:})\".format(xset)\n",
|
||||
"\n",
|
||||
"qresults = query_info(paths, False, 'TSF-.*', key_map)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "intimate-approval",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib\n",
|
||||
"from matplotlib import cm\n",
|
||||
"matplotlib.use(\"agg\")\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.ticker as ticker"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "supreme-basis",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def vis_dropouts(qresults, basenames, name2suffix, save_path):\n",
|
||||
" save_dir = (save_path / '..').resolve()\n",
|
||||
" save_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
" print('There are {:} qlib-results'.format(len(qresults)))\n",
|
||||
" \n",
|
||||
" name2qresult = dict()\n",
|
||||
" for qresult in qresults:\n",
|
||||
" name2qresult[qresult.name] = qresult\n",
|
||||
" # sort architectures\n",
|
||||
" accuracies = []\n",
|
||||
" for basename in basenames:\n",
|
||||
" qresult = name2qresult[basename + '-drop0_0']\n",
|
||||
" accuracies.append(qresult['ICIR (train)'])\n",
|
||||
" sorted_basenames = sorted(basenames, key=lambda x: accuracies[basenames.index(x)])\n",
|
||||
" \n",
|
||||
" dpi, width, height = 200, 4000, 2000\n",
|
||||
" figsize = width / float(dpi), height / float(dpi)\n",
|
||||
" LabelSize, LegendFontsize = 22, 22\n",
|
||||
" font_gap = 5\n",
|
||||
" colors = ['k', 'r']\n",
|
||||
" markers = ['*', 'o']\n",
|
||||
" \n",
|
||||
" fig = plt.figure(figsize=figsize)\n",
|
||||
" \n",
|
||||
" def plot_ax(cur_ax, train_or_test):\n",
|
||||
" for idx, (legend, suffix) in enumerate(name2suffix.items()):\n",
|
||||
" x_values = list(range(len(sorted_basenames)))\n",
|
||||
" y_values = []\n",
|
||||
" for i, name in enumerate(sorted_basenames):\n",
|
||||
" name = '{:}{:}'.format(name, suffix)\n",
|
||||
" qresult = name2qresult[name]\n",
|
||||
" if train_or_test:\n",
|
||||
" value = qresult['IC (train)']\n",
|
||||
" else:\n",
|
||||
" value = qresult['IC (valid)']\n",
|
||||
" y_values.append(value)\n",
|
||||
" cur_ax.plot(x_values, y_values, c=colors[idx])\n",
|
||||
" cur_ax.scatter(x_values, y_values,\n",
|
||||
" marker=markers[idx], s=3, c=colors[idx], alpha=0.9,\n",
|
||||
" label=legend)\n",
|
||||
" cur_ax.set_yticks(np.arange(4, 11, 2))\n",
|
||||
" cur_ax.set_xlabel(\"sorted architectures\", fontsize=LabelSize)\n",
|
||||
" cur_ax.set_ylabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n",
|
||||
" for tick in cur_ax.xaxis.get_major_ticks():\n",
|
||||
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||
" for tick in cur_ax.yaxis.get_major_ticks():\n",
|
||||
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||
" cur_ax.legend(loc=4, fontsize=LegendFontsize)\n",
|
||||
" ax = fig.add_subplot(1, 2, 1)\n",
|
||||
" plot_ax(ax, True)\n",
|
||||
" ax = fig.add_subplot(1, 2, 2)\n",
|
||||
" plot_ax(ax, False)\n",
|
||||
" # fig.tight_layout()\n",
|
||||
" # plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n",
|
||||
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
|
||||
" plt.close(\"all\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "shared-envelope",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'TSF-3x48', 'TSF-2x64', 'TSF-2x12', 'TSF-8x48', 'TSF-6x32', 'TSF-4x48', 'TSF-8x6', 'TSF-4x6', 'TSF-2x32', 'TSF-5x12', 'TSF-5x64', 'TSF-1x64', 'TSF-2x24', 'TSF-8x24', 'TSF-4x12', 'TSF-6x12', 'TSF-1x32', 'TSF-5x32', 'TSF-3x24', 'TSF-8x12', 'TSF-5x48', 'TSF-6x64', 'TSF-7x64', 'TSF-7x48', 'TSF-1x6', 'TSF-2x48', 'TSF-7x24', 'TSF-3x32', 'TSF-1x24', 'TSF-4x64', 'TSF-3x12', 'TSF-8x64', 'TSF-4x32', 'TSF-5x6', 'TSF-7x6', 'TSF-7x12', 'TSF-3x6', 'TSF-4x24', 'TSF-6x48', 'TSF-6x6', 'TSF-1x48', 'TSF-1x12', 'TSF-7x32', 'TSF-5x24', 'TSF-2x6', 'TSF-6x24', 'TSF-3x64', 'TSF-8x32'}\n",
|
||||
"The Desktop is at: /Users/xuanyidong/Desktop\n",
|
||||
"There are 104 qlib-results\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Visualization\n",
|
||||
"names = [qresult.name for qresult in qresults]\n",
|
||||
"base_names = set()\n",
|
||||
"for name in names:\n",
|
||||
" base_name = name.split('-drop')[0]\n",
|
||||
" base_names.add(base_name)\n",
|
||||
"print(base_names)\n",
|
||||
"# filter\n",
|
||||
"filtered_base_names = set()\n",
|
||||
"for base_name in base_names:\n",
|
||||
" if (base_name + '-drop0_0') in names and (base_name + '-drop0.1_0') in names:\n",
|
||||
" filtered_base_names.add(base_name)\n",
|
||||
" else:\n",
|
||||
" print('Cannot find all names for {:}'.format(base_name))\n",
|
||||
"# print(filtered_base_names)\n",
|
||||
"home_dir = Path.home()\n",
|
||||
"desktop_dir = home_dir / 'Desktop'\n",
|
||||
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
|
||||
"\n",
|
||||
"vis_dropouts(qresults, list(filtered_base_names),\n",
|
||||
" {'No-dropout': '-drop0_0',\n",
|
||||
" 'Ratio=0.1' : '-drop0.1_0'},\n",
|
||||
" desktop_dir / 'es_csi300_drop.pdf')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
208
AutoDL-Projects/notebooks/TOT/Time-Curve.ipynb
Normal file
208
AutoDL-Projects/notebooks/TOT/Time-Curve.ipynb
Normal file
@@ -0,0 +1,208 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "afraid-minutes",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
|
||||
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import sys\n",
|
||||
"import torch\n",
|
||||
"import pprint\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"from pathlib import Path\n",
|
||||
"from scipy.interpolate import make_interp_spline\n",
|
||||
"\n",
|
||||
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
|
||||
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
|
||||
"lib_dir = (root_dir / \"lib\").resolve()\n",
|
||||
"print(\"The root path: {:}\".format(root_dir))\n",
|
||||
"print(\"The library path: {:}\".format(lib_dir))\n",
|
||||
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
|
||||
"if str(lib_dir) not in sys.path:\n",
|
||||
" sys.path.insert(0, str(lib_dir))\n",
|
||||
"from utils.qlib_utils import QResult"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "continental-drain",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"TSF-2x24-drop0_0s2013-01-01\n",
|
||||
"TSF-2x24-drop0_0s2012-01-01\n",
|
||||
"TSF-2x24-drop0_0s2008-01-01\n",
|
||||
"TSF-2x24-drop0_0s2009-01-01\n",
|
||||
"TSF-2x24-drop0_0s2010-01-01\n",
|
||||
"TSF-2x24-drop0_0s2011-01-01\n",
|
||||
"TSF-2x24-drop0_0s2008-07-01\n",
|
||||
"TSF-2x24-drop0_0s2009-07-01\n",
|
||||
"There are 3011 dates\n",
|
||||
"Dates: 2008-01-02 2008-01-03\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"qresults = torch.load(os.path.join(root_dir, 'notebooks', 'TOT', 'temp-time-x.pth'))\n",
|
||||
"for qresult in qresults:\n",
|
||||
" print(qresult.name)\n",
|
||||
"all_dates = set()\n",
|
||||
"for qresult in qresults:\n",
|
||||
" dates = qresult.find_all_dates()\n",
|
||||
" for date in dates:\n",
|
||||
" all_dates.add(date)\n",
|
||||
"all_dates = sorted(list(all_dates))\n",
|
||||
"print('There are {:} dates'.format(len(all_dates)))\n",
|
||||
"print('Dates: {:} {:}'.format(all_dates[0], all_dates[1]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "intimate-approval",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib\n",
|
||||
"from matplotlib import cm\n",
|
||||
"matplotlib.use(\"agg\")\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.ticker as ticker"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "supreme-basis",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def vis_time_curve(qresults, dates, use_original, save_path):\n",
|
||||
" save_dir = (save_path / '..').resolve()\n",
|
||||
" save_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
" print('There are {:} qlib-results'.format(len(qresults)))\n",
|
||||
" \n",
|
||||
" dpi, width, height = 200, 5000, 2000\n",
|
||||
" figsize = width / float(dpi), height / float(dpi)\n",
|
||||
" LabelSize, LegendFontsize = 22, 12\n",
|
||||
" font_gap = 5\n",
|
||||
" linestyles = ['-', '--']\n",
|
||||
" colors = ['k', 'r']\n",
|
||||
" \n",
|
||||
" fig = plt.figure(figsize=figsize)\n",
|
||||
" cur_ax = fig.add_subplot(1, 1, 1)\n",
|
||||
" for idx, qresult in enumerate(qresults):\n",
|
||||
" print('Visualize [{:}] -- {:}'.format(idx, qresult.name))\n",
|
||||
" x_axis, y_axis = [], []\n",
|
||||
" for idate, date in enumerate(dates):\n",
|
||||
" if date in qresult._date2ICs[-1]:\n",
|
||||
" mean, std = qresult.get_IC_by_date(date, 100)\n",
|
||||
" if not np.isnan(mean):\n",
|
||||
" x_axis.append(idate)\n",
|
||||
" y_axis.append(mean)\n",
|
||||
" x_axis, y_axis = np.array(x_axis), np.array(y_axis)\n",
|
||||
" if use_original:\n",
|
||||
" cur_ax.plot(x_axis, y_axis, linewidth=1, color=colors[idx], linestyle=linestyles[idx])\n",
|
||||
" else:\n",
|
||||
" xnew = np.linspace(x_axis.min(), x_axis.max(), 200)\n",
|
||||
" spl = make_interp_spline(x_axis, y_axis, k=5)\n",
|
||||
" ynew = spl(xnew)\n",
|
||||
" cur_ax.plot(xnew, ynew, linewidth=2, color=colors[idx], linestyle=linestyles[idx])\n",
|
||||
" \n",
|
||||
" for tick in cur_ax.xaxis.get_major_ticks():\n",
|
||||
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||
" for tick in cur_ax.yaxis.get_major_ticks():\n",
|
||||
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||
" cur_ax.set_ylabel(\"IC (%)\", fontsize=LabelSize)\n",
|
||||
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
|
||||
" plt.close(\"all\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "shared-envelope",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The Desktop is at: /Users/xuanyidong/Desktop\n",
|
||||
"There are 2 qlib-results\n",
|
||||
"Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n",
|
||||
"Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n",
|
||||
"There are 2 qlib-results\n",
|
||||
"Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n",
|
||||
"Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Visualization\n",
|
||||
"home_dir = Path.home()\n",
|
||||
"desktop_dir = home_dir / 'Desktop'\n",
|
||||
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
|
||||
"\n",
|
||||
"vis_time_curve(\n",
|
||||
" (qresults[2], qresults[-1]),\n",
|
||||
" all_dates,\n",
|
||||
" True,\n",
|
||||
" desktop_dir / 'es_csi300_time_curve.pdf')\n",
|
||||
"\n",
|
||||
"vis_time_curve(\n",
|
||||
" (qresults[2], qresults[-1]),\n",
|
||||
" all_dates,\n",
|
||||
" False,\n",
|
||||
" desktop_dir / 'es_csi300_time_curve-inter.pdf')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "exempt-stable",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
129
AutoDL-Projects/notebooks/TOT/time-curve.py
Normal file
129
AutoDL-Projects/notebooks/TOT/time-curve.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import torch
|
||||
import qlib
|
||||
import pprint
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
# __file__ = os.path.dirname(os.path.realpath("__file__"))
|
||||
note_dir = Path(__file__).parent.resolve()
|
||||
root_dir = (Path(__file__).parent / ".." / "..").resolve()
|
||||
lib_dir = (root_dir / "lib").resolve()
|
||||
print("The root path: {:}".format(root_dir))
|
||||
print("The library path: {:}".format(lib_dir))
|
||||
assert lib_dir.exists(), "{:} does not exist".format(lib_dir)
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
import qlib
|
||||
from qlib import config as qconfig
|
||||
from qlib.workflow import R
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN)
|
||||
|
||||
from utils.qlib_utils import QResult
|
||||
|
||||
|
||||
def filter_finished(recorders):
|
||||
returned_recorders = dict()
|
||||
not_finished = 0
|
||||
for key, recorder in recorders.items():
|
||||
if recorder.status == "FINISHED":
|
||||
returned_recorders[key] = recorder
|
||||
else:
|
||||
not_finished += 1
|
||||
return returned_recorders, not_finished
|
||||
|
||||
|
||||
def add_to_dict(xdict, timestamp, value):
|
||||
date = timestamp.date().strftime("%Y-%m-%d")
|
||||
if date in xdict:
|
||||
raise ValueError("This date [{:}] is already in the dict".format(date))
|
||||
xdict[date] = value
|
||||
|
||||
|
||||
def query_info(save_dir, verbose, name_filter, key_map):
|
||||
if isinstance(save_dir, list):
|
||||
results = []
|
||||
for x in save_dir:
|
||||
x = query_info(x, verbose, name_filter, key_map)
|
||||
results.extend(x)
|
||||
return results
|
||||
# Here, the save_dir must be a string
|
||||
R.set_uri(str(save_dir))
|
||||
experiments = R.list_experiments()
|
||||
|
||||
if verbose:
|
||||
print("There are {:} experiments.".format(len(experiments)))
|
||||
qresults = []
|
||||
for idx, (key, experiment) in enumerate(experiments.items()):
|
||||
if experiment.id == "0":
|
||||
continue
|
||||
if (
|
||||
name_filter is not None
|
||||
and re.fullmatch(name_filter, experiment.name) is None
|
||||
):
|
||||
continue
|
||||
recorders = experiment.list_recorders()
|
||||
recorders, not_finished = filter_finished(recorders)
|
||||
if verbose:
|
||||
print(
|
||||
"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.".format(
|
||||
idx + 1,
|
||||
len(experiments),
|
||||
experiment.name,
|
||||
len(recorders),
|
||||
len(recorders) + not_finished,
|
||||
)
|
||||
)
|
||||
result = QResult(experiment.name)
|
||||
for recorder_id, recorder in recorders.items():
|
||||
file_names = ["results-train.pkl", "results-valid.pkl", "results-test.pkl"]
|
||||
date2IC = OrderedDict()
|
||||
for file_name in file_names:
|
||||
xtemp = recorder.load_object(file_name)["all-IC"]
|
||||
timestamps, values = xtemp.index.tolist(), xtemp.tolist()
|
||||
for timestamp, value in zip(timestamps, values):
|
||||
add_to_dict(date2IC, timestamp, value)
|
||||
result.update(recorder.list_metrics(), key_map)
|
||||
result.append_path(
|
||||
os.path.join(recorder.uri, recorder.experiment_id, recorder.id)
|
||||
)
|
||||
result.append_date2ICs(date2IC)
|
||||
if not len(result):
|
||||
print("There are no valid recorders for {:}".format(experiment))
|
||||
continue
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
"There are {:} valid recorders for {:}".format(
|
||||
len(recorders), experiment.name
|
||||
)
|
||||
)
|
||||
qresults.append(result)
|
||||
return qresults
|
||||
|
||||
|
||||
##
|
||||
paths = [root_dir / "outputs" / "qlib-baselines-csi300"]
|
||||
paths = [path.resolve() for path in paths]
|
||||
print(paths)
|
||||
|
||||
key_map = dict()
|
||||
for xset in ("train", "valid", "test"):
|
||||
key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset)
|
||||
key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset)
|
||||
qresults = query_info(paths, False, "TSF-2x24-drop0_0s.*-.*-01", key_map)
|
||||
print("Find {:} results".format(len(qresults)))
|
||||
times = []
|
||||
for qresult in qresults:
|
||||
times.append(qresult.name.split("0_0s")[-1])
|
||||
print(times)
|
||||
save_path = os.path.join(note_dir, "temp-time-x.pth")
|
||||
torch.save(qresults, save_path)
|
||||
print(save_path)
|
@@ -0,0 +1,102 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"library path: /Users/xuanyidong/Desktop/XAutoDL/lib\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#####################################################\n",
|
||||
"# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #\n",
|
||||
"#####################################################\n",
|
||||
"import abc, os, sys\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
|
||||
"\n",
|
||||
"lib_dir = (Path(__file__).parent / \"..\" / \"lib\").resolve()\n",
|
||||
"print(\"library path: {:}\".format(lib_dir))\n",
|
||||
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
|
||||
"if str(lib_dir) not in sys.path:\n",
|
||||
" sys.path.insert(0, str(lib_dir))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1.7.0\n",
|
||||
"True\n",
|
||||
"OrderedDict()\n",
|
||||
"OrderedDict()\n",
|
||||
"set()\n",
|
||||
"OrderedDict()\n",
|
||||
"OrderedDict()\n",
|
||||
"OrderedDict()\n",
|
||||
"OrderedDict()\n",
|
||||
"OrderedDict()\n",
|
||||
"OrderedDict()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/xuanyidong/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py:551: UserWarning: Setting attributes on ParameterDict is not supported.\n",
|
||||
" warnings.warn(\"Setting attributes on ParameterDict is not supported.\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Test the Linear layer\n",
|
||||
"import spaces\n",
|
||||
"import torch\n",
|
||||
"from xlayers import super_core\n",
|
||||
"\n",
|
||||
"print(torch.__version__)\n",
|
||||
"mlp = super_core.SuperMLPv2(10, 12, 32)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
119
AutoDL-Projects/notebooks/spaces-xmisc/scheduler.ipynb
Normal file
119
AutoDL-Projects/notebooks/spaces-xmisc/scheduler.ipynb
Normal file
File diff suppressed because one or more lines are too long
110
AutoDL-Projects/notebooks/spaces-xmisc/synthetic-data.ipynb
Normal file
110
AutoDL-Projects/notebooks/spaces-xmisc/synthetic-data.ipynb
Normal file
File diff suppressed because one or more lines are too long
129
AutoDL-Projects/notebooks/spaces-xmisc/synthetic-env.ipynb
Normal file
129
AutoDL-Projects/notebooks/spaces-xmisc/synthetic-env.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,152 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "filled-multiple",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
|
||||
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os, sys\n",
|
||||
"import torch\n",
|
||||
"from pathlib import Path\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib\n",
|
||||
"from matplotlib import cm\n",
|
||||
"matplotlib.use(\"agg\")\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.ticker as ticker\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
|
||||
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
|
||||
"lib_dir = (root_dir / \"lib\").resolve()\n",
|
||||
"print(\"The root path: {:}\".format(root_dir))\n",
|
||||
"print(\"The library path: {:}\".format(lib_dir))\n",
|
||||
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
|
||||
"if str(lib_dir) not in sys.path:\n",
|
||||
" sys.path.insert(0, str(lib_dir))\n",
|
||||
"\n",
|
||||
"from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv\n",
|
||||
"from datasets import DynamicQuadraticFunc\n",
|
||||
"from datasets.synthetic_example import create_example_v1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "detected-second",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def draw_fig(save_dir, timestamp, xaxis, yaxis):\n",
|
||||
" save_path = save_dir / '{:04d}'.format(timestamp)\n",
|
||||
" # print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path))\n",
|
||||
" dpi, width, height = 40, 1500, 1500\n",
|
||||
" figsize = width / float(dpi), height / float(dpi)\n",
|
||||
" LabelSize, LegendFontsize, font_gap = 80, 80, 5\n",
|
||||
"\n",
|
||||
" fig = plt.figure(figsize=figsize)\n",
|
||||
" \n",
|
||||
" cur_ax = fig.add_subplot(1, 1, 1)\n",
|
||||
" cur_ax.scatter(xaxis, yaxis, color=\"k\", s=10, alpha=0.9, label=\"Timestamp={:02d}\".format(timestamp))\n",
|
||||
" cur_ax.set_xlabel(\"X\", fontsize=LabelSize)\n",
|
||||
" cur_ax.set_ylabel(\"f(X)\", rotation=0, fontsize=LabelSize)\n",
|
||||
" cur_ax.set_xlim(-6, 6)\n",
|
||||
" cur_ax.set_ylim(-40, 40)\n",
|
||||
" for tick in cur_ax.xaxis.get_major_ticks():\n",
|
||||
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||
" tick.label.set_rotation(10)\n",
|
||||
" for tick in cur_ax.yaxis.get_major_ticks():\n",
|
||||
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||
" \n",
|
||||
" plt.legend(loc=1, fontsize=LegendFontsize)\n",
|
||||
" fig.savefig(str(save_path) + '.pdf', dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
|
||||
" fig.savefig(str(save_path) + '.png', dpi=dpi, bbox_inches=\"tight\", format=\"png\")\n",
|
||||
" plt.close(\"all\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def visualize_env(save_dir):\n",
|
||||
" save_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
" dynamic_env, function = create_example_v1(100, num_per_task=500)\n",
|
||||
" \n",
|
||||
" additional_xaxis = np.arange(-6, 6, 0.1)\n",
|
||||
" for timestamp, dataset in dynamic_env:\n",
|
||||
" num = dataset.shape[0]\n",
|
||||
" # timeaxis = (torch.zeros(num) + timestamp).numpy()\n",
|
||||
" xaxis = dataset[:,0].numpy()\n",
|
||||
" xaxis = np.concatenate((additional_xaxis, xaxis))\n",
|
||||
" # compute the ground truth\n",
|
||||
" function.set_timestamp(timestamp)\n",
|
||||
" yaxis = function(xaxis)\n",
|
||||
" draw_fig(save_dir, timestamp, xaxis, yaxis)\n",
|
||||
"\n",
|
||||
"home_dir = Path.home()\n",
|
||||
"desktop_dir = home_dir / 'Desktop'\n",
|
||||
"vis_save_dir = desktop_dir / 'vis-synthetic'\n",
|
||||
"visualize_env(vis_save_dir)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "rapid-uruguay",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ffmpeg -y -i /Users/xuanyidong/Desktop/vis-synthetic/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1000:1000 -vb 5000k /Users/xuanyidong/Desktop/vis-synthetic/vis.mp4\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Plot the data\n",
|
||||
"cmd = 'ffmpeg -y -i {:}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1000:1000 -vb 5000k {:}/vis.mp4'.format(vis_save_dir, vis_save_dir)\n",
|
||||
"print(cmd)\n",
|
||||
"os.system(cmd)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@@ -0,0 +1,277 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "3f754c96",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from xautodl import spaces\n",
|
||||
"from xautodl.xlayers import super_core\n",
|
||||
"\n",
|
||||
"def _create_stel(input_dim, output_dim, order):\n",
|
||||
" return super_core.SuperSequential(\n",
|
||||
" super_core.SuperLinear(input_dim, output_dim),\n",
|
||||
" super_core.SuperTransformerEncoderLayer(\n",
|
||||
" output_dim,\n",
|
||||
" num_heads=spaces.Categorical(2, 4, 6),\n",
|
||||
" mlp_hidden_multiplier=spaces.Categorical(1, 2, 4),\n",
|
||||
" order=order,\n",
|
||||
" ),\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "81d42f4b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batch, seq_dim, input_dim = 1, 4, 6\n",
|
||||
"order = super_core.LayerOrder.PreNorm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "8056b37c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"SuperSequential(\n",
|
||||
" (0): SuperSequential(\n",
|
||||
" (0): SuperLinear(in_features=6, out_features=Categorical(candidates=[12, 24, 36], default_index=None), bias=True)\n",
|
||||
" (1): SuperTransformerEncoderLayer(\n",
|
||||
" (norm1): SuperLayerNorm1D(shape=Categorical(candidates=[12, 24, 36], default_index=None), eps=1e-06, elementwise_affine=True)\n",
|
||||
" (mha): SuperSelfAttention(\n",
|
||||
" input_dim=Categorical(candidates=[12, 24, 36], default_index=None), proj_dim=Categorical(candidates=[12, 24, 36], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n",
|
||||
" (q_fc): SuperLinear(in_features=Categorical(candidates=[12, 24, 36], default_index=None), out_features=Categorical(candidates=[12, 24, 36], default_index=None), bias=False)\n",
|
||||
" (k_fc): SuperLinear(in_features=Categorical(candidates=[12, 24, 36], default_index=None), out_features=Categorical(candidates=[12, 24, 36], default_index=None), bias=False)\n",
|
||||
" (v_fc): SuperLinear(in_features=Categorical(candidates=[12, 24, 36], default_index=None), out_features=Categorical(candidates=[12, 24, 36], default_index=None), bias=False)\n",
|
||||
" (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n",
|
||||
" )\n",
|
||||
" (drop): Dropout(p=0.0, inplace=False)\n",
|
||||
" (norm2): SuperLayerNorm1D(shape=Categorical(candidates=[12, 24, 36], default_index=None), eps=1e-06, elementwise_affine=True)\n",
|
||||
" (mlp): SuperMLPv2(\n",
|
||||
" in_features=Categorical(candidates=[12, 24, 36], default_index=None), hidden_multiplier=Categorical(candidates=[1, 2, 4], default_index=None), out_features=Categorical(candidates=[12, 24, 36], default_index=None), drop=None, fc1 -> act -> drop -> fc2 -> drop,\n",
|
||||
" (_params): ParameterDict(\n",
|
||||
" (fc1_super_weight): Parameter containing: [torch.FloatTensor of size 144x36]\n",
|
||||
" (fc1_super_bias): Parameter containing: [torch.FloatTensor of size 144]\n",
|
||||
" (fc2_super_weight): Parameter containing: [torch.FloatTensor of size 36x144]\n",
|
||||
" (fc2_super_bias): Parameter containing: [torch.FloatTensor of size 36]\n",
|
||||
" )\n",
|
||||
" (act): GELU()\n",
|
||||
" (drop): Dropout(p=0.0, inplace=False)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): SuperSequential(\n",
|
||||
" (0): SuperLinear(in_features=Categorical(candidates=[12, 24, 36], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), bias=True)\n",
|
||||
" (1): SuperTransformerEncoderLayer(\n",
|
||||
" (norm1): SuperLayerNorm1D(shape=Categorical(candidates=[24, 36, 48], default_index=None), eps=1e-06, elementwise_affine=True)\n",
|
||||
" (mha): SuperSelfAttention(\n",
|
||||
" input_dim=Categorical(candidates=[24, 36, 48], default_index=None), proj_dim=Categorical(candidates=[24, 36, 48], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n",
|
||||
" (q_fc): SuperLinear(in_features=Categorical(candidates=[24, 36, 48], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), bias=False)\n",
|
||||
" (k_fc): SuperLinear(in_features=Categorical(candidates=[24, 36, 48], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), bias=False)\n",
|
||||
" (v_fc): SuperLinear(in_features=Categorical(candidates=[24, 36, 48], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), bias=False)\n",
|
||||
" (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n",
|
||||
" )\n",
|
||||
" (drop): Dropout(p=0.0, inplace=False)\n",
|
||||
" (norm2): SuperLayerNorm1D(shape=Categorical(candidates=[24, 36, 48], default_index=None), eps=1e-06, elementwise_affine=True)\n",
|
||||
" (mlp): SuperMLPv2(\n",
|
||||
" in_features=Categorical(candidates=[24, 36, 48], default_index=None), hidden_multiplier=Categorical(candidates=[1, 2, 4], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), drop=None, fc1 -> act -> drop -> fc2 -> drop,\n",
|
||||
" (_params): ParameterDict(\n",
|
||||
" (fc1_super_weight): Parameter containing: [torch.FloatTensor of size 192x48]\n",
|
||||
" (fc1_super_bias): Parameter containing: [torch.FloatTensor of size 192]\n",
|
||||
" (fc2_super_weight): Parameter containing: [torch.FloatTensor of size 48x192]\n",
|
||||
" (fc2_super_bias): Parameter containing: [torch.FloatTensor of size 48]\n",
|
||||
" )\n",
|
||||
" (act): GELU()\n",
|
||||
" (drop): Dropout(p=0.0, inplace=False)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (2): SuperSequential(\n",
|
||||
" (0): SuperLinear(in_features=Categorical(candidates=[24, 36, 48], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=True)\n",
|
||||
" (1): SuperTransformerEncoderLayer(\n",
|
||||
" (norm1): SuperLayerNorm1D(shape=Categorical(candidates=[36, 72, 100], default_index=None), eps=1e-06, elementwise_affine=True)\n",
|
||||
" (mha): SuperSelfAttention(\n",
|
||||
" input_dim=Categorical(candidates=[36, 72, 100], default_index=None), proj_dim=Categorical(candidates=[36, 72, 100], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n",
|
||||
" (q_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
|
||||
" (k_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
|
||||
" (v_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
|
||||
" (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n",
|
||||
" )\n",
|
||||
" (drop): Dropout(p=0.0, inplace=False)\n",
|
||||
" (norm2): SuperLayerNorm1D(shape=Categorical(candidates=[36, 72, 100], default_index=None), eps=1e-06, elementwise_affine=True)\n",
|
||||
" (mlp): SuperMLPv2(\n",
|
||||
" in_features=Categorical(candidates=[36, 72, 100], default_index=None), hidden_multiplier=Categorical(candidates=[1, 2, 4], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), drop=None, fc1 -> act -> drop -> fc2 -> drop,\n",
|
||||
" (_params): ParameterDict(\n",
|
||||
" (fc1_super_weight): Parameter containing: [torch.FloatTensor of size 400x100]\n",
|
||||
" (fc1_super_bias): Parameter containing: [torch.FloatTensor of size 400]\n",
|
||||
" (fc2_super_weight): Parameter containing: [torch.FloatTensor of size 100x400]\n",
|
||||
" (fc2_super_bias): Parameter containing: [torch.FloatTensor of size 100]\n",
|
||||
" )\n",
|
||||
" (act): GELU()\n",
|
||||
" (drop): Dropout(p=0.0, inplace=False)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"out1_dim = spaces.Categorical(12, 24, 36)\n",
|
||||
"out2_dim = spaces.Categorical(24, 36, 48)\n",
|
||||
"out3_dim = spaces.Categorical(36, 72, 100)\n",
|
||||
"layer1 = _create_stel(input_dim, out1_dim, order)\n",
|
||||
"layer2 = _create_stel(out1_dim, out2_dim, order)\n",
|
||||
"layer3 = _create_stel(out2_dim, out3_dim, order)\n",
|
||||
"model = super_core.SuperSequential(layer1, layer2, layer3)\n",
|
||||
"print(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4fd53a7c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"> \u001b[0;32m/Users/xuanyidong/anaconda3/lib/python3.8/site-packages/xautodl-0.9.9-py3.8.egg/xautodl/xlayers/super_transformer.py\u001b[0m(116)\u001b[0;36mforward_raw\u001b[0;34m()\u001b[0m\n",
|
||||
"\u001b[0;32m 114 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 115 \u001b[0;31m \u001b[0;31m# feed-forward layer -- MLP\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m--> 116 \u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 117 \u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmlp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\u001b[0;32m 118 \u001b[0;31m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_order\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mLayerOrder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPostNorm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0m\n",
|
||||
"ipdb> print(self)\n",
|
||||
"SuperTransformerEncoderLayer(\n",
|
||||
" (norm1): SuperLayerNorm1D(shape=Categorical(candidates=[36, 72, 100], default_index=None), eps=1e-06, elementwise_affine=True)\n",
|
||||
" (mha): SuperSelfAttention(\n",
|
||||
" input_dim=Categorical(candidates=[36, 72, 100], default_index=None), proj_dim=Categorical(candidates=[36, 72, 100], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n",
|
||||
" (q_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
|
||||
" (k_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
|
||||
" (v_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
|
||||
" (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n",
|
||||
" )\n",
|
||||
" (drop): Dropout(p=0.0, inplace=False)\n",
|
||||
" (norm2): SuperLayerNorm1D(shape=Categorical(candidates=[36, 72, 100], default_index=None), eps=1e-06, elementwise_affine=True)\n",
|
||||
" (mlp): SuperMLPv2(\n",
|
||||
" in_features=Categorical(candidates=[36, 72, 100], default_index=None), hidden_multiplier=Categorical(candidates=[1, 2, 4], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), drop=None, fc1 -> act -> drop -> fc2 -> drop,\n",
|
||||
" (_params): ParameterDict(\n",
|
||||
" (fc1_super_weight): Parameter containing: [torch.FloatTensor of size 400x100]\n",
|
||||
" (fc1_super_bias): Parameter containing: [torch.FloatTensor of size 400]\n",
|
||||
" (fc2_super_weight): Parameter containing: [torch.FloatTensor of size 100x400]\n",
|
||||
" (fc2_super_bias): Parameter containing: [torch.FloatTensor of size 100]\n",
|
||||
" )\n",
|
||||
" (act): GELU()\n",
|
||||
" (drop): Dropout(p=0.0, inplace=False)\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"ipdb> print(inputs.shape)\n",
|
||||
"torch.Size([1, 4, 100])\n",
|
||||
"ipdb> print(x.shape)\n",
|
||||
"torch.Size([1, 4, 96])\n",
|
||||
"ipdb> print(self.mha)\n",
|
||||
"SuperSelfAttention(\n",
|
||||
" input_dim=Categorical(candidates=[36, 72, 100], default_index=None), proj_dim=Categorical(candidates=[36, 72, 100], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n",
|
||||
" (q_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
|
||||
" (k_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
|
||||
" (v_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
|
||||
" (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n",
|
||||
")\n",
|
||||
"ipdb> print(self.mha.candidate)\n",
|
||||
"*** AttributeError: 'SuperSelfAttention' object has no attribute 'candidate'\n",
|
||||
"ipdb> print(self.mha.abstract_candidate)\n",
|
||||
"*** AttributeError: 'SuperSelfAttention' object has no attribute 'abstract_candidate'\n",
|
||||
"ipdb> print(self.mha._abstract_child)\n",
|
||||
"None\n",
|
||||
"ipdb> print(self.abstract_child)\n",
|
||||
"None\n",
|
||||
"ipdb> print(self.abstract_child.abstract_child)\n",
|
||||
"*** AttributeError: 'NoneType' object has no attribute 'abstract_child'\n",
|
||||
"ipdb> print(self.mha)\n",
|
||||
"SuperSelfAttention(\n",
|
||||
" input_dim=Categorical(candidates=[36, 72, 100], default_index=None), proj_dim=Categorical(candidates=[36, 72, 100], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n",
|
||||
" (q_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
|
||||
" (k_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
|
||||
" (v_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n",
|
||||
" (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"inputs = torch.rand(batch, seq_dim, input_dim)\n",
|
||||
"outputs = model(inputs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "05332b98",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"abstract_space = model.abstract_search_space\n",
|
||||
"abstract_space.clean_last()\n",
|
||||
"abstract_child = abstract_space.random(reuse_last=True)\n",
|
||||
"# print(\"The abstract child program is:\\n{:}\".format(abstract_child))\n",
|
||||
"model.enable_candidate()\n",
|
||||
"model.set_super_run_type(super_core.SuperRunMode.Candidate)\n",
|
||||
"model.apply_candidate(abstract_child)\n",
|
||||
"outputs = model(inputs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3289f938",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(outputs.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "36951cdf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Reference in New Issue
Block a user