update vis
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
# python ./exps/vis/test.py
|
||||
import os, sys
|
||||
import os, sys, random
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from graphviz import Digraph
|
||||
|
||||
|
||||
def test_nas_api():
|
||||
@@ -23,5 +24,35 @@ def test_nas_api():
|
||||
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True))
|
||||
print(archRes.query('cifar10-valid', 777))
|
||||
|
||||
|
||||
OPS = ['skip-connect', 'conv-1x1', 'conv-3x3', 'pool-3x3']
|
||||
COLORS = ['chartreuse' , 'cyan' , 'navyblue', 'chocolate1']
|
||||
|
||||
def plot(filename):
|
||||
g = Digraph(
|
||||
format='png',
|
||||
edge_attr=dict(fontsize='20', fontname="times"),
|
||||
node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
|
||||
engine='dot')
|
||||
g.body.extend(['rankdir=LR'])
|
||||
|
||||
steps = 5
|
||||
for i in range(0, steps):
|
||||
if i == 0:
|
||||
g.node(str(i), fillcolor='darkseagreen2')
|
||||
elif i+1 == steps:
|
||||
g.node(str(i), fillcolor='palegoldenrod')
|
||||
else: g.node(str(i), fillcolor='lightblue')
|
||||
|
||||
for i in range(1, steps):
|
||||
for xin in range(i):
|
||||
op_i = random.randint(0, len(OPS)-1)
|
||||
#g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
|
||||
g.edge(str(xin), str(i), label=OPS[op_i], color=COLORS[op_i], fillcolor=COLORS[op_i])
|
||||
#import pdb; pdb.set_trace()
|
||||
g.render(filename, cleanup=True, view=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_nas_api()
|
||||
for i in range(200): plot('{:04d}'.format(i))
|
||||
|
Reference in New Issue
Block a user