Compare commits
10 Commits
d2cef525f3
...
bb33ca9a68
Author | SHA1 | Date | |
---|---|---|---|
bb33ca9a68 | |||
|
f46486e21b | ||
|
5908a1edef | ||
|
ed34024a88 | ||
|
5bf036a763 | ||
|
b557a22928 | ||
|
f549ed2e61 | ||
|
5a5cb82537 | ||
|
676e8e411d | ||
|
8d0799dfb1 |
2
.github/workflows/test-basic.yaml
vendored
2
.github/workflows/test-basic.yaml
vendored
@ -41,7 +41,7 @@ jobs:
|
||||
|
||||
- name: Install XAutoDL from source
|
||||
run: |
|
||||
python setup.py install
|
||||
pip install .
|
||||
|
||||
- name: Test Search Space
|
||||
run: |
|
||||
|
2
.github/workflows/test-misc.yaml
vendored
2
.github/workflows/test-misc.yaml
vendored
@ -26,7 +26,7 @@ jobs:
|
||||
|
||||
- name: Install XAutoDL from source
|
||||
run: |
|
||||
python setup.py install
|
||||
pip install .
|
||||
|
||||
- name: Test Xmisc
|
||||
run: |
|
||||
|
@ -26,7 +26,7 @@ jobs:
|
||||
|
||||
- name: Install XAutoDL from source
|
||||
run: |
|
||||
python setup.py install
|
||||
pip install .
|
||||
|
||||
- name: Test Super Model
|
||||
run: |
|
||||
|
@ -61,13 +61,13 @@ At this moment, this project provides the following algorithms and scripts to ru
|
||||
<tr> <!-- (6-th row) -->
|
||||
<td align="center" valign="middle"> NATS-Bench </td>
|
||||
<td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size</a> </td>
|
||||
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/NATS-Bench">NATS-Bench.md</a> </td>
|
||||
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/NATS-Bench/blob/main/README.md">NATS-Bench.md</a> </td>
|
||||
</tr>
|
||||
<tr> <!-- (7-th row) -->
|
||||
<td align="center" valign="middle"> ... </td>
|
||||
<td align="center" valign="middle"> ENAS / REA / REINFORCE / BOHB </td>
|
||||
<td align="center" valign="middle"> Please check the original papers </td>
|
||||
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NAS-Bench-201.md">NAS-Bench-201.md</a> <a href="https://github.com/D-X-Y/NATS-Bench">NATS-Bench.md</a> </td>
|
||||
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NAS-Bench-201.md">NAS-Bench-201.md</a> <a href="https://github.com/D-X-Y/NATS-Bench/blob/main/README.md">NATS-Bench.md</a> </td>
|
||||
</tr>
|
||||
<tr> <!-- (start second block) -->
|
||||
<td rowspan="1" align="center" valign="middle" halign="middle"> HPO </td>
|
||||
@ -89,7 +89,7 @@ At this moment, this project provides the following algorithms and scripts to ru
|
||||
## Requirements and Preparation
|
||||
|
||||
|
||||
**First of all**, please use `python setup.py install` to install `xautodl` library.
|
||||
**First of all**, please use `pip install .` to install `xautodl` library.
|
||||
|
||||
Please install `Python>=3.6` and `PyTorch>=1.5.0`. (You could use lower versions of Python and PyTorch, but may have bugs).
|
||||
Some visualization codes may require `opencv`.
|
||||
|
@ -29,7 +29,7 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s
|
||||
You can move it to anywhere you want and send its path to our API for initialization.
|
||||
- [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
|
||||
- [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [
|
||||
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
|
||||
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the trained weights.
|
||||
- [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
|
||||
- [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions
|
||||
- [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable.
|
||||
|
@ -27,7 +27,7 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s
|
||||
You can move it to anywhere you want and send its path to our API for initialization.
|
||||
- [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
|
||||
- [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [
|
||||
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
|
||||
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the trained weights.
|
||||
- [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
|
||||
- [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions
|
||||
- [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable.
|
||||
|
@ -3,7 +3,7 @@
|
||||
</p>
|
||||
|
||||
---------
|
||||
[](LICENSE.md)
|
||||
[](../LICENSE.md)
|
||||
|
||||
自动深度学习库 (AutoDL-Projects) 是一个开源的,轻量级的,功能强大的项目。
|
||||
该项目实现了多种网络结构搜索(NAS)和超参数优化(HPO)算法。
|
||||
@ -142,8 +142,8 @@
|
||||
|
||||
# 其他
|
||||
|
||||
如果你想要给这份代码库做贡献,请看[CONTRIBUTING.md](.github/CONTRIBUTING.md)。
|
||||
此外,使用规范请参考[CODE-OF-CONDUCT.md](.github/CODE-OF-CONDUCT.md)。
|
||||
如果你想要给这份代码库做贡献,请看[CONTRIBUTING.md](../.github/CONTRIBUTING.md)。
|
||||
此外,使用规范请参考[CODE-OF-CONDUCT.md](../.github/CODE-OF-CONDUCT.md)。
|
||||
|
||||
# 许可证
|
||||
The entire codebase is under [MIT license](LICENSE.md)
|
||||
The entire codebase is under [MIT license](../LICENSE.md)
|
||||
|
@ -24,6 +24,9 @@
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
####
|
||||
# The following scripts are added in 20 Mar 2022
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas_v1 --rand_seed 777
|
||||
######################################################################################
|
||||
import os, sys, time, random, argparse
|
||||
import numpy as np
|
||||
@ -166,6 +169,8 @@ def search_func(
|
||||
network.set_cal_mode("dynamic", sampled_arch)
|
||||
elif algo == "gdas":
|
||||
network.set_cal_mode("gdas", None)
|
||||
elif algo == "gdas_v1":
|
||||
network.set_cal_mode("gdas_v1", None)
|
||||
elif algo.startswith("darts"):
|
||||
network.set_cal_mode("joint", None)
|
||||
elif algo == "random":
|
||||
@ -196,6 +201,8 @@ def search_func(
|
||||
network.set_cal_mode("joint")
|
||||
elif algo == "gdas":
|
||||
network.set_cal_mode("gdas", None)
|
||||
elif algo == "gdas_v1":
|
||||
network.set_cal_mode("gdas_v1", None)
|
||||
elif algo.startswith("darts"):
|
||||
network.set_cal_mode("joint", None)
|
||||
elif algo == "random":
|
||||
@ -373,7 +380,7 @@ def get_best_arch(xloader, network, n_samples, algo):
|
||||
archs, valid_accs = network.return_topK(n_samples, True), []
|
||||
elif algo == "setn":
|
||||
archs, valid_accs = network.return_topK(n_samples, False), []
|
||||
elif algo.startswith("darts") or algo == "gdas":
|
||||
elif algo.startswith("darts") or algo == "gdas" or algo == "gdas_v1":
|
||||
arch = network.genotype
|
||||
archs, valid_accs = [arch], []
|
||||
elif algo == "enas":
|
||||
@ -568,7 +575,7 @@ def main(xargs):
|
||||
)
|
||||
|
||||
network.set_drop_path(float(epoch + 1) / total_epoch, xargs.drop_path_rate)
|
||||
if xargs.algo == "gdas":
|
||||
if xargs.algo == "gdas" or xargs.algo == "gdas_v1":
|
||||
network.set_tau(
|
||||
xargs.tau_max
|
||||
- (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)
|
||||
@ -632,6 +639,8 @@ def main(xargs):
|
||||
network.set_cal_mode("dynamic", genotype)
|
||||
elif xargs.algo == "gdas":
|
||||
network.set_cal_mode("gdas", None)
|
||||
elif xargs.algo == "gdas_v1":
|
||||
network.set_cal_mode("gdas_v1", None)
|
||||
elif xargs.algo.startswith("darts"):
|
||||
network.set_cal_mode("joint", None)
|
||||
elif xargs.algo == "random":
|
||||
@ -699,6 +708,8 @@ def main(xargs):
|
||||
network.set_cal_mode("dynamic", genotype)
|
||||
elif xargs.algo == "gdas":
|
||||
network.set_cal_mode("gdas", None)
|
||||
elif xargs.algo == "gdas_v1":
|
||||
network.set_cal_mode("gdas_v1", None)
|
||||
elif xargs.algo.startswith("darts"):
|
||||
network.set_cal_mode("joint", None)
|
||||
elif xargs.algo == "random":
|
||||
@ -747,7 +758,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--algo",
|
||||
type=str,
|
||||
choices=["darts-v1", "darts-v2", "gdas", "setn", "random", "enas"],
|
||||
choices=["darts-v1", "darts-v2", "gdas", "gdas_v1", "setn", "random", "enas"],
|
||||
help="The search space name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
57
exps/experimental/test-dks.py
Normal file
57
exps/experimental/test-dks.py
Normal file
@ -0,0 +1,57 @@
|
||||
from dks.base.activation_getter import (
|
||||
get_activation_function as _get_numpy_activation_function,
|
||||
)
|
||||
from dks.base.activation_transform import _get_activations_params
|
||||
|
||||
|
||||
def subnet_max_func(x, r_fn):
|
||||
depth = 7
|
||||
res_x = r_fn(x)
|
||||
x = r_fn(x)
|
||||
for _ in range(depth):
|
||||
x = r_fn(r_fn(x)) + x
|
||||
return max(x, res_x)
|
||||
|
||||
|
||||
def subnet_max_func_v2(x, r_fn):
|
||||
depth = 2
|
||||
res_x = r_fn(x)
|
||||
|
||||
x = r_fn(x)
|
||||
for _ in range(depth):
|
||||
x = 0.8 * r_fn(r_fn(x)) + 0.2 * x
|
||||
|
||||
return max(x, res_x)
|
||||
|
||||
|
||||
def get_transformed_activations(
|
||||
activation_names,
|
||||
method="TAT",
|
||||
dks_params=None,
|
||||
tat_params=None,
|
||||
max_slope_func=None,
|
||||
max_curv_func=None,
|
||||
subnet_max_func=None,
|
||||
activation_getter=_get_numpy_activation_function,
|
||||
):
|
||||
params = _get_activations_params(
|
||||
activation_names,
|
||||
method=method,
|
||||
dks_params=dks_params,
|
||||
tat_params=tat_params,
|
||||
max_slope_func=max_slope_func,
|
||||
max_curv_func=max_curv_func,
|
||||
subnet_max_func=subnet_max_func,
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
params = get_transformed_activations(
|
||||
["swish"], method="TAT", subnet_max_func=subnet_max_func
|
||||
)
|
||||
print(params)
|
||||
|
||||
params = get_transformed_activations(
|
||||
["leaky_relu"], method="TAT", subnet_max_func=subnet_max_func_v2
|
||||
)
|
||||
print(params)
|
2
setup.py
2
setup.py
@ -37,7 +37,7 @@ def read(fname="README.md"):
|
||||
|
||||
|
||||
# What packages are required for this module to be executed?
|
||||
REQUIRED = ["numpy>=1.16.5,<=1.19.5", "pyyaml>=5.0.0", "fvcore"]
|
||||
REQUIRED = ["numpy>=1.16.5", "pyyaml>=5.0.0", "fvcore"]
|
||||
|
||||
packages = find_packages(
|
||||
exclude=("tests", "scripts", "scripts-search", "lib*", "exps*")
|
||||
|
502
test.ipynb
Normal file
502
test.ipynb
Normal file
@ -0,0 +1,502 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from nats_bench import create\n",
|
||||
"\n",
|
||||
"# Create the API for size search space\n",
|
||||
"api = create(None, 'sss', fast_mode=True, verbose=True)\n",
|
||||
"\n",
|
||||
"# Create the API for tologoy search space\n",
|
||||
"api = create(None, 'tss', fast_mode=True, verbose=True)\n",
|
||||
"\n",
|
||||
"# Query the loss / accuracy / time for 1234-th candidate architecture on CIFAR-10\n",
|
||||
"# info is a dict, where you can easily figure out the meaning by key\n",
|
||||
"info = api.get_more_info(1234, 'cifar10')\n",
|
||||
"\n",
|
||||
"# Query the flops, params, latency. info is a dict.\n",
|
||||
"info = api.get_cost_info(12, 'cifar10')\n",
|
||||
"\n",
|
||||
"# Simulate the training of the 1224-th candidate:\n",
|
||||
"validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(1224, dataset='cifar10', hp='12')\n",
|
||||
"\n",
|
||||
"# Clear the parameters of the 12-th candidate.\n",
|
||||
"api.clear_params(12)\n",
|
||||
"\n",
|
||||
"# Reload all information of the 12-th candidate.\n",
|
||||
"api.reload(index=12)\n",
|
||||
"\n",
|
||||
"# Create the instance of th 12-th candidate for CIFAR-10.\n",
|
||||
"from models import get_cell_based_tiny_net\n",
|
||||
"config = api.get_net_config(12, 'cifar10')\n",
|
||||
"network = get_cell_based_tiny_net(config)\n",
|
||||
"\n",
|
||||
"# Load the pre-trained weights: params is a dict, where the key is the seed and value is the weights.\n",
|
||||
"params = api.get_net_param(12, 'cifar10', None)\n",
|
||||
"network.load_state_dict(next(iter(params.values())))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from nas_201_api import NASBench201API as API\n",
|
||||
"import os\n",
|
||||
"# api = API('./NAS-Bench-201-v1_1_096897.pth')\n",
|
||||
"# get the current path\n",
|
||||
"print(os.path.abspath(os.path.curdir))\n",
|
||||
"cur_path = os.path.abspath(os.path.curdir)\n",
|
||||
"data_path = os.path.join(cur_path, 'NAS-Bench-201-v1_1-096897.pth')\n",
|
||||
"api = API(data_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# get the best performance on CIFAR-10\n",
|
||||
"len = 15625\n",
|
||||
"accs = []\n",
|
||||
"for i in range(1, len):\n",
|
||||
" results = api.query_by_index(i, 'cifar10')\n",
|
||||
" dict_items = list(results.items())\n",
|
||||
" train_info = dict_items[0][1].get_train()\n",
|
||||
" acc = train_info['accuracy']\n",
|
||||
" accs.append((i, acc))\n",
|
||||
"print(max(accs, key=lambda x: x[1]))\n",
|
||||
"best_index, best_acc = max(accs, key=lambda x: x[1])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def find_best_index(dataset):\n",
|
||||
" len = 15625\n",
|
||||
" accs = []\n",
|
||||
" for i in range(1, len):\n",
|
||||
" results = api.query_by_index(i, dataset)\n",
|
||||
" dict_items = list(results.items())\n",
|
||||
" train_info = dict_items[0][1].get_train()\n",
|
||||
" acc = train_info['accuracy']\n",
|
||||
" accs.append((i, acc))\n",
|
||||
" return max(accs, key=lambda x: x[1])\n",
|
||||
"best_cifar_10_index, best_cifar_10_acc = find_best_index('cifar10')\n",
|
||||
"best_cifar_100_index, best_cifar_100_acc = find_best_index('cifar100')\n",
|
||||
"best_ImageNet16_index, best_ImageNet16_acc= find_best_index('ImageNet16-120')\n",
|
||||
"print(best_cifar_10_index, best_cifar_10_acc)\n",
|
||||
"print(best_cifar_100_index, best_cifar_100_acc)\n",
|
||||
"print(best_ImageNet16_index, best_ImageNet16_acc)\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"api.show(5374)\n",
|
||||
"config = api.get_net_config(best_index, 'cifar10')\n",
|
||||
"from models import get_cell_based_tiny_net\n",
|
||||
"network = get_cell_based_tiny_net(config)\n",
|
||||
"print(network)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"api.get_net_param(5374, 'cifar10', None)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os, sys, time, torch, random, argparse\n",
|
||||
"from PIL import ImageFile\n",
|
||||
"\n",
|
||||
"ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
|
||||
"from copy import deepcopy\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"from config_utils import load_config\n",
|
||||
"from procedures.starts import get_machine_info\n",
|
||||
"from datasets.get_dataset_with_transform import get_datasets\n",
|
||||
"from log_utils import Logger, AverageMeter, time_string, convert_secs2time\n",
|
||||
"from models import CellStructure, CellArchitectures, get_search_spaces"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def evaluate_all_datasets(\n",
|
||||
" arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger\n",
|
||||
"):\n",
|
||||
" machine_info, arch_config = get_machine_info(), deepcopy(arch_config)\n",
|
||||
" all_infos = {\"info\": machine_info}\n",
|
||||
" all_dataset_keys = []\n",
|
||||
" # look all the datasets\n",
|
||||
" for dataset, xpath, split in zip(datasets, xpaths, splits):\n",
|
||||
" # train valid data\n",
|
||||
" train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)\n",
|
||||
" # load the configuration\n",
|
||||
" if dataset == \"cifar10\" or dataset == \"cifar100\":\n",
|
||||
" if use_less:\n",
|
||||
" config_path = \"configs/nas-benchmark/LESS.config\"\n",
|
||||
" else:\n",
|
||||
" config_path = \"configs/nas-benchmark/CIFAR.config\"\n",
|
||||
" split_info = load_config(\n",
|
||||
" \"configs/nas-benchmark/cifar-split.txt\", None, None\n",
|
||||
" )\n",
|
||||
" elif dataset.startswith(\"ImageNet16\"):\n",
|
||||
" if use_less:\n",
|
||||
" config_path = \"configs/nas-benchmark/LESS.config\"\n",
|
||||
" else:\n",
|
||||
" config_path = \"configs/nas-benchmark/ImageNet-16.config\"\n",
|
||||
" split_info = load_config(\n",
|
||||
" \"configs/nas-benchmark/{:}-split.txt\".format(dataset), None, None\n",
|
||||
" )\n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"invalid dataset : {:}\".format(dataset))\n",
|
||||
" config = load_config(\n",
|
||||
" config_path, {\"class_num\": class_num, \"xshape\": xshape}, logger\n",
|
||||
" )\n",
|
||||
" # check whether use splited validation set\n",
|
||||
" if bool(split):\n",
|
||||
" assert dataset == \"cifar10\"\n",
|
||||
" ValLoaders = {\n",
|
||||
" \"ori-test\": torch.utils.data.DataLoader(\n",
|
||||
" valid_data,\n",
|
||||
" batch_size=config.batch_size,\n",
|
||||
" shuffle=False,\n",
|
||||
" num_workers=workers,\n",
|
||||
" pin_memory=True,\n",
|
||||
" )\n",
|
||||
" }\n",
|
||||
" assert len(train_data) == len(split_info.train) + len(\n",
|
||||
" split_info.valid\n",
|
||||
" ), \"invalid length : {:} vs {:} + {:}\".format(\n",
|
||||
" len(train_data), len(split_info.train), len(split_info.valid)\n",
|
||||
" )\n",
|
||||
" train_data_v2 = deepcopy(train_data)\n",
|
||||
" train_data_v2.transform = valid_data.transform\n",
|
||||
" valid_data = train_data_v2\n",
|
||||
" # data loader\n",
|
||||
" train_loader = torch.utils.data.DataLoader(\n",
|
||||
" train_data,\n",
|
||||
" batch_size=config.batch_size,\n",
|
||||
" sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),\n",
|
||||
" num_workers=workers,\n",
|
||||
" pin_memory=True,\n",
|
||||
" )\n",
|
||||
" valid_loader = torch.utils.data.DataLoader(\n",
|
||||
" valid_data,\n",
|
||||
" batch_size=config.batch_size,\n",
|
||||
" sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),\n",
|
||||
" num_workers=workers,\n",
|
||||
" pin_memory=True,\n",
|
||||
" )\n",
|
||||
" ValLoaders[\"x-valid\"] = valid_loader\n",
|
||||
" else:\n",
|
||||
" # data loader\n",
|
||||
" train_loader = torch.utils.data.DataLoader(\n",
|
||||
" train_data,\n",
|
||||
" batch_size=config.batch_size,\n",
|
||||
" shuffle=True,\n",
|
||||
" num_workers=workers,\n",
|
||||
" pin_memory=True,\n",
|
||||
" )\n",
|
||||
" valid_loader = torch.utils.data.DataLoader(\n",
|
||||
" valid_data,\n",
|
||||
" batch_size=config.batch_size,\n",
|
||||
" shuffle=False,\n",
|
||||
" num_workers=workers,\n",
|
||||
" pin_memory=True,\n",
|
||||
" )\n",
|
||||
" if dataset == \"cifar10\":\n",
|
||||
" ValLoaders = {\"ori-test\": valid_loader}\n",
|
||||
" elif dataset == \"cifar100\":\n",
|
||||
" cifar100_splits = load_config(\n",
|
||||
" \"configs/nas-benchmark/cifar100-test-split.txt\", None, None\n",
|
||||
" )\n",
|
||||
" ValLoaders = {\n",
|
||||
" \"ori-test\": valid_loader,\n",
|
||||
" \"x-valid\": torch.utils.data.DataLoader(\n",
|
||||
" valid_data,\n",
|
||||
" batch_size=config.batch_size,\n",
|
||||
" sampler=torch.utils.data.sampler.SubsetRandomSampler(\n",
|
||||
" cifar100_splits.xvalid\n",
|
||||
" ),\n",
|
||||
" num_workers=workers,\n",
|
||||
" pin_memory=True,\n",
|
||||
" ),\n",
|
||||
" \"x-test\": torch.utils.data.DataLoader(\n",
|
||||
" valid_data,\n",
|
||||
" batch_size=config.batch_size,\n",
|
||||
" sampler=torch.utils.data.sampler.SubsetRandomSampler(\n",
|
||||
" cifar100_splits.xtest\n",
|
||||
" ),\n",
|
||||
" num_workers=workers,\n",
|
||||
" pin_memory=True,\n",
|
||||
" ),\n",
|
||||
" }\n",
|
||||
" elif dataset == \"ImageNet16-120\":\n",
|
||||
" imagenet16_splits = load_config(\n",
|
||||
" \"configs/nas-benchmark/imagenet-16-120-test-split.txt\", None, None\n",
|
||||
" )\n",
|
||||
" ValLoaders = {\n",
|
||||
" \"ori-test\": valid_loader,\n",
|
||||
" \"x-valid\": torch.utils.data.DataLoader(\n",
|
||||
" valid_data,\n",
|
||||
" batch_size=config.batch_size,\n",
|
||||
" sampler=torch.utils.data.sampler.SubsetRandomSampler(\n",
|
||||
" imagenet16_splits.xvalid\n",
|
||||
" ),\n",
|
||||
" num_workers=workers,\n",
|
||||
" pin_memory=True,\n",
|
||||
" ),\n",
|
||||
" \"x-test\": torch.utils.data.DataLoader(\n",
|
||||
" valid_data,\n",
|
||||
" batch_size=config.batch_size,\n",
|
||||
" sampler=torch.utils.data.sampler.SubsetRandomSampler(\n",
|
||||
" imagenet16_splits.xtest\n",
|
||||
" ),\n",
|
||||
" num_workers=workers,\n",
|
||||
" pin_memory=True,\n",
|
||||
" ),\n",
|
||||
" }\n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"invalid dataset : {:}\".format(dataset))\n",
|
||||
"\n",
|
||||
" dataset_key = \"{:}\".format(dataset)\n",
|
||||
" if bool(split):\n",
|
||||
" dataset_key = dataset_key + \"-valid\"\n",
|
||||
" logger.log(\n",
|
||||
" \"Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}\".format(\n",
|
||||
" dataset_key,\n",
|
||||
" len(train_data),\n",
|
||||
" len(valid_data),\n",
|
||||
" len(train_loader),\n",
|
||||
" len(valid_loader),\n",
|
||||
" config.batch_size,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" logger.log(\n",
|
||||
" \"Evaluate ||||||| {:10s} ||||||| Config={:}\".format(dataset_key, config)\n",
|
||||
" )\n",
|
||||
" for key, value in ValLoaders.items():\n",
|
||||
" logger.log(\n",
|
||||
" \"Evaluate ---->>>> {:10s} with {:} batchs\".format(key, len(value))\n",
|
||||
" )\n",
|
||||
" results = evaluate_for_seed(\n",
|
||||
" arch_config, config, arch, train_loader, ValLoaders, seed, logger\n",
|
||||
" )\n",
|
||||
" all_infos[dataset_key] = results\n",
|
||||
" all_dataset_keys.append(dataset_key)\n",
|
||||
" all_infos[\"all_dataset_keys\"] = all_dataset_keys\n",
|
||||
" return all_infos\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train_single_model(\n",
|
||||
" save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config\n",
|
||||
"):\n",
|
||||
" assert torch.cuda.is_available(), \"CUDA is not available.\"\n",
|
||||
" torch.backends.cudnn.enabled = True\n",
|
||||
" torch.backends.cudnn.deterministic = True\n",
|
||||
" # torch.backends.cudnn.benchmark = True\n",
|
||||
" torch.set_num_threads(workers)\n",
|
||||
"\n",
|
||||
" save_dir = (\n",
|
||||
" Path(save_dir)\n",
|
||||
" / \"specifics\"\n",
|
||||
" / \"{:}-{:}-{:}-{:}\".format(\n",
|
||||
" \"LESS\" if use_less else \"FULL\",\n",
|
||||
" model_str,\n",
|
||||
" arch_config[\"channel\"],\n",
|
||||
" arch_config[\"num_cells\"],\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" logger = Logger(str(save_dir), 0, False)\n",
|
||||
" if model_str in CellArchitectures:\n",
|
||||
" arch = CellArchitectures[model_str]\n",
|
||||
" logger.log(\n",
|
||||
" \"The model string is found in pre-defined architecture dict : {:}\".format(\n",
|
||||
" model_str\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" else:\n",
|
||||
" try:\n",
|
||||
" arch = CellStructure.str2structure(model_str)\n",
|
||||
" except:\n",
|
||||
" raise ValueError(\n",
|
||||
" \"Invalid model string : {:}. It can not be found or parsed.\".format(\n",
|
||||
" model_str\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" assert arch.check_valid_op(\n",
|
||||
" get_search_spaces(\"cell\", \"full\")\n",
|
||||
" ), \"{:} has the invalid op.\".format(arch)\n",
|
||||
" logger.log(\"Start train-evaluate {:}\".format(arch.tostr()))\n",
|
||||
" logger.log(\"arch_config : {:}\".format(arch_config))\n",
|
||||
"\n",
|
||||
" start_time, seed_time = time.time(), AverageMeter()\n",
|
||||
" for _is, seed in enumerate(seeds):\n",
|
||||
" logger.log(\n",
|
||||
" \"\\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------\".format(\n",
|
||||
" _is, len(seeds), seed\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" to_save_name = save_dir / \"seed-{:04d}.pth\".format(seed)\n",
|
||||
" if to_save_name.exists():\n",
|
||||
" logger.log(\n",
|
||||
" \"Find the existing file {:}, directly load!\".format(to_save_name)\n",
|
||||
" )\n",
|
||||
" checkpoint = torch.load(to_save_name)\n",
|
||||
" else:\n",
|
||||
" logger.log(\n",
|
||||
" \"Does not find the existing file {:}, train and evaluate!\".format(\n",
|
||||
" to_save_name\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" checkpoint = evaluate_all_datasets(\n",
|
||||
" arch,\n",
|
||||
" datasets,\n",
|
||||
" xpaths,\n",
|
||||
" splits,\n",
|
||||
" use_less,\n",
|
||||
" seed,\n",
|
||||
" arch_config,\n",
|
||||
" workers,\n",
|
||||
" logger,\n",
|
||||
" )\n",
|
||||
" torch.save(checkpoint, to_save_name)\n",
|
||||
" # log information\n",
|
||||
" logger.log(\"{:}\".format(checkpoint[\"info\"]))\n",
|
||||
" all_dataset_keys = checkpoint[\"all_dataset_keys\"]\n",
|
||||
" for dataset_key in all_dataset_keys:\n",
|
||||
" logger.log(\n",
|
||||
" \"\\n{:} dataset : {:} {:}\".format(\"-\" * 15, dataset_key, \"-\" * 15)\n",
|
||||
" )\n",
|
||||
" dataset_info = checkpoint[dataset_key]\n",
|
||||
" # logger.log('Network ==>\\n{:}'.format( dataset_info['net_string'] ))\n",
|
||||
" logger.log(\n",
|
||||
" \"Flops = {:} MB, Params = {:} MB\".format(\n",
|
||||
" dataset_info[\"flop\"], dataset_info[\"param\"]\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" logger.log(\"config : {:}\".format(dataset_info[\"config\"]))\n",
|
||||
" logger.log(\n",
|
||||
" \"Training State (finish) = {:}\".format(dataset_info[\"finish-train\"])\n",
|
||||
" )\n",
|
||||
" last_epoch = dataset_info[\"total_epoch\"] - 1\n",
|
||||
" train_acc1es, train_acc5es = (\n",
|
||||
" dataset_info[\"train_acc1es\"],\n",
|
||||
" dataset_info[\"train_acc5es\"],\n",
|
||||
" )\n",
|
||||
" valid_acc1es, valid_acc5es = (\n",
|
||||
" dataset_info[\"valid_acc1es\"],\n",
|
||||
" dataset_info[\"valid_acc5es\"],\n",
|
||||
" )\n",
|
||||
" logger.log(\n",
|
||||
" \"Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%\".format(\n",
|
||||
" train_acc1es[last_epoch],\n",
|
||||
" train_acc5es[last_epoch],\n",
|
||||
" 100 - train_acc1es[last_epoch],\n",
|
||||
" valid_acc1es[last_epoch],\n",
|
||||
" valid_acc5es[last_epoch],\n",
|
||||
" 100 - valid_acc1es[last_epoch],\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" # measure elapsed time\n",
|
||||
" seed_time.update(time.time() - start_time)\n",
|
||||
" start_time = time.time()\n",
|
||||
" need_time = \"Time Left: {:}\".format(\n",
|
||||
" convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)\n",
|
||||
" )\n",
|
||||
" logger.log(\n",
|
||||
" \"\\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}\".format(\n",
|
||||
" _is, len(seeds), seed, need_time\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" logger.close()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_single_model(\n",
|
||||
" save_dir=\"./outputs\",\n",
|
||||
" workers=8,\n",
|
||||
" datasets=\"cifar10\", \n",
|
||||
" xpaths=\"/root/cifardata/cifar-10-batches-py\",\n",
|
||||
" splits=[0, 0, 0],\n",
|
||||
" use_less=False,\n",
|
||||
" seeds=[777],\n",
|
||||
" model_str=\"|nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|\",\n",
|
||||
" arch_config={\"channel\": 16, \"num_cells\": 8},)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "natsbench",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.19"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -156,7 +156,7 @@ class Logger(object):
|
||||
hist.max = float(np.max(values))
|
||||
hist.num = int(np.prod(values.shape))
|
||||
hist.sum = float(np.sum(values))
|
||||
hist.sum_squares = float(np.sum(values ** 2))
|
||||
hist.sum_squares = float(np.sum(values**2))
|
||||
|
||||
# Drop the start of the first bin
|
||||
bin_edges = bin_edges[1:]
|
||||
|
@ -347,6 +347,10 @@ class GenericNAS201Model(nn.Module):
|
||||
feature = cell.forward_gdas(feature, alphas, index)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_gdas"
|
||||
elif self.mode == "gdas_v1":
|
||||
feature = cell.forward_gdas_v1(feature, alphas, index)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_gdas_v1"
|
||||
else:
|
||||
raise ValueError("invalid mode={:}".format(self.mode))
|
||||
else:
|
||||
|
@ -213,6 +213,13 @@ AllConv3x3_CODE = Structure(
|
||||
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)),
|
||||
] # node-3
|
||||
)
|
||||
Number_5374 = Structure(
|
||||
[
|
||||
(("nor_conv_3x3", 0),), # node-1
|
||||
(("nor_conv_1x1", 0), ("nor_conv_3x3", 1)), # node-2
|
||||
(("skip_connect", 0), ("none", 1), ("nor_conv_3x3", 2)), # node-3
|
||||
]
|
||||
)
|
||||
|
||||
AllFull_CODE = Structure(
|
||||
[
|
||||
@ -271,4 +278,5 @@ architectures = {
|
||||
"all_c1x1": AllConv1x1_CODE,
|
||||
"all_idnt": AllIdentity_CODE,
|
||||
"all_full": AllFull_CODE,
|
||||
"5374": Number_5374,
|
||||
}
|
||||
|
@ -85,6 +85,20 @@ class NAS201SearchCell(nn.Module):
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# GDAS Variant: https://github.com/D-X-Y/AutoDL-Projects/issues/119
|
||||
def forward_gdas_v1(self, inputs, hardwts, index):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = hardwts[self.edge2index[node_str]]
|
||||
argmaxs = index[self.edge2index[node_str]].item()
|
||||
weigsum = weights[argmaxs] * self.edges[node_str](nodes[j])
|
||||
inter_nodes.append(weigsum)
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# joint
|
||||
def forward_joint(self, inputs, weightss):
|
||||
nodes = [inputs]
|
||||
@ -152,6 +166,9 @@ class NAS201SearchCell(nn.Module):
|
||||
return nodes[-1]
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
def __init__(self, space, C, stride, affine, track_running_stats):
|
||||
super(MixedOp, self).__init__()
|
||||
@ -167,7 +184,6 @@ class MixedOp(nn.Module):
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
class NASNetSearchCell(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -155,7 +155,7 @@ class ExponentialLR(_LRScheduler):
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
||||
lr = base_lr * (self.gamma ** last_epoch)
|
||||
lr = base_lr * (self.gamma**last_epoch)
|
||||
else:
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
|
@ -12,6 +12,7 @@ def obtain_accuracy(output, target, topk=(1,)):
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
# correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
@ -122,7 +122,7 @@ class ExponentialParamScheduler(ParamScheduler):
|
||||
self._decay = decay
|
||||
|
||||
def __call__(self, where: float) -> float:
|
||||
return self._start_value * (self._decay ** where)
|
||||
return self._start_value * (self._decay**where)
|
||||
|
||||
|
||||
class LinearParamScheduler(ParamScheduler):
|
||||
|
Loading…
Reference in New Issue
Block a user