first commit
This commit is contained in:
58
MobileNetV3/evaluation/evaluator.py
Normal file
58
MobileNetV3/evaluation/evaluator.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import networkx as nx
|
||||
from .structure_evaluator import mmd_eval
|
||||
from .gin_evaluator import nn_based_eval
|
||||
from torch_geometric.utils import to_networkx
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import dgl
|
||||
|
||||
|
||||
def get_stats_eval(config):
|
||||
|
||||
if config.eval.mmd_distance.lower() == 'rbf':
|
||||
method = [('degree', 1., 'argmax'), ('cluster', 0.1, 'argmax'),
|
||||
('spectral', 1., 'argmax')]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def eval_stats_fn(test_dataset, pred_graph_list):
|
||||
pred_G = [nx.from_numpy_matrix(pred_adj) for pred_adj in pred_graph_list]
|
||||
sub_pred_G = []
|
||||
if config.eval.max_subgraph:
|
||||
for G in pred_G:
|
||||
CGs = [G.subgraph(c) for c in nx.connected_components(G)]
|
||||
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
|
||||
sub_pred_G += [CGs[0]]
|
||||
pred_G = sub_pred_G
|
||||
|
||||
test_G = [to_networkx(test_dataset[i], to_undirected=True, remove_self_loops=True)
|
||||
for i in range(len(test_dataset))]
|
||||
results = mmd_eval(test_G, pred_G, method)
|
||||
return results
|
||||
|
||||
return eval_stats_fn
|
||||
|
||||
|
||||
def get_nn_eval(config):
|
||||
|
||||
if hasattr(config.eval, "N_gin"):
|
||||
N_gin = config.eval.N_gin
|
||||
else:
|
||||
N_gin = 10
|
||||
|
||||
def nn_eval_fn(test_dataset, pred_graph_list):
|
||||
pred_G = [nx.from_numpy_matrix(pred_adj) for pred_adj in pred_graph_list]
|
||||
sub_pred_G = []
|
||||
if config.eval.max_subgraph:
|
||||
for G in pred_G:
|
||||
CGs = [G.subgraph(c) for c in nx.connected_components(G)]
|
||||
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
|
||||
sub_pred_G += [CGs[0]]
|
||||
pred_G = sub_pred_G
|
||||
test_G = [to_networkx(test_dataset[i], to_undirected=True, remove_self_loops=True)
|
||||
for i in range(len(test_dataset))]
|
||||
|
||||
results = nn_based_eval(test_G, pred_G, N_gin)
|
||||
return results
|
||||
|
||||
return nn_eval_fn
|
Reference in New Issue
Block a user