3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
10 from collections
import defaultdict
12 from future.utils
import viewitems
14 logger = logging.getLogger(__name__)
15 logger.setLevel(logging.INFO)
21 'Cannot import pydot, which is required for drawing a network. This ' 22 'can usually be installed in python with "pip install pydot". Also, ' 23 'pydot requires graphviz to convert dot files to pdf: in ubuntu, this ' 24 'can usually be installed with "sudo apt-get install graphviz".' 27 'net_drawer will not run correctly. Please install the correct ' 32 from caffe2.proto
import caffe2_pb2
38 'fontcolor':
'#FFFFFF' 40 BLOB_STYLE = {
'shape':
'octagon'}
43 def _rectify_operator_and_name(operators_or_net, name):
44 """Gets the operators and name for the pydot graph.""" 45 if isinstance(operators_or_net, caffe2_pb2.NetDef):
46 operators = operators_or_net.op
48 name = operators_or_net.name
49 elif hasattr(operators_or_net,
'Proto'):
50 net = operators_or_net.Proto()
51 if not isinstance(net, caffe2_pb2.NetDef):
53 "Expecting NetDef, but got {}".format(type(net)))
58 operators = operators_or_net
61 return operators, name
64 def _escape_label(name):
66 return json.dumps(name)
69 def GetOpNodeProducer(append_output, **kwargs):
70 def ReallyGetOpNode(op, op_id):
72 node_name =
'%s/%s (op#%d)' % (op.name, op.type, op_id)
74 node_name =
'%s (op#%d)' % (op.type, op_id)
76 for output_name
in op.output:
77 node_name +=
'\n' + output_name
78 return pydot.Node(node_name, **kwargs)
79 return ReallyGetOpNode
82 def GetBlobNodeProducer(**kwargs):
83 def ReallyGetBlobNode(node_name, label):
84 return pydot.Node(node_name, label=label, **kwargs)
85 return ReallyGetBlobNode
91 op_node_producer=
None,
92 blob_node_producer=
None 94 if op_node_producer
is None:
95 op_node_producer = GetOpNodeProducer(
False, **OP_STYLE)
96 if blob_node_producer
is None:
97 blob_node_producer = GetBlobNodeProducer(**BLOB_STYLE)
98 operators, name = _rectify_operator_and_name(operators_or_net, name)
99 graph = pydot.Dot(name, rankdir=rankdir)
101 pydot_node_counts = defaultdict(int)
102 for op_id, op
in enumerate(operators):
103 op_node = op_node_producer(op, op_id)
104 graph.add_node(op_node)
108 for input_name
in op.input:
109 if input_name
not in pydot_nodes:
110 input_node = blob_node_producer(
112 input_name + str(pydot_node_counts[input_name])),
113 label=_escape_label(input_name),
115 pydot_nodes[input_name] = input_node
117 input_node = pydot_nodes[input_name]
118 graph.add_node(input_node)
119 graph.add_edge(pydot.Edge(input_node, op_node))
120 for output_name
in op.output:
121 if output_name
in pydot_nodes:
123 pydot_node_counts[output_name] += 1
124 output_node = blob_node_producer(
126 output_name + str(pydot_node_counts[output_name])),
127 label=_escape_label(output_name),
129 pydot_nodes[output_name] = output_node
130 graph.add_node(output_node)
131 graph.add_edge(pydot.Edge(op_node, output_node))
135 def GetPydotGraphMinimal(
139 minimal_dependency=
False,
140 op_node_producer=
None,
142 """Different from GetPydotGraph, hide all blob nodes and only show op nodes. 144 If minimal_dependency is set as well, for each op, we will only draw the 145 edges to the minimal necessary ancestors. For example, if op c depends on 146 op a and b, and op b depends on a, then only the edge b->c will be drawn 147 because a->c will be implied. 149 if op_node_producer
is None:
150 op_node_producer = GetOpNodeProducer(
False, **OP_STYLE)
151 operators, name = _rectify_operator_and_name(operators_or_net, name)
152 graph = pydot.Dot(name, rankdir=rankdir)
156 op_ancestry = defaultdict(set)
157 for op_id, op
in enumerate(operators):
158 op_node = op_node_producer(op, op_id)
159 graph.add_node(op_node)
162 blob_parents[input_name]
for input_name
in op.input
163 if input_name
in blob_parents
165 op_ancestry[op_node].update(parents)
167 op_ancestry[op_node].update(op_ancestry[node])
168 if minimal_dependency:
172 [node
not in op_ancestry[other_node]
173 for other_node
in parents]
175 graph.add_edge(pydot.Edge(node, op_node))
179 graph.add_edge(pydot.Edge(node, op_node))
181 for output_name
in op.output:
182 blob_parents[output_name] = op_node
186 def GetOperatorMapForPlan(plan_def):
188 for net_id, net
in enumerate(plan_def.network):
189 if net.HasField(
'name'):
190 operator_map[plan_def.name +
"_" + net.name] = net.op
192 operator_map[plan_def.name +
"_network_%d" % net_id] = net.op
196 def _draw_nets(nets, g):
198 for i, net
in enumerate(nets):
199 nodes.append(pydot.Node(_escape_label(net)))
200 g.add_node(nodes[-1])
202 g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
206 def _draw_steps(steps, g, skip_step_edges=False):
207 kMaxParallelSteps = 3
210 label = [step.name +
'\n']
212 label.append(
'Reporter: {}'.format(step.report_net))
213 if step.should_stop_blob:
214 label.append(
'Stopper: {}'.format(step.should_stop_blob))
215 if step.concurrent_substeps:
216 label.append(
'Concurrent')
219 return '\n'.join(label)
221 def substep_edge(start, end):
222 return pydot.Edge(start, end, arrowhead=
'dot', style=
'dashed')
225 for i, step
in enumerate(steps):
226 parallel = step.concurrent_substeps
228 nodes.append(pydot.Node(_escape_label(get_label()), **OP_STYLE))
229 g.add_node(nodes[-1])
231 if i > 0
and not skip_step_edges:
232 g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
235 sub_nodes = _draw_nets(step.network, g)
238 sub_nodes = _draw_steps(
239 step.substep[:kMaxParallelSteps], g, skip_step_edges=
True)
241 sub_nodes = _draw_steps(step.substep, g)
243 raise ValueError(
'invalid step')
247 g.add_edge(substep_edge(nodes[-1], sn))
248 if len(step.substep) > kMaxParallelSteps:
249 ellipsis = pydot.Node(
'{} more steps'.format(
250 len(step.substep) - kMaxParallelSteps), **OP_STYLE)
252 g.add_edge(substep_edge(nodes[-1], ellipsis))
254 g.add_edge(substep_edge(nodes[-1], sub_nodes[0]))
259 def GetPlanGraph(plan_def, name=None, rankdir='TB'):
260 graph = pydot.Dot(name, rankdir=rankdir)
261 _draw_steps(plan_def.execution_step, graph)
265 def GetGraphInJson(operators_or_net, output_filepath):
266 operators, _ = _rectify_operator_and_name(operators_or_net,
None)
267 blob_strid_to_node_id = {}
268 node_name_counts = defaultdict(int)
271 for op_id, op
in enumerate(operators):
272 op_label = op.name +
'/' + op.type
if op.name
else op.type
273 op_node_id = len(nodes)
280 for input_name
in op.input:
281 strid = _escape_label(
282 input_name + str(node_name_counts[input_name]))
283 if strid
not in blob_strid_to_node_id:
289 blob_strid_to_node_id[strid] = len(nodes)
290 nodes.append(input_node)
292 input_node = nodes[blob_strid_to_node_id[strid]]
294 'source': blob_strid_to_node_id[strid],
297 for output_name
in op.output:
298 strid = _escape_label(
299 output_name + str(node_name_counts[output_name]))
300 if strid
in blob_strid_to_node_id:
302 node_name_counts[output_name] += 1
303 strid = _escape_label(
304 output_name + str(node_name_counts[output_name]))
306 if strid
not in blob_strid_to_node_id:
309 'label': output_name,
312 blob_strid_to_node_id[strid] = len(nodes)
313 nodes.append(output_node)
315 'source': op_node_id,
316 'target': blob_strid_to_node_id[strid]
319 with open(output_filepath,
'w')
as f:
320 json.dump({
'nodes': nodes,
'edges': edges}, f)
326 b
'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00' 327 b
'\x01\x01\x00\x00\x00\x007n\xf9$\x00\x00\x00\nIDATx\x9cc`\x00\x00' 328 b
'\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82')
331 def GetGraphPngSafe(func, *args, **kwargs):
333 Invokes `func` (e.g. GetPydotGraph) with args. If anything fails - returns 334 and empty image instead of throwing Exception 337 graph = func(*args, **kwargs)
338 if not isinstance(graph, pydot.Dot):
339 raise ValueError(
"func is expected to return pydot.Dot")
340 return graph.create_png()
341 except Exception
as e:
342 logger.error(
"Failed to draw graph: {}".format(e))
343 return _DummyPngImage
347 parser = argparse.ArgumentParser(description=
"Caffe2 net drawer.")
350 type=str, required=
True,
351 help=
"The input protobuf file." 355 type=str, default=
"",
356 help=
"The prefix to be added to the output filename." 359 "--minimal", action=
"store_true",
360 help=
"If set, produce a minimal visualization." 363 "--minimal_dependency", action=
"store_true",
364 help=
"If set, only draw minimal dependency." 367 "--append_output", action=
"store_true",
368 help=
"If set, append the output blobs to the operator names.")
370 "--rankdir", type=str, default=
"LR",
371 help=
"The rank direction of the pydot graph." 373 args = parser.parse_args()
374 with open(args.input,
'r') as fid: 376 graphs = utils.GetContentFromProtoString( 378 caffe2_pb2.PlanDef: lambda x: GetOperatorMapForPlan(x),
379 caffe2_pb2.NetDef:
lambda x: {x.name: x.op},
382 for key, operators
in viewitems(graphs):
384 graph = GetPydotGraphMinimal(
387 rankdir=args.rankdir,
388 node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE),
389 minimal_dependency=args.minimal_dependency)
391 graph = GetPydotGraph(
394 rankdir=args.rankdir,
395 node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE))
396 filename = args.output_prefix + graph.get_name() +
'.dot' 397 graph.write(filename, format=
'raw')
398 pdf_filename = filename[:-3] +
'pdf' 400 graph.write_pdf(pdf_filename)
403 'Error when writing out the pdf file. Pydot requires graphviz ' 404 'to convert dot files to pdf, and you may not have installed ' 405 'graphviz. On ubuntu this can usually be installed with "sudo ' 406 'apt-get install graphviz". We have generated the .dot file ' 407 'but will not be able to generate pdf file for now.' 411 if __name__ ==
'__main__':