To answer issue #119
This commit is contained in:
@@ -24,6 +24,9 @@
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
####
|
||||
# The following scripts are added in 20 Mar 2022
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas_v1 --rand_seed 777
|
||||
######################################################################################
|
||||
import os, sys, time, random, argparse
|
||||
import numpy as np
|
||||
@@ -166,6 +169,8 @@ def search_func(
|
||||
network.set_cal_mode("dynamic", sampled_arch)
|
||||
elif algo == "gdas":
|
||||
network.set_cal_mode("gdas", None)
|
||||
elif algo == "gdas_v1":
|
||||
network.set_cal_mode("gdas_v1", None)
|
||||
elif algo.startswith("darts"):
|
||||
network.set_cal_mode("joint", None)
|
||||
elif algo == "random":
|
||||
@@ -196,6 +201,8 @@ def search_func(
|
||||
network.set_cal_mode("joint")
|
||||
elif algo == "gdas":
|
||||
network.set_cal_mode("gdas", None)
|
||||
elif algo == "gdas_v1":
|
||||
network.set_cal_mode("gdas_v1", None)
|
||||
elif algo.startswith("darts"):
|
||||
network.set_cal_mode("joint", None)
|
||||
elif algo == "random":
|
||||
@@ -373,7 +380,7 @@ def get_best_arch(xloader, network, n_samples, algo):
|
||||
archs, valid_accs = network.return_topK(n_samples, True), []
|
||||
elif algo == "setn":
|
||||
archs, valid_accs = network.return_topK(n_samples, False), []
|
||||
elif algo.startswith("darts") or algo == "gdas":
|
||||
elif algo.startswith("darts") or algo == "gdas" or algo == "gdas_v1":
|
||||
arch = network.genotype
|
||||
archs, valid_accs = [arch], []
|
||||
elif algo == "enas":
|
||||
@@ -568,7 +575,7 @@ def main(xargs):
|
||||
)
|
||||
|
||||
network.set_drop_path(float(epoch + 1) / total_epoch, xargs.drop_path_rate)
|
||||
if xargs.algo == "gdas":
|
||||
if xargs.algo == "gdas" or xargs.algo == "gdas_v1":
|
||||
network.set_tau(
|
||||
xargs.tau_max
|
||||
- (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)
|
||||
@@ -632,6 +639,8 @@ def main(xargs):
|
||||
network.set_cal_mode("dynamic", genotype)
|
||||
elif xargs.algo == "gdas":
|
||||
network.set_cal_mode("gdas", None)
|
||||
elif xargs.algo == "gdas_v1":
|
||||
network.set_cal_mode("gdas_v1", None)
|
||||
elif xargs.algo.startswith("darts"):
|
||||
network.set_cal_mode("joint", None)
|
||||
elif xargs.algo == "random":
|
||||
@@ -699,6 +708,8 @@ def main(xargs):
|
||||
network.set_cal_mode("dynamic", genotype)
|
||||
elif xargs.algo == "gdas":
|
||||
network.set_cal_mode("gdas", None)
|
||||
elif xargs.algo == "gdas_v1":
|
||||
network.set_cal_mode("gdas_v1", None)
|
||||
elif xargs.algo.startswith("darts"):
|
||||
network.set_cal_mode("joint", None)
|
||||
elif xargs.algo == "random":
|
||||
@@ -747,7 +758,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--algo",
|
||||
type=str,
|
||||
choices=["darts-v1", "darts-v2", "gdas", "setn", "random", "enas"],
|
||||
choices=["darts-v1", "darts-v2", "gdas", "gdas_v1", "setn", "random", "enas"],
|
||||
help="The search space name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
Reference in New Issue
Block a user