Create NATS
This commit is contained in:
@@ -21,9 +21,9 @@ import matplotlib.ticker as ticker
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import dict2config, load_config
|
||||
from nas_201_api import NASBench201API, NASBench301API
|
||||
from log_utils import time_string
|
||||
from models import get_cell_based_tiny_net
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
def visualize_info(api, vis_save_dir, indicator):
|
||||
@@ -391,11 +391,11 @@ if __name__ == '__main__':
|
||||
to_save_dir = Path(args.save_dir)
|
||||
|
||||
datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
|
||||
api201 = NASBench201API(None, verbose=True)
|
||||
api201 = create(None, 'tss', verbose=True)
|
||||
for xdata in datasets:
|
||||
visualize_tss_info(api201, xdata, to_save_dir)
|
||||
|
||||
api301 = NASBench301API(None, verbose=True)
|
||||
api301 = create(None, 'size', verbose=True)
|
||||
for xdata in datasets:
|
||||
visualize_sss_info(api301, xdata, to_save_dir)
|
||||
|
||||
|
Reference in New Issue
Block a user