add autodl

This commit is contained in:
mhz
2024-08-25 18:02:31 +02:00
parent 192f286cfb
commit a0a25f291c
431 changed files with 50646 additions and 8 deletions

View File

@@ -10,6 +10,21 @@ from src.utils.utilities import *
from src.metrics.swap import SWAP
from src.datasets.utilities import get_datasets
from src.search_space.networks import *
import time
# NASBench-201
from nas_201_api import NASBench201API as API
# xautodl
from xautodl.models import get_cell_based_tiny_net
# initalize nasbench-201
nas_201_path = 'datasets/NAS-Bench-201-v1_1-096897.pth'
print(f'Loading NAS-Bench-201 from {nas_201_path}')
start_time = time.time()
api = API(nas_201_path)
end_time = time.time()
print(f'Loaded NAS-Bench-201 in {end_time - start_time:.2f} seconds')
# Settings for console outputs
import warnings
@@ -20,8 +35,8 @@ parser = argparse.ArgumentParser()
# general setting
parser.add_argument('--data_path', default="datasets", type=str, nargs='?', help='path to the image dataset (datasets or datasets/ILSVRC/Data/CLS-LOC)')
parser.add_argument('--seed', default=0, type=int, help='random seed')
parser.add_argument('--device', default="cuda:2", type=str, nargs='?', help='setup device (cpu, mps or cuda)')
parser.add_argument('--seed', default=111, type=int, help='random seed')
parser.add_argument('--device', default="cuda:1", type=str, nargs='?', help='setup device (cpu, mps or cuda)')
parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric')
parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric')
@@ -31,7 +46,7 @@ if __name__ == "__main__":
device = torch.device(args.device)
arch_info = pd.read_csv(args.data_path+'/DARTS_archs_CIFAR10.csv', names=['genotype', 'valid_acc'], sep=',')
# arch_info = pd.read_csv(args.data_path+'/DARTS_archs_CIFAR10.csv', names=['genotype', 'valid_acc'], sep=',')
train_data, _, _ = get_datasets('cifar10', args.data_path, (args.input_samples, 3, 32, 32), -1)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.input_samples, num_workers=0, pin_memory=True)
@@ -39,24 +54,46 @@ if __name__ == "__main__":
inputs, _ = next(loader)
results = []
nasbench_len = 15625
for index, i in arch_info.iterrows():
print(f'Evaluating network: {index}')
# for index, i in arch_info.iterrows():
for i in range(nasbench_len):
# print(f'Evaluating network: {index}')
print(f'Evaluating network: {i}')
network = Network(3, 10, 1, eval(i.genotype))
config = api.get_net_config(i, 'cifar10')
network = get_cell_based_tiny_net(config)
nas_results = api.query_by_index(i, 'cifar10')
acc = nas_results[111].get_eval('ori-test')
print(type(network))
start_time = time.time()
# network = Network(3, 10, 1, eval(i.genotype))
network = network.to(device)
end_time = time.time()
print(f'Loaded network in {end_time - start_time:.2f} seconds')
print(f'initiliazing SWAP')
swap = SWAP(model=network, inputs=inputs, device=device, seed=args.seed)
swap_score = []
for _ in range(args.repeats):
print(f'Calculating SWAP score')
start_time = time.time()
for i in range(args.repeats):
print(f'Iteration: {i+1}/{args.repeats}', end='\r')
network = network.apply(network_weight_gaussian_init)
swap.reinit()
swap_score.append(swap.forward())
swap.clear()
end_time = time.time()
print(f'Average SWAP score: {np.mean(swap_score)}')
print(f'Elapsed time: {end_time - start_time:.2f} seconds')
results.append([np.mean(swap_score), i.valid_acc])
results.append([np.mean(swap_score), acc])
results = pd.DataFrame(results, columns=['swap_score', 'valid_acc'])
print()