update README
This commit is contained in:
@@ -26,9 +26,10 @@ It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default).
|
||||
|
||||
1. Creating an API instance from a file:
|
||||
```
|
||||
from nas_102_api import NASBench102API
|
||||
api = NASBench102API('$path_to_meta_nas_bench_file')
|
||||
api = NASBench102API('NAS-Bench-102-v1_0-e61699.pth')
|
||||
from nas_102_api import NASBench102API as API
|
||||
api = API('$path_to_meta_nas_bench_file')
|
||||
api = API('NAS-Bench-102-v1_0-e61699.pth')
|
||||
api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-102-v1_0-e61699.pth'))
|
||||
```
|
||||
|
||||
2. Show the number of architectures `len(api)` and each architecture `api[i]`:
|
||||
@@ -45,12 +46,12 @@ api.show(1)
|
||||
api.show(2)
|
||||
|
||||
# show the mean loss and accuracy of an architecture
|
||||
info = api.query_meta_info_by_index(1)
|
||||
res_metrics = info.get_metrics('cifar10', 'train')
|
||||
cost_metrics = info.get_comput_costs('cifar100')
|
||||
info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults`
|
||||
res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys
|
||||
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
|
||||
|
||||
# get the detailed information
|
||||
results = api.query_by_index(1, 'cifar100')
|
||||
results = api.query_by_index(1, 'cifar100') # a list of all trials on cifar100
|
||||
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
|
||||
print ('Latency : {:}'.format(results[0].get_latency()))
|
||||
print ('Train Info : {:}'.format(results[0].get_train()))
|
||||
|
||||
Reference in New Issue
Block a user