blob: 9b94de5fcdb14e43d1423579eeb03a6232900880 [file] [log] [blame]
from collections import defaultdict
from pycaffe2 import utils
import sys
import subprocess
try:
import pydot
except ImportError:
print ('Cannot import pydot, which is required for drawing a network. This '
'can usually be installed in python with "pip install pydot". Also, '
'pydot requires graphviz to convert dot files to pdf: in ubuntu, this '
'can usually be installed with "sudo apt-get install graphviz".')
print ('net_drawer will not run correctly. Please install the correct '
'dependencies.')
pydot = None
from caffe2.proto import caffe2_pb2
from google.protobuf import text_format
OP_STYLE = {'shape': 'box', 'color': '#0F9D58', 'style': 'filled',
'fontcolor': '#FFFFFF'}
BLOB_STYLE = {'shape': 'octagon'}
def GetPydotGraph(operators, name, rankdir='LR'):
graph = pydot.Dot(name, rankdir=rankdir)
pydot_nodes = {}
pydot_node_counts = defaultdict(int)
node_id = 0
for op_id, op in enumerate(operators):
if op.name:
op_node = pydot.Node(
'%s/%s (op#%d)' % (op.name, op.type, op_id), **OP_STYLE)
else:
op_node = pydot.Node(
'%s (op#%d)' % (op.type, op_id), **OP_STYLE)
graph.add_node(op_node)
# print 'Op: %s' % op.name
# print 'inputs: %s' % str(op.input)
# print 'outputs: %s' % str(op.output)
for input_name in op.input:
if input_name not in pydot_nodes:
input_node = pydot.Node(
input_name + str(pydot_node_counts[input_name]),
label=input_name, **BLOB_STYLE)
pydot_nodes[input_name] = input_node
else:
input_node = pydot_nodes[input_name]
graph.add_node(input_node)
graph.add_edge(pydot.Edge(input_node, op_node))
for output_name in op.output:
if output_name in pydot_nodes:
# we are overwriting an existing blob. need to updat the count.
pydot_node_counts[output_name] += 1
output_node = pydot.Node(
output_name + str(pydot_node_counts[output_name]),
label=output_name, **BLOB_STYLE)
pydot_nodes[output_name] = output_node
graph.add_node(output_node)
graph.add_edge(pydot.Edge(op_node, output_node))
return graph
def GetPydotGraphMinimal(operators, name, rankdir='LR'):
"""Different from GetPydotGraph, hide all blob nodes and only show op nodes.
"""
graph = pydot.Dot(name, rankdir=rankdir)
pydot_nodes = {}
blob_parents = {}
pydot_node_counts = defaultdict(int)
node_id = 0
for op_id, op in enumerate(operators):
if op.name:
op_node = pydot.Node(
'%s/%s (op#%d)' % (op.name, op.type, op_id), **OP_STYLE)
else:
op_node = pydot.Node(
'%s (op#%d)' % (op.type, op_id), **OP_STYLE)
graph.add_node(op_node)
for input_name in op.input:
if input_name in blob_parents:
graph.add_edge(pydot.Edge(blob_parents[input_name], op_node))
for output_name in op.output:
blob_parents[output_name] = op_node
return graph
def GetOperatorMapForPlan(plan_def):
graphs = {}
for net_id, net in enumerate(plan_def.networks):
if net.HasField('name'):
graphs[plan_def.name + "_" + net.name] = net.operators
else:
graphs[plan_def.name + "_network_%d" % net_id] = net.operators
return graphs
def main():
with open(sys.argv[1], 'r') as fid:
content = fid.read()
graphs = utils.GetContentFromProtoString(
content,{
caffe2_pb2.PlanDef: lambda x: GetOperatorMapForPlan(x),
caffe2_pb2.NetDef: lambda x: {x.name: x.operators},
})
for key, operators in graphs.iteritems():
graph = GetPydotGraph(operators, key)
filename = graph.get_name() + '.dot'
graph.write(filename, format='raw')
pdf_filename = filename[:-3] + 'pdf'
with open(pdf_filename, 'w') as fid:
try:
subprocess.call(['dot', '-Tpdf', filename], stdout=fid)
except OSError:
print ('pydot requires graphviz to convert dot files to pdf: in ubuntu '
'this can usually be installed with "sudo apt-get install '
'graphviz". We have generated the .dot file but will not '
'generate pdf file for now due to missing graphviz binaries.')
if __name__ == '__main__':
main()