update print and output json statements
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
### packages for visualization
|
||||
from analysis.rdkit_functions import compute_molecular_metrics
|
||||
from analysis.rdkit_functions import compute_graph_metrics
|
||||
from mini_moses.metrics.metrics import compute_intermediate_statistics
|
||||
from metrics.property_metric import TaskModel
|
||||
|
||||
@@ -49,8 +50,8 @@ class SamplingGraphMetrics(nn.Module):
|
||||
|
||||
self.task_evaluator = {
|
||||
'meta_taskname': dataset_infos.task,
|
||||
'sas': None,
|
||||
'scs': None
|
||||
# 'sas': None,
|
||||
# 'scs': None
|
||||
}
|
||||
|
||||
for cur_task in dataset_infos.task.split("-")[:]:
|
||||
@@ -62,13 +63,14 @@ class SamplingGraphMetrics(nn.Module):
|
||||
self.task_evaluator[cur_task] = evaluator
|
||||
|
||||
def forward(self, graphs, targets, name, current_epoch, val_counter, test=False):
|
||||
test = True
|
||||
if isinstance(targets, list):
|
||||
targets_cat = torch.cat(targets, dim=0)
|
||||
targets_np = targets_cat.detach().cpu().numpy()
|
||||
else:
|
||||
targets_np = targets.detach().cpu().numpy()
|
||||
|
||||
unique_graphs, all_graphs, all_graphs, targets_log = compute_molecular_metrics(
|
||||
unique_graphs, all_graphs, all_metrics, targets_log = compute_graph_metrics(
|
||||
graphs,
|
||||
targets_np,
|
||||
self.train_graphs,
|
||||
@@ -77,6 +79,22 @@ class SamplingGraphMetrics(nn.Module):
|
||||
self.task_evaluator,
|
||||
self.compute_config,
|
||||
)
|
||||
print(f"all graphs: {all_graphs}")
|
||||
print(f"all graphs[0]: {all_graphs[0]}")
|
||||
tmp_graphs = all_graphs.copy()
|
||||
str_graphs = []
|
||||
for graph in tmp_graphs:
|
||||
node_types = graph[0]
|
||||
edge_types = graph[1]
|
||||
node_str = " ".join([str(node) for node in node_types])
|
||||
edge_str_list = []
|
||||
for i in range(len(node_types)):
|
||||
for j in range(len(node_types)):
|
||||
edge_str_list.append(str(edge_types[i][j]))
|
||||
edge_str_list.append("/n")
|
||||
edge_str = " ".join(edge_str_list)
|
||||
str_graphs.append(f"nodes: {node_str} /n edges: /n{edge_str}")
|
||||
|
||||
|
||||
if test:
|
||||
file_name = "final_graphs.txt"
|
||||
@@ -88,7 +106,7 @@ class SamplingGraphMetrics(nn.Module):
|
||||
|
||||
all_tasks_str = "graph, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name])
|
||||
fp.write(all_tasks_str + "\n")
|
||||
for i, graph in enumerate(all_graphs):
|
||||
for i, graph in enumerate(str_graphs):
|
||||
if targets_log is not None:
|
||||
all_result_str = f"{graph}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name])
|
||||
fp.write(all_result_str + "\n")
|
||||
@@ -107,7 +125,7 @@ class SamplingGraphMetrics(nn.Module):
|
||||
textfile.write(graph + "\n")
|
||||
textfile.close()
|
||||
|
||||
all_logs = all_graphs
|
||||
all_logs = all_metrics
|
||||
if test:
|
||||
all_logs["log_name"] = "test"
|
||||
else:
|
||||
@@ -116,7 +134,7 @@ class SamplingGraphMetrics(nn.Module):
|
||||
)
|
||||
|
||||
result_to_csv("output.csv", all_logs)
|
||||
return all_graphs
|
||||
return str_graphs
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user