From 551abc31f370ada75b8349a5ab0a13be8c4a405e Mon Sep 17 00:00:00 2001 From: Mhrooz Date: Wed, 28 Aug 2024 17:11:17 +0200 Subject: [PATCH] add a datsets option to specify the datset you want, add a plot script --- analyze.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ correlation.py | 7 ++++--- 2 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 analyze.py diff --git a/analyze.py b/analyze.py new file mode 100644 index 0000000..d809373 --- /dev/null +++ b/analyze.py @@ -0,0 +1,48 @@ +import csv +import matplotlib.pyplot as plt +from scipy import stats +import pandas as pd + +def plot(l): + labels = ['0-10k', '10k-20k,', '20k-30k', '30k-40k', '40k-50k', '50k-60k', '60k-70k'] + l = [i/15625 for i in l] + l = l[:7] + plt.bar(labels, l) + plt.savefig('plot.png') + +def analyse(filename): + l = [0 for i in range(10)] + scores = [] + count = 0 + best_value = -1 + with open(filename) as file: + reader = csv.reader(file) + header = next(reader) + data = [row for row in reader] + + for row in data: + score = row[0] + best_value = max(best_value, float(score)) + # print(score) + ind = float(score) // 10000 + ind = int(ind) + l[ind] += 1 + acc = row[1] + index = row[2] + datas = list(zip(score, acc, index)) + scores.append(score) + print(max(scores)) + results = pd.DataFrame(datas, columns=['swap_score', 'valid_acc', 'index']) + print(results['swap_score'].max()) + print(best_value) + plot(l) + return stats.spearmanr(results.swap_score, results.valid_acc)[0] + +if __name__ == '__main__': + print(analyse('output/swap_results.csv')) + + + + + + diff --git a/correlation.py b/correlation.py index c45f54b..49adb47 100644 --- a/correlation.py +++ b/correlation.py @@ -39,6 +39,7 @@ parser.add_argument('--seed', default=0, type=int, help='random seed') parser.add_argument('--device', default="cuda", type=str, nargs='?', help='setup device (cpu, mps or cuda)') parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric') parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric') +parser.add_argument('--datasets', default='cifar10', type=str, help='input datasets') args = parser.parse_args() @@ -48,7 +49,7 @@ if __name__ == "__main__": # arch_info = pd.read_csv(args.data_path+'/DARTS_archs_CIFAR10.csv', names=['genotype', 'valid_acc'], sep=',') - train_data, _, _ = get_datasets('cifar10', args.data_path, (args.input_samples, 3, 32, 32), -1) + train_data, _, _ = get_datasets(args.datasets, args.data_path, (args.input_samples, 3, 32, 32), -1) train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.input_samples, num_workers=0, pin_memory=True) loader = iter(train_loader) inputs, _ = next(loader) @@ -63,11 +64,11 @@ if __name__ == "__main__": # print(f'Evaluating network: {index}') print(f'Evaluating network: {ind}') - config = api.get_net_config(ind, 'cifar10') + config = api.get_net_config(ind, args.datasets) network = get_cell_based_tiny_net(config) # nas_results = api.query_by_index(i, 'cifar10') # acc = nas_results[111].get_eval('ori-test') - nas_results = api.get_more_info(ind, 'cifar10', None, hp=200, is_random=False) + nas_results = api.get_more_info(ind, args.datasets, None, hp=200, is_random=False) acc = nas_results['test-accuracy'] # print(type(network))