Add get_torch_home func for NATS-Bench

This commit is contained in:
D-X-Y
2020-12-01 22:25:23 +08:00
parent 8afb62ad2e
commit 46b92e37e2
7 changed files with 294 additions and 10 deletions

View File

@@ -9,6 +9,7 @@
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --use_proxy 0
##################################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
@@ -119,10 +120,8 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
while len(population) < population_size:
model = Model()
model.arch = random_arch()
if use_proxy:
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12')
else:
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp=api.full_train_epochs)
model.accuracy, _, _, total_cost = api.simulate_train_eval(
model.arch, dataset, hp='12' if use_proxy else api.full_train_epochs)
# Append the info
population.append(model)
history.append((model.accuracy, model.arch))
@@ -146,7 +145,8 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
# Create the child model and store it.
child = Model()
child.arch = mutate_arch(parent.arch)
child.accuracy, _, _, total_cost = api.simulate_train_eval(child.arch, dataset, hp='12')
child.accuracy, _, _, total_cost = api.simulate_train_eval(
child.arch, dataset, hp='12' if use_proxy else api.full_train_epochs)
# Append the info
population.append(child)
history.append((child.accuracy, child.arch))