update NAS-Bench-102

This commit is contained in:
D-X-Y
2019-12-21 11:13:08 +11:00
parent 69ca0860aa
commit 95ec4d328e
3 changed files with 76 additions and 39 deletions

View File

@@ -6,11 +6,16 @@ Each edge here is associated with an operation selected from a predefined operat
For it to be applicable for all NAS algorithms, the search space defined in NAS-Bench-102 includes 4 nodes and 5 associated operation options, which generates 15,625 neural cell candidates in total.
In this Markdown file, we provide:
- Detailed instruction to reproduce NAS-Bench-102.
- 10 NAS algorithms evaluated in our paper.
- [How to Use NAS-Bench-102](#how-to-use-nas-bench-102)
- [Instruction to re-generate NAS-Bench-102](#instruction-to-re-generate-nas-bench-102)
- [10 NAS algorithms evaluated in our paper](#to-reproduce-10-baseline-nas-algorithms-in-nas-bench-102)
Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.
The data file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan].
The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan].
## How to Use NAS-Bench-102
1. Creating an API instance from a file:
@@ -35,8 +40,8 @@ api.show(2)
# show the mean loss and accuracy of an architecture
info = api.query_meta_info_by_index(1)
loss, accuracy = info.get_metrics('cifar10', 'train')
flops, params, latency = info.get_comput_costs('cifar100')
res_metrics = info.get_metrics('cifar10', 'train')
cost_metrics = info.get_comput_costs('cifar100')
# get the detailed information
results = api.query_by_index(1, 'cifar100')
@@ -55,7 +60,8 @@ index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1
api.show(index)
```
5. For other usages, please see `lib/aa_nas_api/api.py`
5. For other usages, please see `lib/nas_102_api/api.py`
### Detailed Instruction
@@ -98,8 +104,10 @@ print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True)) # print loss
```
from nas_102_api import NASBench102API as API
api = API('NAS-Bench-102-v1_0.pth')
api.show(-1) # show info of all architectures
```
## Instruction to Re-Generate NAS-Bench-102
1. generate the meta file for NAS-Bench-102 using the following script, where `NAS-BENCH-102` indicates the name and `4` indicates the maximum number of nodes in a cell.
@@ -139,6 +147,7 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-102/train-a-net.sh resnet
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-102/train-a-net.sh '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|' 16 5
```
## To Reproduce 10 Baseline NAS Algorithms in NAS-Bench-102
We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our NAS-Bench-102.