Add get_torch_home func for NATS-Bench
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --use_proxy 0
|
||||
##################################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np, collections
|
||||
@@ -119,10 +120,8 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
|
||||
while len(population) < population_size:
|
||||
model = Model()
|
||||
model.arch = random_arch()
|
||||
if use_proxy:
|
||||
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12')
|
||||
else:
|
||||
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp=api.full_train_epochs)
|
||||
model.accuracy, _, _, total_cost = api.simulate_train_eval(
|
||||
model.arch, dataset, hp='12' if use_proxy else api.full_train_epochs)
|
||||
# Append the info
|
||||
population.append(model)
|
||||
history.append((model.accuracy, model.arch))
|
||||
@@ -146,7 +145,8 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
|
||||
# Create the child model and store it.
|
||||
child = Model()
|
||||
child.arch = mutate_arch(parent.arch)
|
||||
child.accuracy, _, _, total_cost = api.simulate_train_eval(child.arch, dataset, hp='12')
|
||||
child.accuracy, _, _, total_cost = api.simulate_train_eval(
|
||||
child.arch, dataset, hp='12' if use_proxy else api.full_train_epochs)
|
||||
# Append the info
|
||||
population.append(child)
|
||||
history.append((child.accuracy, child.arch))
|
||||
|
Reference in New Issue
Block a user