update
This commit is contained in:
@@ -223,20 +223,6 @@ def main():
|
||||
else:
|
||||
#score = score_loop(network, None, train_queue, args.gpu, None, args.proj_crit)
|
||||
network.requires_feature = False
|
||||
|
||||
if args.proj_crit == 'comb':
|
||||
synflow = predictive.find_measures(network,
|
||||
train_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=['synflow'])
|
||||
var = predictive.find_measures(network,
|
||||
train_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=['var'])
|
||||
comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1)
|
||||
measures = {'comb': comb}
|
||||
else:
|
||||
measures = predictive.find_measures(network,
|
||||
train_queue,
|
||||
|
Reference in New Issue
Block a user