Add more algorithms
This commit is contained in:
101
README.md
101
README.md
@@ -1,65 +1,84 @@
|
||||
## [Searching for A Robust Neural Architecture in Four GPU Hours](http://xuanyidong.com/publication/gradient-based-diff-sampler/)
|
||||
# Nueral Architecture Search
|
||||
|
||||
We propose A Gradient-based neural architecture search approach using Differentiable Architecture Sampler (GDAS). Please find details in [our paper](https://github.com/D-X-Y/GDAS/blob/master/data/GDAS.pdf).
|
||||
This project contains the following neural architecture search algorithms, implemented in PyTorch.
|
||||
|
||||
<img src="data/GDAS.png" width="520">
|
||||
Figure-1. We utilize a DAG to represent the search space of a neural cell. Different operations (colored arrows) transform one node (square) to its intermediate features (little circles). Meanwhile, each node is the sum of the intermediate features transformed from the previous nodes. As indicated by the solid connections, the neural cell in the proposed GDAS is a sampled sub-graph of this DAG. Specifically, among the intermediate features between every two nodes, GDAS samples one feature in a differentiable way.
|
||||
- Network Pruning via Transformable Architecture Search
|
||||
- One-Shot Neural Architecture Search via Self-Evaluated Template Network
|
||||
- Searching for A Robust Neural Architecture in Four GPU Hours
|
||||
|
||||
### Requirements
|
||||
- PyTorch 1.0.1
|
||||
- Python 3.6
|
||||
- opencv
|
||||
|
||||
## Requirements and Preparation
|
||||
|
||||
Please install `PyTorch>=1.0.1`, `Python>=3.6`, and `opencv`.
|
||||
|
||||
The CIFAR and ImageNet should be downloaded and extracted into `$TORCH_HOME`.
|
||||
Some methods use knowledge distillation (KD), which require pre-trained models. Please download these models from [Google Driver](https://drive.google.com/open?id=1ANmiYEGX-IQZTfH8w0aSpj-Wypg-0DR-) (or train by yourself) and save into `.latent-data`.
|
||||
|
||||
|
||||
## Network Pruning via Transformable Architecture Search
|
||||
|
||||
Use `bash ./scripts/prepare.sh` to prepare data splits for `CIFAR-10`, `CIFARR-100`, and `ILSVRC2012`.
|
||||
If you do not have `ILSVRC2012` data, pleasee comment L12 in `./scripts/prepare.sh`.
|
||||
|
||||
Search the depth configuration of ResNet:
|
||||
```
|
||||
conda install pytorch torchvision cuda100 -c pytorch
|
||||
CUDA_VISIBLE_DEVICES=0,1 bash ./scripts-search/search-depth-gumbel.sh cifar10 ResNet110 CIFARX 0.57 -1
|
||||
```
|
||||
|
||||
### Usages
|
||||
|
||||
Train the searched CNN on CIFAR
|
||||
Search the width configuration of ResNet:
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-cnn/train-cifar.sh GDAS_FG cifar10 cut
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-cnn/train-cifar.sh GDAS_F1 cifar10 cut
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-cnn/train-cifar.sh GDAS_V1 cifar100 cut
|
||||
CUDA_VISIBLE_DEVICES=0,1 bash ./scripts-search/search-width-gumbel.sh cifar10 ResNet110 CIFARX 0.57 -1
|
||||
```
|
||||
|
||||
Train the searched CNN on ImageNet
|
||||
Search for both depth and width configuration of ResNet:
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts-cnn/train-imagenet.sh GDAS_F1 52 14 B128 -1
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts-cnn/train-imagenet.sh GDAS_V1 50 14 B256 -1
|
||||
CUDA_VISIBLE_DEVICES=0,1 bash ./scripts-search/search-cifar.sh cifar10 ResNet56 CIFARX 0.47 -1
|
||||
```
|
||||
|
||||
Evaluate a trained CNN model
|
||||
args: `cifar10` indicates the dataset name, `ResNet56` indicates the basemodel name, `CIFARX` indicates the searching hyper-parameters, `0.47/0.57` indicates the expected FLOP ratio, `-1` indicates the random seed.
|
||||
|
||||
|
||||
## One-Shot Neural Architecture Search via Self-Evaluated Template Network
|
||||
|
||||
Train the searched SETN-searched CNN on CIFAR-10, CIFAR-100, and ImageNet.
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0 python ./exps-cnn/evaluate.py --data_path $TORCH_HOME/cifar.python --checkpoint ${checkpoint-path}
|
||||
CUDA_VISIBLE_DEVICES=0 python ./exps-cnn/evaluate.py --data_path $TORCH_HOME/ILSVRC2012 --checkpoint ${checkpoint-path}
|
||||
CUDA_VISIBLE_DEVICES=0 python ./exps-cnn/evaluate.py --data_path $TORCH_HOME/ILSVRC2012 --checkpoint GDAS-V1-C50-N14-ImageNet.pth
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar10 SETN 96 -1
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 SETN 96 -1
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN 256 -1
|
||||
```
|
||||
|
||||
Train the searched RNN
|
||||
Searching codes come soon!
|
||||
|
||||
|
||||
## Searching for A Robust Neural Architecture in Four GPU Hours
|
||||
|
||||
The old version is located in `others/GDAS`.
|
||||
|
||||
Train the searched GDAS-searched CNN on CIFAR-10, CIFAR-100, and ImageNet.
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-PTB.sh DARTS_V1
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-PTB.sh DARTS_V2
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-PTB.sh GDAS
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-WT2.sh DARTS_V1
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-WT2.sh DARTS_V2
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-rnn/train-WT2.sh GDAS
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar10 GDAS_V1 96 -1
|
||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 GDAS_V1 96 -1
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_V1 256 -1
|
||||
```
|
||||
|
||||
### Training Logs
|
||||
You can find some training logs in [`./data/logs/`](https://github.com/D-X-Y/GDAS/tree/master/data/logs).
|
||||
You can also find some pre-trained models in [Google Driver](https://drive.google.com/open?id=1Ofhc49xC1PLIX4O708gJZ1ugzz4td_RJ).
|
||||
Searching codes come soon!
|
||||
|
||||
### Experimental Results
|
||||
<img src="data/imagenet-results.png" width="700">
|
||||
Figure-2. Top-1 and top-5 errors on ImageNet.
|
||||
|
||||
### Correction
|
||||
|
||||
The Gumbel-softmax tempurature during searching should decrease from 10 to 0.1.
|
||||
|
||||
### Citation
|
||||
If you find that this project (GDAS) helps your research, please cite the paper:
|
||||
# Citation
|
||||
If you find that this project helps your research, please consider citing some of the following papers:
|
||||
```
|
||||
@inproceedings{dong2019tas,
|
||||
title = {Network Pruning via Transformable Architecture Search},
|
||||
author = {Dong, Xuanyi and Yang, Yi},
|
||||
booktitle = {Neural Information Processing Systems (NeurIPS)},
|
||||
year = {2019}
|
||||
}
|
||||
@inproceedings{dong2019one,
|
||||
title = {One-Shot Neural Architecture Search via Self-Evaluated Template Network},
|
||||
author = {Dong, Xuanyi and Yang, Yi},
|
||||
booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
|
||||
year = {2019}
|
||||
}
|
||||
@inproceedings{dong2019search,
|
||||
title={Searching for A Robust Neural Architecture in Four GPU Hours},
|
||||
author={Dong, Xuanyi and Yang, Yi},
|
||||
|
Reference in New Issue
Block a user