update README

This commit is contained in:
D-X-Y
2019-12-28 15:42:36 +11:00
parent d791622b63
commit 4c144b7437
6 changed files with 59 additions and 28 deletions

View File

@@ -26,9 +26,10 @@ It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default).
1. Creating an API instance from a file:
```
from nas_102_api import NASBench102API
api = NASBench102API('$path_to_meta_nas_bench_file')
api = NASBench102API('NAS-Bench-102-v1_0-e61699.pth')
from nas_102_api import NASBench102API as API
api = API('$path_to_meta_nas_bench_file')
api = API('NAS-Bench-102-v1_0-e61699.pth')
api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-102-v1_0-e61699.pth'))
```
2. Show the number of architectures `len(api)` and each architecture `api[i]`:
@@ -45,12 +46,12 @@ api.show(1)
api.show(2)
# show the mean loss and accuracy of an architecture
info = api.query_meta_info_by_index(1)
res_metrics = info.get_metrics('cifar10', 'train')
cost_metrics = info.get_comput_costs('cifar100')
info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults`
res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
# get the detailed information
results = api.query_by_index(1, 'cifar100')
results = api.query_by_index(1, 'cifar100') # a list of all trials on cifar100
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
print ('Latency : {:}'.format(results[0].get_latency()))
print ('Train Info : {:}'.format(results[0].get_train()))