add the idea of guidance
This commit is contained in:
@@ -8,6 +8,7 @@ import os
|
||||
import os.path as osp
|
||||
import pathlib
|
||||
import json
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -49,6 +50,9 @@ op_type = {
|
||||
'none': 5,
|
||||
'output': 6,
|
||||
}
|
||||
|
||||
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
|
||||
|
||||
class DataModule(AbstractDataModule):
|
||||
def __init__(self, cfg):
|
||||
self.datadir = cfg.dataset.datadir
|
||||
@@ -676,6 +680,52 @@ class Dataset(InMemoryDataset):
|
||||
|
||||
data_list = []
|
||||
len_data = len(self.api)
|
||||
def check_valid_graph(nodes, edges):
|
||||
if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
|
||||
return False
|
||||
if nodes[0] != 'input' or nodes[-1] != 'output':
|
||||
return False
|
||||
for i in range(0, len(nodes)):
|
||||
if edges[i][i] == 1:
|
||||
return False
|
||||
for i in range(1, len(nodes) - 1):
|
||||
if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output':
|
||||
return False
|
||||
for i in range(0, len(nodes)):
|
||||
for j in range(i, len(nodes)):
|
||||
if edges[i, j] == 1 and nodes[j] == 'input':
|
||||
return False
|
||||
for i in range(0, len(nodes)):
|
||||
for j in range(i, len(nodes)):
|
||||
if edges[i, j] == 1 and nodes[i] == 'output':
|
||||
return False
|
||||
flag = 0
|
||||
for i in range(0,len(nodes)):
|
||||
if edges[i,-1] == 1:
|
||||
flag = 1
|
||||
break
|
||||
if flag == 0: return False
|
||||
return True
|
||||
|
||||
def generate_flex_adj_mat(ori_nodes, ori_edges, max_nodes=12, min_nodes=8,random_ratio=0.5):
|
||||
nasbench_201_node_num = 8
|
||||
# random.seed(random_seed)
|
||||
nodes_num = random.randint(min_nodes, max_nodes)
|
||||
# print(f'arch_str: {arch_str}, \nmax_nodes: {max_nodes}, min_nodes: {min_nodes}, nodes_num: {nodes_num},random_seed: {random_seed},random_ratio: {random_ratio}')
|
||||
add_num = nodes_num - nasbench_201_node_num
|
||||
# ori_nodes, ori_edges = parse_architecture_string(arch_str)
|
||||
add_nodes = [op for op in random.choices(num_to_op[1:-1], k=add_num)]
|
||||
# print(add_nodes)
|
||||
nodes = ori_nodes[:-1] + add_nodes + ['output']
|
||||
edges = np.zeros((nodes_num , nodes_num))
|
||||
edges[:6, :6] = ori_edges[:6, :6]
|
||||
edges[0:8, -1] = ori_edges[0:8 , -1]
|
||||
for i in range(0, nodes_num):
|
||||
for j in range(max(7,i + 1), nodes_num):
|
||||
rand = random.random()
|
||||
if rand < random_ratio:
|
||||
edges[i, j] = 1
|
||||
return nodes, edges
|
||||
|
||||
def graph_to_graph_data(graph):
|
||||
ops = graph[1]
|
||||
@@ -746,6 +796,9 @@ class Dataset(InMemoryDataset):
|
||||
})
|
||||
data = graph_to_graph_data((adj_matrix, ops))
|
||||
data_list.append(data)
|
||||
|
||||
# new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ops, ori_edges=adj_matrix, max_nodes=12, min_nodes=8, random_ratio=0.5)
|
||||
# data_list.append(graph_to_graph_data((new_adj, new_ops)))
|
||||
pbar.update(1)
|
||||
|
||||
for graph in graph_list:
|
||||
|
Reference in New Issue
Block a user