Create NATS
This commit is contained in:
@@ -24,8 +24,8 @@ from datasets import get_datasets, SearchDataset
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from nas_201_api import NASBench201API, NASBench301API
|
||||
from models import CellStructure, get_search_spaces
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
class PolicyTopology(nn.Module):
|
||||
@@ -192,12 +192,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.search_space == 'tss':
|
||||
api = NASBench201API(verbose=False)
|
||||
elif args.search_space == 'sss':
|
||||
api = NASBench301API(verbose=False)
|
||||
else:
|
||||
raise ValueError('Invalid search space : {:}'.format(args.search_space))
|
||||
api = create(None, args.search_space, verbose=False)
|
||||
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'REINFORCE-{:}'.format(args.learning_rate))
|
||||
print('save-dir : {:}'.format(args.save_dir))
|
||||
|
Reference in New Issue
Block a user