change the device
This commit is contained in:
@@ -10,6 +10,21 @@ from src.utils.utilities import *
|
||||
from src.metrics.swap import SWAP
|
||||
from src.datasets.utilities import get_datasets
|
||||
from src.search_space.networks import *
|
||||
import time
|
||||
|
||||
# NASBench-201
|
||||
from nas_201_api import NASBench201API as API
|
||||
|
||||
# xautodl
|
||||
from xautodl.models import get_cell_based_tiny_net
|
||||
|
||||
# initalize nasbench-201
|
||||
nas_201_path = 'datasets/NAS-Bench-201-v1_1-096897.pth'
|
||||
print(f'Loading NAS-Bench-201 from {nas_201_path}')
|
||||
start_time = time.time()
|
||||
api = API(nas_201_path)
|
||||
end_time = time.time()
|
||||
print(f'Loaded NAS-Bench-201 in {end_time - start_time:.2f} seconds')
|
||||
|
||||
# Settings for console outputs
|
||||
import warnings
|
||||
@@ -31,7 +46,7 @@ if __name__ == "__main__":
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
arch_info = pd.read_csv(args.data_path+'/DARTS_archs_CIFAR10.csv', names=['genotype', 'valid_acc'], sep=',')
|
||||
# arch_info = pd.read_csv(args.data_path+'/DARTS_archs_CIFAR10.csv', names=['genotype', 'valid_acc'], sep=',')
|
||||
|
||||
train_data, _, _ = get_datasets('cifar10', args.data_path, (args.input_samples, 3, 32, 32), -1)
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.input_samples, num_workers=0, pin_memory=True)
|
||||
@@ -39,24 +54,46 @@ if __name__ == "__main__":
|
||||
inputs, _ = next(loader)
|
||||
|
||||
results = []
|
||||
|
||||
nasbench_len = 15625
|
||||
|
||||
for index, i in arch_info.iterrows():
|
||||
print(f'Evaluating network: {index}')
|
||||
# for index, i in arch_info.iterrows():
|
||||
for i in range(nasbench_len):
|
||||
# print(f'Evaluating network: {index}')
|
||||
print(f'Evaluating network: {i}')
|
||||
|
||||
network = Network(3, 10, 1, eval(i.genotype))
|
||||
config = api.get_net_config(i, 'cifar10')
|
||||
network = get_cell_based_tiny_net(config)
|
||||
nas_results = api.query_by_index(i, 'cifar10')
|
||||
acc = nas_results[111].get_eval('ori-test')
|
||||
|
||||
print(type(network))
|
||||
start_time = time.time()
|
||||
|
||||
# network = Network(3, 10, 1, eval(i.genotype))
|
||||
network = network.to(device)
|
||||
|
||||
end_time = time.time()
|
||||
print(f'Loaded network in {end_time - start_time:.2f} seconds')
|
||||
|
||||
print(f'initiliazing SWAP')
|
||||
swap = SWAP(model=network, inputs=inputs, device=device, seed=args.seed)
|
||||
|
||||
swap_score = []
|
||||
|
||||
for _ in range(args.repeats):
|
||||
print(f'Calculating SWAP score')
|
||||
start_time = time.time()
|
||||
for i in range(args.repeats):
|
||||
print(f'Iteration: {i+1}/{args.repeats}', end='\r')
|
||||
network = network.apply(network_weight_gaussian_init)
|
||||
swap.reinit()
|
||||
swap_score.append(swap.forward())
|
||||
swap.clear()
|
||||
end_time = time.time()
|
||||
print(f'Average SWAP score: {np.mean(swap_score)}')
|
||||
print(f'Elapsed time: {end_time - start_time:.2f} seconds')
|
||||
|
||||
results.append([np.mean(swap_score), i.valid_acc])
|
||||
results.append([np.mean(swap_score), acc])
|
||||
|
||||
results = pd.DataFrame(results, columns=['swap_score', 'valid_acc'])
|
||||
print()
|
||||
|
||||
Reference in New Issue
Block a user