Create NATS
This commit is contained in:
@@ -27,7 +27,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
||||
from utils import count_parameters_in_MB, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nas_201_api import NASBench301API as API
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
# Ad-hoc for TuNAS
|
||||
@@ -176,7 +176,7 @@ def main(xargs):
|
||||
logger.log('The parameters of the search model = {:.2f} MB'.format(params))
|
||||
logger.log('search-space : {:}'.format(search_space))
|
||||
if bool(xargs.use_api):
|
||||
api = API(verbose=False)
|
||||
api = create(None, 'size', verbose=False)
|
||||
else:
|
||||
api = None
|
||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||
@@ -291,7 +291,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
dirname = '{:}-affine{:}_BN{:}'.format(args.algo, args.affine, args.track_running_stats)
|
||||
dirname = '{:}-affine{:}_BN{:}-AWD{:}'.format(args.algo, args.affine, args.track_running_stats, args.arch_weight_decay)
|
||||
if args.overwite_epochs is not None:
|
||||
dirname = dirname + '-E{:}'.format(args.overwite_epochs)
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, dirname)
|
||||
|
Reference in New Issue
Block a user