support get_net_config for NAS-Bench-201
This commit is contained in:
@@ -72,7 +72,16 @@ index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1
|
||||
api.show(index)
|
||||
```
|
||||
|
||||
5. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201.
|
||||
5. Create the network from api:
|
||||
```
|
||||
config = api.get_net_config(123, 'cifar10') # obtain the network configuration for the 123-th architecture on the CIFAR-10 dataset
|
||||
from models import get_cell_based_tiny_net # this module is in AutoDL-Projects/lib/models
|
||||
network = get_cell_based_tiny_net(config) # create the network from configurration
|
||||
print(network) # show the structure of this architecture
|
||||
```
|
||||
If you want to load the trained weights of this created network, you need to use `api.get_net_param(123, ...)` to obtain the weights and then load it to the network.
|
||||
|
||||
6. For other usages, please see `lib/nas_201_api/api.py`. We provide some usage information in the comments for the corresponding functions. If what you want is not provided, please feel free to open an issue for discussion, and I am happy to answer any questions regarding NAS-Bench-201.
|
||||
|
||||
|
||||
### Detailed Instruction
|
||||
|
Reference in New Issue
Block a user