upload
This commit is contained in:
5
zero-cost-nas/notebooks/README.md
Normal file
5
zero-cost-nas/notebooks/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
## Notebooks
|
||||
|
||||
To run these notebooks, you need to compute the zero-cost metrics for each dataset.
|
||||
Alternatively, you can download precomputed results from [here](https://drive.google.com/drive/folders/1fUBaTd05OHrKIRs-x9Fx8Zsk5QqErks8?usp=sharing) and save to the root folder of this repo.
|
||||
You will also need the [`data` directory](https://drive.google.com/drive/folders/18Eia6YuTE5tn5Lis_43h30HYpnF9Ynqf?usp=sharing), which should also be saved to the root folder of the repo.
|
387
zero-cost-nas/notebooks/nas_examples.ipynb
Normal file
387
zero-cost-nas/notebooks/nas_examples.ipynb
Normal file
File diff suppressed because one or more lines are too long
153
zero-cost-nas/notebooks/nasbench101_correlations.ipynb
Normal file
153
zero-cost-nas/notebooks/nasbench101_correlations.ipynb
Normal file
@@ -0,0 +1,153 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os, pickle, sys\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from scipy import stats\n",
|
||||
"import numpy as np\n",
|
||||
"import glob\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"from prettytable import PrettyTable"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 96/96 [00:03<00:00, 30.17it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"d = '../results_release/nasbench1/proxies'\n",
|
||||
"runs = []\n",
|
||||
"processed = set()\n",
|
||||
"\n",
|
||||
"for f in tqdm(os.listdir(d)):\n",
|
||||
" pf = open(os.path.join(d,f),'rb')\n",
|
||||
" while 1:\n",
|
||||
" try:\n",
|
||||
" p = pickle.load(pf)\n",
|
||||
" if p['hash'] in processed:\n",
|
||||
" continue\n",
|
||||
" processed.add(p['hash'])\n",
|
||||
" runs.append(p)\n",
|
||||
" except EOFError:\n",
|
||||
" break\n",
|
||||
" pf.close()\n",
|
||||
"with open('../data/nasbench1_accuracy.p','rb') as f:\n",
|
||||
" all_accur = pickle.load(f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"423624 423624\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(len(runs),len(all_accur))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"../results_release/nasbench1/proxies 423624\n",
|
||||
"+---------+-----------+-------+-------+--------+---------+-----------+\n",
|
||||
"| Dataset | grad_norm | snip | grasp | fisher | synflow | jacob_cov |\n",
|
||||
"+---------+-----------+-------+-------+--------+---------+-----------+\n",
|
||||
"| CIFAR10 | 0.198 | 0.164 | 0.448 | 0.257 | 0.372 | 0.378 |\n",
|
||||
"+---------+-----------+-------+-------+--------+---------+-----------+\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"t=None\n",
|
||||
"\n",
|
||||
"print(d, len(runs))\n",
|
||||
"metrics={}\n",
|
||||
"for k in runs[0]['logmeasures'].keys():\n",
|
||||
" metrics[k] = []\n",
|
||||
"acc = []\n",
|
||||
"hashes = []\n",
|
||||
"\n",
|
||||
"if t is None:\n",
|
||||
" hl=['Dataset']\n",
|
||||
" hl.extend(['grad_norm', 'snip', 'grasp', 'fisher', 'synflow', 'jacob_cov'])\n",
|
||||
" t = PrettyTable(hl)\n",
|
||||
"\n",
|
||||
"for r in runs:\n",
|
||||
" for k,v in r['logmeasures'].items():\n",
|
||||
" metrics[k].append(v)\n",
|
||||
" \n",
|
||||
" acc.append(all_accur[r['hash']][0])\n",
|
||||
" hashes.append(r['hash'])\n",
|
||||
"\n",
|
||||
"res = []\n",
|
||||
"for k in hl:\n",
|
||||
" if k=='Dataset':\n",
|
||||
" continue\n",
|
||||
" v = metrics[k]\n",
|
||||
" cr = abs(stats.spearmanr(acc,v,nan_policy='omit').correlation)\n",
|
||||
" #print(f'{k} = {cr}')\n",
|
||||
" res.append(round(cr,3))\n",
|
||||
"\n",
|
||||
"ds = 'CIFAR10'\n",
|
||||
"t.add_row([ds]+res)\n",
|
||||
"\n",
|
||||
"print(t)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
749
zero-cost-nas/notebooks/nasbench201_correlations.ipynb
Normal file
749
zero-cost-nas/notebooks/nasbench201_correlations.ipynb
Normal file
File diff suppressed because one or more lines are too long
362
zero-cost-nas/notebooks/ptcv_correlations.ipynb
Normal file
362
zero-cost-nas/notebooks/ptcv_correlations.ipynb
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user