update NAS-Bench-102
This commit is contained in:
@@ -16,7 +16,8 @@ Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.
|
||||
|
||||
The benchmark file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w).
|
||||
You can move it to anywhere you want and send its path to our API for initialization.
|
||||
- v1.0: `NAS-Bench-102-v1_0-e61699.pth`, where `e61699` is the last six digits for this file.
|
||||
- v1.0: `NAS-Bench-102-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
|
||||
- v1.0: The full data of each architecture can be download from [Google Drive](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
|
||||
|
||||
The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ).
|
||||
It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). If you want to generate NAS-Bench-102 or similar NAS datasets or training models by yourself, you need these data.
|
||||
@@ -108,8 +109,12 @@ print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True)) # print loss
|
||||
`NASBench102API` is the topest level api. Please see the following usages:
|
||||
```
|
||||
from nas_102_api import NASBench102API as API
|
||||
api = API('NAS-Bench-102-v1_0-e61699.pth')
|
||||
api = API('NAS-Bench-102-v1_0-e61699.pth') # This will load all the information of NAS-Bench-102 except the trained weights
|
||||
api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-102-v1_0-e61699.pth')) # The same as the above line while I usually save NAS-Bench-102-v1_0-e61699.pth in ~/.torch/.
|
||||
api.show(-1) # show info of all architectures
|
||||
api.reload('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-BENCH-102-4-v1.0-archive'), 3) # This code will reload the information 3-th architecture with the trained weights
|
||||
|
||||
weights = api.get_net_param(3, 'cifar10', None) # Obtaining the weights of all trials for the 3-th architecture on cifar10. It will returns a dict, where the key is the seed and the value is the trained weights.
|
||||
```
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user