Compare commits

...

10 Commits

Author SHA1 Message Date
bb33ca9a68 run the specific model 2024-07-11 11:48:51 +02:00
D-X-Y
f46486e21b
Update README.md 2022-04-24 15:18:16 -07:00
D-X-Y
5908a1edef
Merge pull request #123 from Yulv-git/main
Update some links in README_CN.md and fix some typos.
2022-04-24 15:16:21 -07:00
Yulv-git
ed34024a88 Update some links in README_CN.md and fix some typos. 2022-04-23 10:59:49 +08:00
D-X-Y
5bf036a763 Update DKS exploration 2022-03-28 21:28:50 -07:00
D-X-Y
b557a22928
Merge pull request #121 from ain-soph/patch-1
remove numpy version requirements
2022-03-25 00:05:53 -07:00
Ren Pang
f549ed2e61 fix setup bug 2022-03-24 21:06:28 -04:00
Local State
5a5cb82537
remove numpy version requirements
Is it possible to remove numpy version requirements?

I want to use the benchmark, but my codes are relying on some new bug fixes after `numpy>1.20`.
2022-03-24 16:50:19 -04:00
D-X-Y
676e8e411d Upgrade black to 22.1.0 and fix the corresponding issues 2022-03-20 23:18:23 -07:00
D-X-Y
8d0799dfb1 To answer issue #119 2022-03-20 23:12:12 -07:00
18 changed files with 620 additions and 21 deletions

View File

@ -41,7 +41,7 @@ jobs:
- name: Install XAutoDL from source
run: |
python setup.py install
pip install .
- name: Test Search Space
run: |

View File

@ -26,7 +26,7 @@ jobs:
- name: Install XAutoDL from source
run: |
python setup.py install
pip install .
- name: Test Xmisc
run: |

View File

@ -26,7 +26,7 @@ jobs:
- name: Install XAutoDL from source
run: |
python setup.py install
pip install .
- name: Test Super Model
run: |

View File

@ -61,13 +61,13 @@ At this moment, this project provides the following algorithms and scripts to ru
<tr> <!-- (6-th row) -->
<td align="center" valign="middle"> NATS-Bench </td>
<td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size</a> </td>
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/NATS-Bench">NATS-Bench.md</a> </td>
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/NATS-Bench/blob/main/README.md">NATS-Bench.md</a> </td>
</tr>
<tr> <!-- (7-th row) -->
<td align="center" valign="middle"> ... </td>
<td align="center" valign="middle"> ENAS / REA / REINFORCE / BOHB </td>
<td align="center" valign="middle"> Please check the original papers </td>
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NAS-Bench-201.md">NAS-Bench-201.md</a> <a href="https://github.com/D-X-Y/NATS-Bench">NATS-Bench.md</a> </td>
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NAS-Bench-201.md">NAS-Bench-201.md</a> <a href="https://github.com/D-X-Y/NATS-Bench/blob/main/README.md">NATS-Bench.md</a> </td>
</tr>
<tr> <!-- (start second block) -->
<td rowspan="1" align="center" valign="middle" halign="middle"> HPO </td>
@ -89,7 +89,7 @@ At this moment, this project provides the following algorithms and scripts to ru
## Requirements and Preparation
**First of all**, please use `python setup.py install` to install `xautodl` library.
**First of all**, please use `pip install .` to install `xautodl` library.
Please install `Python>=3.6` and `PyTorch>=1.5.0`. (You could use lower versions of Python and PyTorch, but may have bugs).
Some visualization codes may require `opencv`.

View File

@ -29,7 +29,7 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s
You can move it to anywhere you want and send its path to our API for initialization.
- [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
- [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the trained weights.
- [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
- [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions
- [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable.

View File

@ -27,7 +27,7 @@ You can simply type `pip install nas-bench-201` to install our api. Please see s
You can move it to anywhere you want and send its path to our API for initialization.
- [2020.02.25] APIv1.0/FILEv1.0: [`NAS-Bench-201-v1_0-e61699.pth`](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) (2.2G), where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
- [2020.02.25] APIv1.0/FILEv1.0: The full data of each architecture can be download from [
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
NAS-BENCH-201-4-v1.0-archive.tar](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the trained weights.
- [2020.02.25] APIv1.0/FILEv1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
- [2020.03.09] APIv1.2/FILEv1.0: More robust API with more functions and descriptions
- [2020.03.16] APIv1.3/FILEv1.1: [`NAS-Bench-201-v1_1-096897.pth`](https://drive.google.com/open?id=16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_) (4.7G), where `096897` is the last six digits for this file. It contains information of more trials compared to `NAS-Bench-201-v1_0-e61699.pth`, especially all models trained by 12 epochs on all datasets are avaliable.

View File

@ -3,7 +3,7 @@
</p>
---------
[![MIT licensed](https://img.shields.io/badge/license-MIT-brightgreen.svg)](LICENSE.md)
[![MIT licensed](https://img.shields.io/badge/license-MIT-brightgreen.svg)](../LICENSE.md)
自动深度学习库 (AutoDL-Projects) 是一个开源的,轻量级的,功能强大的项目。
该项目实现了多种网络结构搜索(NAS)和超参数优化(HPO)算法。
@ -142,8 +142,8 @@
# 其他
如果你想要给这份代码库做贡献,请看[CONTRIBUTING.md](.github/CONTRIBUTING.md)。
此外,使用规范请参考[CODE-OF-CONDUCT.md](.github/CODE-OF-CONDUCT.md)。
如果你想要给这份代码库做贡献,请看[CONTRIBUTING.md](../.github/CONTRIBUTING.md)。
此外,使用规范请参考[CODE-OF-CONDUCT.md](../.github/CODE-OF-CONDUCT.md)。
# 许可证
The entire codebase is under [MIT license](LICENSE.md)
The entire codebase is under [MIT license](../LICENSE.md)

View File

@ -24,6 +24,9 @@
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
####
# The following scripts are added in 20 Mar 2022
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas_v1 --rand_seed 777
######################################################################################
import os, sys, time, random, argparse
import numpy as np
@ -166,6 +169,8 @@ def search_func(
network.set_cal_mode("dynamic", sampled_arch)
elif algo == "gdas":
network.set_cal_mode("gdas", None)
elif algo == "gdas_v1":
network.set_cal_mode("gdas_v1", None)
elif algo.startswith("darts"):
network.set_cal_mode("joint", None)
elif algo == "random":
@ -196,6 +201,8 @@ def search_func(
network.set_cal_mode("joint")
elif algo == "gdas":
network.set_cal_mode("gdas", None)
elif algo == "gdas_v1":
network.set_cal_mode("gdas_v1", None)
elif algo.startswith("darts"):
network.set_cal_mode("joint", None)
elif algo == "random":
@ -373,7 +380,7 @@ def get_best_arch(xloader, network, n_samples, algo):
archs, valid_accs = network.return_topK(n_samples, True), []
elif algo == "setn":
archs, valid_accs = network.return_topK(n_samples, False), []
elif algo.startswith("darts") or algo == "gdas":
elif algo.startswith("darts") or algo == "gdas" or algo == "gdas_v1":
arch = network.genotype
archs, valid_accs = [arch], []
elif algo == "enas":
@ -568,7 +575,7 @@ def main(xargs):
)
network.set_drop_path(float(epoch + 1) / total_epoch, xargs.drop_path_rate)
if xargs.algo == "gdas":
if xargs.algo == "gdas" or xargs.algo == "gdas_v1":
network.set_tau(
xargs.tau_max
- (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)
@ -632,6 +639,8 @@ def main(xargs):
network.set_cal_mode("dynamic", genotype)
elif xargs.algo == "gdas":
network.set_cal_mode("gdas", None)
elif xargs.algo == "gdas_v1":
network.set_cal_mode("gdas_v1", None)
elif xargs.algo.startswith("darts"):
network.set_cal_mode("joint", None)
elif xargs.algo == "random":
@ -699,6 +708,8 @@ def main(xargs):
network.set_cal_mode("dynamic", genotype)
elif xargs.algo == "gdas":
network.set_cal_mode("gdas", None)
elif xargs.algo == "gdas_v1":
network.set_cal_mode("gdas_v1", None)
elif xargs.algo.startswith("darts"):
network.set_cal_mode("joint", None)
elif xargs.algo == "random":
@ -747,7 +758,7 @@ if __name__ == "__main__":
parser.add_argument(
"--algo",
type=str,
choices=["darts-v1", "darts-v2", "gdas", "setn", "random", "enas"],
choices=["darts-v1", "darts-v2", "gdas", "gdas_v1", "setn", "random", "enas"],
help="The search space name.",
)
parser.add_argument(

View File

@ -0,0 +1,57 @@
from dks.base.activation_getter import (
get_activation_function as _get_numpy_activation_function,
)
from dks.base.activation_transform import _get_activations_params
def subnet_max_func(x, r_fn):
depth = 7
res_x = r_fn(x)
x = r_fn(x)
for _ in range(depth):
x = r_fn(r_fn(x)) + x
return max(x, res_x)
def subnet_max_func_v2(x, r_fn):
depth = 2
res_x = r_fn(x)
x = r_fn(x)
for _ in range(depth):
x = 0.8 * r_fn(r_fn(x)) + 0.2 * x
return max(x, res_x)
def get_transformed_activations(
activation_names,
method="TAT",
dks_params=None,
tat_params=None,
max_slope_func=None,
max_curv_func=None,
subnet_max_func=None,
activation_getter=_get_numpy_activation_function,
):
params = _get_activations_params(
activation_names,
method=method,
dks_params=dks_params,
tat_params=tat_params,
max_slope_func=max_slope_func,
max_curv_func=max_curv_func,
subnet_max_func=subnet_max_func,
)
return params
params = get_transformed_activations(
["swish"], method="TAT", subnet_max_func=subnet_max_func
)
print(params)
params = get_transformed_activations(
["leaky_relu"], method="TAT", subnet_max_func=subnet_max_func_v2
)
print(params)

View File

@ -37,7 +37,7 @@ def read(fname="README.md"):
# What packages are required for this module to be executed?
REQUIRED = ["numpy>=1.16.5,<=1.19.5", "pyyaml>=5.0.0", "fvcore"]
REQUIRED = ["numpy>=1.16.5", "pyyaml>=5.0.0", "fvcore"]
packages = find_packages(
exclude=("tests", "scripts", "scripts-search", "lib*", "exps*")

502
test.ipynb Normal file
View File

@ -0,0 +1,502 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from nats_bench import create\n",
"\n",
"# Create the API for size search space\n",
"api = create(None, 'sss', fast_mode=True, verbose=True)\n",
"\n",
"# Create the API for tologoy search space\n",
"api = create(None, 'tss', fast_mode=True, verbose=True)\n",
"\n",
"# Query the loss / accuracy / time for 1234-th candidate architecture on CIFAR-10\n",
"# info is a dict, where you can easily figure out the meaning by key\n",
"info = api.get_more_info(1234, 'cifar10')\n",
"\n",
"# Query the flops, params, latency. info is a dict.\n",
"info = api.get_cost_info(12, 'cifar10')\n",
"\n",
"# Simulate the training of the 1224-th candidate:\n",
"validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(1224, dataset='cifar10', hp='12')\n",
"\n",
"# Clear the parameters of the 12-th candidate.\n",
"api.clear_params(12)\n",
"\n",
"# Reload all information of the 12-th candidate.\n",
"api.reload(index=12)\n",
"\n",
"# Create the instance of th 12-th candidate for CIFAR-10.\n",
"from models import get_cell_based_tiny_net\n",
"config = api.get_net_config(12, 'cifar10')\n",
"network = get_cell_based_tiny_net(config)\n",
"\n",
"# Load the pre-trained weights: params is a dict, where the key is the seed and value is the weights.\n",
"params = api.get_net_param(12, 'cifar10', None)\n",
"network.load_state_dict(next(iter(params.values())))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from nas_201_api import NASBench201API as API\n",
"import os\n",
"# api = API('./NAS-Bench-201-v1_1_096897.pth')\n",
"# get the current path\n",
"print(os.path.abspath(os.path.curdir))\n",
"cur_path = os.path.abspath(os.path.curdir)\n",
"data_path = os.path.join(cur_path, 'NAS-Bench-201-v1_1-096897.pth')\n",
"api = API(data_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get the best performance on CIFAR-10\n",
"len = 15625\n",
"accs = []\n",
"for i in range(1, len):\n",
" results = api.query_by_index(i, 'cifar10')\n",
" dict_items = list(results.items())\n",
" train_info = dict_items[0][1].get_train()\n",
" acc = train_info['accuracy']\n",
" accs.append((i, acc))\n",
"print(max(accs, key=lambda x: x[1]))\n",
"best_index, best_acc = max(accs, key=lambda x: x[1])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def find_best_index(dataset):\n",
" len = 15625\n",
" accs = []\n",
" for i in range(1, len):\n",
" results = api.query_by_index(i, dataset)\n",
" dict_items = list(results.items())\n",
" train_info = dict_items[0][1].get_train()\n",
" acc = train_info['accuracy']\n",
" accs.append((i, acc))\n",
" return max(accs, key=lambda x: x[1])\n",
"best_cifar_10_index, best_cifar_10_acc = find_best_index('cifar10')\n",
"best_cifar_100_index, best_cifar_100_acc = find_best_index('cifar100')\n",
"best_ImageNet16_index, best_ImageNet16_acc= find_best_index('ImageNet16-120')\n",
"print(best_cifar_10_index, best_cifar_10_acc)\n",
"print(best_cifar_100_index, best_cifar_100_acc)\n",
"print(best_ImageNet16_index, best_ImageNet16_acc)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"api.show(5374)\n",
"config = api.get_net_config(best_index, 'cifar10')\n",
"from models import get_cell_based_tiny_net\n",
"network = get_cell_based_tiny_net(config)\n",
"print(network)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"api.get_net_param(5374, 'cifar10', None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os, sys, time, torch, random, argparse\n",
"from PIL import ImageFile\n",
"\n",
"ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
"from copy import deepcopy\n",
"from pathlib import Path\n",
"\n",
"from config_utils import load_config\n",
"from procedures.starts import get_machine_info\n",
"from datasets.get_dataset_with_transform import get_datasets\n",
"from log_utils import Logger, AverageMeter, time_string, convert_secs2time\n",
"from models import CellStructure, CellArchitectures, get_search_spaces"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def evaluate_all_datasets(\n",
" arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger\n",
"):\n",
" machine_info, arch_config = get_machine_info(), deepcopy(arch_config)\n",
" all_infos = {\"info\": machine_info}\n",
" all_dataset_keys = []\n",
" # look all the datasets\n",
" for dataset, xpath, split in zip(datasets, xpaths, splits):\n",
" # train valid data\n",
" train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)\n",
" # load the configuration\n",
" if dataset == \"cifar10\" or dataset == \"cifar100\":\n",
" if use_less:\n",
" config_path = \"configs/nas-benchmark/LESS.config\"\n",
" else:\n",
" config_path = \"configs/nas-benchmark/CIFAR.config\"\n",
" split_info = load_config(\n",
" \"configs/nas-benchmark/cifar-split.txt\", None, None\n",
" )\n",
" elif dataset.startswith(\"ImageNet16\"):\n",
" if use_less:\n",
" config_path = \"configs/nas-benchmark/LESS.config\"\n",
" else:\n",
" config_path = \"configs/nas-benchmark/ImageNet-16.config\"\n",
" split_info = load_config(\n",
" \"configs/nas-benchmark/{:}-split.txt\".format(dataset), None, None\n",
" )\n",
" else:\n",
" raise ValueError(\"invalid dataset : {:}\".format(dataset))\n",
" config = load_config(\n",
" config_path, {\"class_num\": class_num, \"xshape\": xshape}, logger\n",
" )\n",
" # check whether use splited validation set\n",
" if bool(split):\n",
" assert dataset == \"cifar10\"\n",
" ValLoaders = {\n",
" \"ori-test\": torch.utils.data.DataLoader(\n",
" valid_data,\n",
" batch_size=config.batch_size,\n",
" shuffle=False,\n",
" num_workers=workers,\n",
" pin_memory=True,\n",
" )\n",
" }\n",
" assert len(train_data) == len(split_info.train) + len(\n",
" split_info.valid\n",
" ), \"invalid length : {:} vs {:} + {:}\".format(\n",
" len(train_data), len(split_info.train), len(split_info.valid)\n",
" )\n",
" train_data_v2 = deepcopy(train_data)\n",
" train_data_v2.transform = valid_data.transform\n",
" valid_data = train_data_v2\n",
" # data loader\n",
" train_loader = torch.utils.data.DataLoader(\n",
" train_data,\n",
" batch_size=config.batch_size,\n",
" sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),\n",
" num_workers=workers,\n",
" pin_memory=True,\n",
" )\n",
" valid_loader = torch.utils.data.DataLoader(\n",
" valid_data,\n",
" batch_size=config.batch_size,\n",
" sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),\n",
" num_workers=workers,\n",
" pin_memory=True,\n",
" )\n",
" ValLoaders[\"x-valid\"] = valid_loader\n",
" else:\n",
" # data loader\n",
" train_loader = torch.utils.data.DataLoader(\n",
" train_data,\n",
" batch_size=config.batch_size,\n",
" shuffle=True,\n",
" num_workers=workers,\n",
" pin_memory=True,\n",
" )\n",
" valid_loader = torch.utils.data.DataLoader(\n",
" valid_data,\n",
" batch_size=config.batch_size,\n",
" shuffle=False,\n",
" num_workers=workers,\n",
" pin_memory=True,\n",
" )\n",
" if dataset == \"cifar10\":\n",
" ValLoaders = {\"ori-test\": valid_loader}\n",
" elif dataset == \"cifar100\":\n",
" cifar100_splits = load_config(\n",
" \"configs/nas-benchmark/cifar100-test-split.txt\", None, None\n",
" )\n",
" ValLoaders = {\n",
" \"ori-test\": valid_loader,\n",
" \"x-valid\": torch.utils.data.DataLoader(\n",
" valid_data,\n",
" batch_size=config.batch_size,\n",
" sampler=torch.utils.data.sampler.SubsetRandomSampler(\n",
" cifar100_splits.xvalid\n",
" ),\n",
" num_workers=workers,\n",
" pin_memory=True,\n",
" ),\n",
" \"x-test\": torch.utils.data.DataLoader(\n",
" valid_data,\n",
" batch_size=config.batch_size,\n",
" sampler=torch.utils.data.sampler.SubsetRandomSampler(\n",
" cifar100_splits.xtest\n",
" ),\n",
" num_workers=workers,\n",
" pin_memory=True,\n",
" ),\n",
" }\n",
" elif dataset == \"ImageNet16-120\":\n",
" imagenet16_splits = load_config(\n",
" \"configs/nas-benchmark/imagenet-16-120-test-split.txt\", None, None\n",
" )\n",
" ValLoaders = {\n",
" \"ori-test\": valid_loader,\n",
" \"x-valid\": torch.utils.data.DataLoader(\n",
" valid_data,\n",
" batch_size=config.batch_size,\n",
" sampler=torch.utils.data.sampler.SubsetRandomSampler(\n",
" imagenet16_splits.xvalid\n",
" ),\n",
" num_workers=workers,\n",
" pin_memory=True,\n",
" ),\n",
" \"x-test\": torch.utils.data.DataLoader(\n",
" valid_data,\n",
" batch_size=config.batch_size,\n",
" sampler=torch.utils.data.sampler.SubsetRandomSampler(\n",
" imagenet16_splits.xtest\n",
" ),\n",
" num_workers=workers,\n",
" pin_memory=True,\n",
" ),\n",
" }\n",
" else:\n",
" raise ValueError(\"invalid dataset : {:}\".format(dataset))\n",
"\n",
" dataset_key = \"{:}\".format(dataset)\n",
" if bool(split):\n",
" dataset_key = dataset_key + \"-valid\"\n",
" logger.log(\n",
" \"Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}\".format(\n",
" dataset_key,\n",
" len(train_data),\n",
" len(valid_data),\n",
" len(train_loader),\n",
" len(valid_loader),\n",
" config.batch_size,\n",
" )\n",
" )\n",
" logger.log(\n",
" \"Evaluate ||||||| {:10s} ||||||| Config={:}\".format(dataset_key, config)\n",
" )\n",
" for key, value in ValLoaders.items():\n",
" logger.log(\n",
" \"Evaluate ---->>>> {:10s} with {:} batchs\".format(key, len(value))\n",
" )\n",
" results = evaluate_for_seed(\n",
" arch_config, config, arch, train_loader, ValLoaders, seed, logger\n",
" )\n",
" all_infos[dataset_key] = results\n",
" all_dataset_keys.append(dataset_key)\n",
" all_infos[\"all_dataset_keys\"] = all_dataset_keys\n",
" return all_infos\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train_single_model(\n",
" save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config\n",
"):\n",
" assert torch.cuda.is_available(), \"CUDA is not available.\"\n",
" torch.backends.cudnn.enabled = True\n",
" torch.backends.cudnn.deterministic = True\n",
" # torch.backends.cudnn.benchmark = True\n",
" torch.set_num_threads(workers)\n",
"\n",
" save_dir = (\n",
" Path(save_dir)\n",
" / \"specifics\"\n",
" / \"{:}-{:}-{:}-{:}\".format(\n",
" \"LESS\" if use_less else \"FULL\",\n",
" model_str,\n",
" arch_config[\"channel\"],\n",
" arch_config[\"num_cells\"],\n",
" )\n",
" )\n",
" logger = Logger(str(save_dir), 0, False)\n",
" if model_str in CellArchitectures:\n",
" arch = CellArchitectures[model_str]\n",
" logger.log(\n",
" \"The model string is found in pre-defined architecture dict : {:}\".format(\n",
" model_str\n",
" )\n",
" )\n",
" else:\n",
" try:\n",
" arch = CellStructure.str2structure(model_str)\n",
" except:\n",
" raise ValueError(\n",
" \"Invalid model string : {:}. It can not be found or parsed.\".format(\n",
" model_str\n",
" )\n",
" )\n",
" assert arch.check_valid_op(\n",
" get_search_spaces(\"cell\", \"full\")\n",
" ), \"{:} has the invalid op.\".format(arch)\n",
" logger.log(\"Start train-evaluate {:}\".format(arch.tostr()))\n",
" logger.log(\"arch_config : {:}\".format(arch_config))\n",
"\n",
" start_time, seed_time = time.time(), AverageMeter()\n",
" for _is, seed in enumerate(seeds):\n",
" logger.log(\n",
" \"\\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------\".format(\n",
" _is, len(seeds), seed\n",
" )\n",
" )\n",
" to_save_name = save_dir / \"seed-{:04d}.pth\".format(seed)\n",
" if to_save_name.exists():\n",
" logger.log(\n",
" \"Find the existing file {:}, directly load!\".format(to_save_name)\n",
" )\n",
" checkpoint = torch.load(to_save_name)\n",
" else:\n",
" logger.log(\n",
" \"Does not find the existing file {:}, train and evaluate!\".format(\n",
" to_save_name\n",
" )\n",
" )\n",
" checkpoint = evaluate_all_datasets(\n",
" arch,\n",
" datasets,\n",
" xpaths,\n",
" splits,\n",
" use_less,\n",
" seed,\n",
" arch_config,\n",
" workers,\n",
" logger,\n",
" )\n",
" torch.save(checkpoint, to_save_name)\n",
" # log information\n",
" logger.log(\"{:}\".format(checkpoint[\"info\"]))\n",
" all_dataset_keys = checkpoint[\"all_dataset_keys\"]\n",
" for dataset_key in all_dataset_keys:\n",
" logger.log(\n",
" \"\\n{:} dataset : {:} {:}\".format(\"-\" * 15, dataset_key, \"-\" * 15)\n",
" )\n",
" dataset_info = checkpoint[dataset_key]\n",
" # logger.log('Network ==>\\n{:}'.format( dataset_info['net_string'] ))\n",
" logger.log(\n",
" \"Flops = {:} MB, Params = {:} MB\".format(\n",
" dataset_info[\"flop\"], dataset_info[\"param\"]\n",
" )\n",
" )\n",
" logger.log(\"config : {:}\".format(dataset_info[\"config\"]))\n",
" logger.log(\n",
" \"Training State (finish) = {:}\".format(dataset_info[\"finish-train\"])\n",
" )\n",
" last_epoch = dataset_info[\"total_epoch\"] - 1\n",
" train_acc1es, train_acc5es = (\n",
" dataset_info[\"train_acc1es\"],\n",
" dataset_info[\"train_acc5es\"],\n",
" )\n",
" valid_acc1es, valid_acc5es = (\n",
" dataset_info[\"valid_acc1es\"],\n",
" dataset_info[\"valid_acc5es\"],\n",
" )\n",
" logger.log(\n",
" \"Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%\".format(\n",
" train_acc1es[last_epoch],\n",
" train_acc5es[last_epoch],\n",
" 100 - train_acc1es[last_epoch],\n",
" valid_acc1es[last_epoch],\n",
" valid_acc5es[last_epoch],\n",
" 100 - valid_acc1es[last_epoch],\n",
" )\n",
" )\n",
" # measure elapsed time\n",
" seed_time.update(time.time() - start_time)\n",
" start_time = time.time()\n",
" need_time = \"Time Left: {:}\".format(\n",
" convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)\n",
" )\n",
" logger.log(\n",
" \"\\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}\".format(\n",
" _is, len(seeds), seed, need_time\n",
" )\n",
" )\n",
" logger.close()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_single_model(\n",
" save_dir=\"./outputs\",\n",
" workers=8,\n",
" datasets=\"cifar10\", \n",
" xpaths=\"/root/cifardata/cifar-10-batches-py\",\n",
" splits=[0, 0, 0],\n",
" use_less=False,\n",
" seeds=[777],\n",
" model_str=\"|nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|none~1|nor_conv_3x3~2|\",\n",
" arch_config={\"channel\": 16, \"num_cells\": 8},)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "natsbench",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -156,7 +156,7 @@ class Logger(object):
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values ** 2))
hist.sum_squares = float(np.sum(values**2))
# Drop the start of the first bin
bin_edges = bin_edges[1:]

View File

@ -347,6 +347,10 @@ class GenericNAS201Model(nn.Module):
feature = cell.forward_gdas(feature, alphas, index)
if self.verbose:
verbose_str += "-forward_gdas"
elif self.mode == "gdas_v1":
feature = cell.forward_gdas_v1(feature, alphas, index)
if self.verbose:
verbose_str += "-forward_gdas_v1"
else:
raise ValueError("invalid mode={:}".format(self.mode))
else:

View File

@ -213,6 +213,13 @@ AllConv3x3_CODE = Structure(
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)),
] # node-3
)
Number_5374 = Structure(
[
(("nor_conv_3x3", 0),), # node-1
(("nor_conv_1x1", 0), ("nor_conv_3x3", 1)), # node-2
(("skip_connect", 0), ("none", 1), ("nor_conv_3x3", 2)), # node-3
]
)
AllFull_CODE = Structure(
[
@ -271,4 +278,5 @@ architectures = {
"all_c1x1": AllConv1x1_CODE,
"all_idnt": AllIdentity_CODE,
"all_full": AllFull_CODE,
"5374": Number_5374,
}

View File

@ -85,6 +85,20 @@ class NAS201SearchCell(nn.Module):
nodes.append(sum(inter_nodes))
return nodes[-1]
# GDAS Variant: https://github.com/D-X-Y/AutoDL-Projects/issues/119
def forward_gdas_v1(self, inputs, hardwts, index):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
weights = hardwts[self.edge2index[node_str]]
argmaxs = index[self.edge2index[node_str]].item()
weigsum = weights[argmaxs] * self.edges[node_str](nodes[j])
inter_nodes.append(weigsum)
nodes.append(sum(inter_nodes))
return nodes[-1]
# joint
def forward_joint(self, inputs, weightss):
nodes = [inputs]
@ -152,6 +166,9 @@ class NAS201SearchCell(nn.Module):
return nodes[-1]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class MixedOp(nn.Module):
def __init__(self, space, C, stride, affine, track_running_stats):
super(MixedOp, self).__init__()
@ -167,7 +184,6 @@ class MixedOp(nn.Module):
return sum(w * op(x) for w, op in zip(weights, self._ops))
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetSearchCell(nn.Module):
def __init__(
self,

View File

@ -155,7 +155,7 @@ class ExponentialLR(_LRScheduler):
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
lr = base_lr * (self.gamma ** last_epoch)
lr = base_lr * (self.gamma**last_epoch)
else:
lr = (
self.current_epoch / self.warmup_epochs

View File

@ -12,6 +12,7 @@ def obtain_accuracy(output, target, topk=(1,)):
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
# correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

View File

@ -122,7 +122,7 @@ class ExponentialParamScheduler(ParamScheduler):
self._decay = decay
def __call__(self, where: float) -> float:
return self._start_value * (self._decay ** where)
return self._start_value * (self._decay**where)
class LinearParamScheduler(ParamScheduler):