Add more algorithms
This commit is contained in:
69
others/GDAS/exps-cnn/vis-arch.py
Normal file
69
others/GDAS/exps-cnn/vis-arch.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from graphviz import Digraph
|
||||
|
||||
parser = argparse.ArgumentParser("Visualize the Networks")
|
||||
parser.add_argument('--checkpoint', type=str, help='The path to the checkpoint.')
|
||||
parser.add_argument('--save_dir', type=str, help='The directory to save the network plot.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def plot(genotype, filename):
|
||||
g = Digraph(
|
||||
format='pdf',
|
||||
edge_attr=dict(fontsize='20', fontname="times"),
|
||||
node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
|
||||
engine='dot')
|
||||
g.body.extend(['rankdir=LR'])
|
||||
|
||||
g.node("c_{k-2}", fillcolor='darkseagreen2')
|
||||
g.node("c_{k-1}", fillcolor='darkseagreen2')
|
||||
assert len(genotype) % 2 == 0
|
||||
steps = len(genotype) // 2
|
||||
|
||||
for i in range(steps):
|
||||
g.node(str(i), fillcolor='lightblue')
|
||||
|
||||
for i in range(steps):
|
||||
for k in [2*i, 2*i + 1]:
|
||||
op, j, weight = genotype[k]
|
||||
if j == 0:
|
||||
u = "c_{k-2}"
|
||||
elif j == 1:
|
||||
u = "c_{k-1}"
|
||||
else:
|
||||
u = str(j-2)
|
||||
v = str(i)
|
||||
g.edge(u, v, label=op, fillcolor="gray")
|
||||
|
||||
g.node("c_{k}", fillcolor='palegoldenrod')
|
||||
for i in range(steps):
|
||||
g.edge(str(i), "c_{k}", fillcolor="gray")
|
||||
|
||||
g.render(filename, view=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
checkpoint = args.checkpoint
|
||||
assert os.path.isfile(checkpoint), 'Invalid path for checkpoint : {:}'.format(checkpoint)
|
||||
checkpoint = torch.load( checkpoint, map_location='cpu' )
|
||||
genotypes = checkpoint['genotypes']
|
||||
save_dir = Path(args.save_dir)
|
||||
subs = ['normal', 'reduce']
|
||||
for sub in subs:
|
||||
if not (save_dir / sub).exists():
|
||||
(save_dir / sub).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for key, network in genotypes.items():
|
||||
save_path = str(save_dir / 'normal' / 'epoch-{:03d}'.format( int(key) ))
|
||||
print('save into {:}'.format(save_path))
|
||||
plot(network.normal, save_path)
|
||||
|
||||
save_path = str(save_dir / 'reduce' / 'epoch-{:03d}'.format( int(key) ))
|
||||
print('save into {:}'.format(save_path))
|
||||
plot(network.reduce, save_path)
|
Reference in New Issue
Block a user