Add int search space
This commit is contained in:
@@ -24,11 +24,17 @@ from nats_bench import create
|
||||
def show_imagenet_16_120(dataset_dir=None):
|
||||
if dataset_dir is None:
|
||||
torch_home_dir = (
|
||||
os.environ["TORCH_HOME"] if "TORCH_HOME" in os.environ else os.path.join(os.environ["HOME"], ".torch")
|
||||
os.environ["TORCH_HOME"]
|
||||
if "TORCH_HOME" in os.environ
|
||||
else os.path.join(os.environ["HOME"], ".torch")
|
||||
)
|
||||
dataset_dir = os.path.join(torch_home_dir, "cifar.python", "ImageNet16")
|
||||
train_data, valid_data, xshape, class_num = get_datasets("ImageNet16-120", dataset_dir, -1)
|
||||
split_info = load_config("configs/nas-benchmark/ImageNet16-120-split.txt", None, None)
|
||||
train_data, valid_data, xshape, class_num = get_datasets(
|
||||
"ImageNet16-120", dataset_dir, -1
|
||||
)
|
||||
split_info = load_config(
|
||||
"configs/nas-benchmark/ImageNet16-120-split.txt", None, None
|
||||
)
|
||||
print("=" * 10 + " ImageNet-16-120 " + "=" * 10)
|
||||
print("Training Data: {:}".format(train_data))
|
||||
print("Evaluation Data: {:}".format(valid_data))
|
||||
@@ -45,8 +51,12 @@ if __name__ == "__main__":
|
||||
test_acc_200e = []
|
||||
for index in range(10000):
|
||||
info = api_nats_tss.get_more_info(index, "ImageNet16-120", hp="12")
|
||||
valid_acc_12e.append(info["valid-accuracy"]) # the validation accuracy after training the model by 12 epochs
|
||||
test_acc_12e.append(info["test-accuracy"]) # the test accuracy after training the model by 12 epochs
|
||||
valid_acc_12e.append(
|
||||
info["valid-accuracy"]
|
||||
) # the validation accuracy after training the model by 12 epochs
|
||||
test_acc_12e.append(
|
||||
info["test-accuracy"]
|
||||
) # the test accuracy after training the model by 12 epochs
|
||||
info = api_nats_tss.get_more_info(index, "ImageNet16-120", hp="200")
|
||||
test_acc_200e.append(
|
||||
info["test-accuracy"]
|
||||
|
Reference in New Issue
Block a user