add a datsets option to specify the datset you want, add a plot script
This commit is contained in:
@@ -39,6 +39,7 @@ parser.add_argument('--seed', default=0, type=int, help='random seed')
|
||||
parser.add_argument('--device', default="cuda", type=str, nargs='?', help='setup device (cpu, mps or cuda)')
|
||||
parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric')
|
||||
parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric')
|
||||
parser.add_argument('--datasets', default='cifar10', type=str, help='input datasets')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -48,7 +49,7 @@ if __name__ == "__main__":
|
||||
|
||||
# arch_info = pd.read_csv(args.data_path+'/DARTS_archs_CIFAR10.csv', names=['genotype', 'valid_acc'], sep=',')
|
||||
|
||||
train_data, _, _ = get_datasets('cifar10', args.data_path, (args.input_samples, 3, 32, 32), -1)
|
||||
train_data, _, _ = get_datasets(args.datasets, args.data_path, (args.input_samples, 3, 32, 32), -1)
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.input_samples, num_workers=0, pin_memory=True)
|
||||
loader = iter(train_loader)
|
||||
inputs, _ = next(loader)
|
||||
@@ -63,11 +64,11 @@ if __name__ == "__main__":
|
||||
# print(f'Evaluating network: {index}')
|
||||
print(f'Evaluating network: {ind}')
|
||||
|
||||
config = api.get_net_config(ind, 'cifar10')
|
||||
config = api.get_net_config(ind, args.datasets)
|
||||
network = get_cell_based_tiny_net(config)
|
||||
# nas_results = api.query_by_index(i, 'cifar10')
|
||||
# acc = nas_results[111].get_eval('ori-test')
|
||||
nas_results = api.get_more_info(ind, 'cifar10', None, hp=200, is_random=False)
|
||||
nas_results = api.get_more_info(ind, args.datasets, None, hp=200, is_random=False)
|
||||
acc = nas_results['test-accuracy']
|
||||
|
||||
# print(type(network))
|
||||
|
||||
Reference in New Issue
Block a user