Caffe2 - Python API
A deep learning, cross platform ML framework
net_drawer.py
1 ## @package net_drawer
2 # Module caffe2.python.net_drawer
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 import argparse
8 import json
9 import logging
10 from collections import defaultdict
11 from caffe2.python import utils
12 from future.utils import viewitems
13 
14 logger = logging.getLogger(__name__)
15 logger.setLevel(logging.INFO)
16 
17 try:
18  import pydot
19 except ImportError:
20  logger.info(
21  'Cannot import pydot, which is required for drawing a network. This '
22  'can usually be installed in python with "pip install pydot". Also, '
23  'pydot requires graphviz to convert dot files to pdf: in ubuntu, this '
24  'can usually be installed with "sudo apt-get install graphviz".'
25  )
26  print(
27  'net_drawer will not run correctly. Please install the correct '
28  'dependencies.'
29  )
30  pydot = None
31 
32 from caffe2.proto import caffe2_pb2
33 
34 OP_STYLE = {
35  'shape': 'box',
36  'color': '#0F9D58',
37  'style': 'filled',
38  'fontcolor': '#FFFFFF'
39 }
40 BLOB_STYLE = {'shape': 'octagon'}
41 
42 
43 def _rectify_operator_and_name(operators_or_net, name):
44  """Gets the operators and name for the pydot graph."""
45  if isinstance(operators_or_net, caffe2_pb2.NetDef):
46  operators = operators_or_net.op
47  if name is None:
48  name = operators_or_net.name
49  elif hasattr(operators_or_net, 'Proto'):
50  net = operators_or_net.Proto()
51  if not isinstance(net, caffe2_pb2.NetDef):
52  raise RuntimeError(
53  "Expecting NetDef, but got {}".format(type(net)))
54  operators = net.op
55  if name is None:
56  name = net.name
57  else:
58  operators = operators_or_net
59  if name is None:
60  name = "unnamed"
61  return operators, name
62 
63 
64 def _escape_label(name):
65  # json.dumps is poor man's escaping
66  return json.dumps(name)
67 
68 
69 def GetOpNodeProducer(append_output, **kwargs):
70  def ReallyGetOpNode(op, op_id):
71  if op.name:
72  node_name = '%s/%s (op#%d)' % (op.name, op.type, op_id)
73  else:
74  node_name = '%s (op#%d)' % (op.type, op_id)
75  if append_output:
76  for output_name in op.output:
77  node_name += '\n' + output_name
78  return pydot.Node(node_name, **kwargs)
79  return ReallyGetOpNode
80 
81 
82 def GetBlobNodeProducer(**kwargs):
83  def ReallyGetBlobNode(node_name, label):
84  return pydot.Node(node_name, label=label, **kwargs)
85  return ReallyGetBlobNode
86 
87 def GetPydotGraph(
88  operators_or_net,
89  name=None,
90  rankdir='LR',
91  op_node_producer=None,
92  blob_node_producer=None
93 ):
94  if op_node_producer is None:
95  op_node_producer = GetOpNodeProducer(False, **OP_STYLE)
96  if blob_node_producer is None:
97  blob_node_producer = GetBlobNodeProducer(**BLOB_STYLE)
98  operators, name = _rectify_operator_and_name(operators_or_net, name)
99  graph = pydot.Dot(name, rankdir=rankdir)
100  pydot_nodes = {}
101  pydot_node_counts = defaultdict(int)
102  for op_id, op in enumerate(operators):
103  op_node = op_node_producer(op, op_id)
104  graph.add_node(op_node)
105  # print 'Op: %s' % op.name
106  # print 'inputs: %s' % str(op.input)
107  # print 'outputs: %s' % str(op.output)
108  for input_name in op.input:
109  if input_name not in pydot_nodes:
110  input_node = blob_node_producer(
111  _escape_label(
112  input_name + str(pydot_node_counts[input_name])),
113  label=_escape_label(input_name),
114  )
115  pydot_nodes[input_name] = input_node
116  else:
117  input_node = pydot_nodes[input_name]
118  graph.add_node(input_node)
119  graph.add_edge(pydot.Edge(input_node, op_node))
120  for output_name in op.output:
121  if output_name in pydot_nodes:
122  # we are overwriting an existing blob. need to updat the count.
123  pydot_node_counts[output_name] += 1
124  output_node = blob_node_producer(
125  _escape_label(
126  output_name + str(pydot_node_counts[output_name])),
127  label=_escape_label(output_name),
128  )
129  pydot_nodes[output_name] = output_node
130  graph.add_node(output_node)
131  graph.add_edge(pydot.Edge(op_node, output_node))
132  return graph
133 
134 
135 def GetPydotGraphMinimal(
136  operators_or_net,
137  name=None,
138  rankdir='LR',
139  minimal_dependency=False,
140  op_node_producer=None,
141 ):
142  """Different from GetPydotGraph, hide all blob nodes and only show op nodes.
143 
144  If minimal_dependency is set as well, for each op, we will only draw the
145  edges to the minimal necessary ancestors. For example, if op c depends on
146  op a and b, and op b depends on a, then only the edge b->c will be drawn
147  because a->c will be implied.
148  """
149  if op_node_producer is None:
150  op_node_producer = GetOpNodeProducer(False, **OP_STYLE)
151  operators, name = _rectify_operator_and_name(operators_or_net, name)
152  graph = pydot.Dot(name, rankdir=rankdir)
153  # blob_parents maps each blob name to its generating op.
154  blob_parents = {}
155  # op_ancestry records the ancestors of each op.
156  op_ancestry = defaultdict(set)
157  for op_id, op in enumerate(operators):
158  op_node = op_node_producer(op, op_id)
159  graph.add_node(op_node)
160  # Get parents, and set up op ancestry.
161  parents = [
162  blob_parents[input_name] for input_name in op.input
163  if input_name in blob_parents
164  ]
165  op_ancestry[op_node].update(parents)
166  for node in parents:
167  op_ancestry[op_node].update(op_ancestry[node])
168  if minimal_dependency:
169  # only add nodes that do not have transitive ancestry
170  for node in parents:
171  if all(
172  [node not in op_ancestry[other_node]
173  for other_node in parents]
174  ):
175  graph.add_edge(pydot.Edge(node, op_node))
176  else:
177  # Add all parents to the graph.
178  for node in parents:
179  graph.add_edge(pydot.Edge(node, op_node))
180  # Update blob_parents to reflect that this op created the blobs.
181  for output_name in op.output:
182  blob_parents[output_name] = op_node
183  return graph
184 
185 
186 def GetOperatorMapForPlan(plan_def):
187  operator_map = {}
188  for net_id, net in enumerate(plan_def.network):
189  if net.HasField('name'):
190  operator_map[plan_def.name + "_" + net.name] = net.op
191  else:
192  operator_map[plan_def.name + "_network_%d" % net_id] = net.op
193  return operator_map
194 
195 
196 def _draw_nets(nets, g):
197  nodes = []
198  for i, net in enumerate(nets):
199  nodes.append(pydot.Node(_escape_label(net)))
200  g.add_node(nodes[-1])
201  if i > 0:
202  g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
203  return nodes
204 
205 
206 def _draw_steps(steps, g, skip_step_edges=False): # noqa
207  kMaxParallelSteps = 3
208 
209  def get_label():
210  label = [step.name + '\n']
211  if step.report_net:
212  label.append('Reporter: {}'.format(step.report_net))
213  if step.should_stop_blob:
214  label.append('Stopper: {}'.format(step.should_stop_blob))
215  if step.concurrent_substeps:
216  label.append('Concurrent')
217  if step.only_once:
218  label.append('Once')
219  return '\n'.join(label)
220 
221  def substep_edge(start, end):
222  return pydot.Edge(start, end, arrowhead='dot', style='dashed')
223 
224  nodes = []
225  for i, step in enumerate(steps):
226  parallel = step.concurrent_substeps
227 
228  nodes.append(pydot.Node(_escape_label(get_label()), **OP_STYLE))
229  g.add_node(nodes[-1])
230 
231  if i > 0 and not skip_step_edges:
232  g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
233 
234  if step.network:
235  sub_nodes = _draw_nets(step.network, g)
236  elif step.substep:
237  if parallel:
238  sub_nodes = _draw_steps(
239  step.substep[:kMaxParallelSteps], g, skip_step_edges=True)
240  else:
241  sub_nodes = _draw_steps(step.substep, g)
242  else:
243  raise ValueError('invalid step')
244 
245  if parallel:
246  for sn in sub_nodes:
247  g.add_edge(substep_edge(nodes[-1], sn))
248  if len(step.substep) > kMaxParallelSteps:
249  ellipsis = pydot.Node('{} more steps'.format(
250  len(step.substep) - kMaxParallelSteps), **OP_STYLE)
251  g.add_node(ellipsis)
252  g.add_edge(substep_edge(nodes[-1], ellipsis))
253  else:
254  g.add_edge(substep_edge(nodes[-1], sub_nodes[0]))
255 
256  return nodes
257 
258 
259 def GetPlanGraph(plan_def, name=None, rankdir='TB'):
260  graph = pydot.Dot(name, rankdir=rankdir)
261  _draw_steps(plan_def.execution_step, graph)
262  return graph
263 
264 
265 def GetGraphInJson(operators_or_net, output_filepath):
266  operators, _ = _rectify_operator_and_name(operators_or_net, None)
267  blob_strid_to_node_id = {}
268  node_name_counts = defaultdict(int)
269  nodes = []
270  edges = []
271  for op_id, op in enumerate(operators):
272  op_label = op.name + '/' + op.type if op.name else op.type
273  op_node_id = len(nodes)
274  nodes.append({
275  'id': op_node_id,
276  'label': op_label,
277  'op_id': op_id,
278  'type': 'op'
279  })
280  for input_name in op.input:
281  strid = _escape_label(
282  input_name + str(node_name_counts[input_name]))
283  if strid not in blob_strid_to_node_id:
284  input_node = {
285  'id': len(nodes),
286  'label': input_name,
287  'type': 'blob'
288  }
289  blob_strid_to_node_id[strid] = len(nodes)
290  nodes.append(input_node)
291  else:
292  input_node = nodes[blob_strid_to_node_id[strid]]
293  edges.append({
294  'source': blob_strid_to_node_id[strid],
295  'target': op_node_id
296  })
297  for output_name in op.output:
298  strid = _escape_label(
299  output_name + str(node_name_counts[output_name]))
300  if strid in blob_strid_to_node_id:
301  # we are overwriting an existing blob. need to update the count.
302  node_name_counts[output_name] += 1
303  strid = _escape_label(
304  output_name + str(node_name_counts[output_name]))
305 
306  if strid not in blob_strid_to_node_id:
307  output_node = {
308  'id': len(nodes),
309  'label': output_name,
310  'type': 'blob'
311  }
312  blob_strid_to_node_id[strid] = len(nodes)
313  nodes.append(output_node)
314  edges.append({
315  'source': op_node_id,
316  'target': blob_strid_to_node_id[strid]
317  })
318 
319  with open(output_filepath, 'w') as f:
320  json.dump({'nodes': nodes, 'edges': edges}, f)
321 
322 
323 # A dummy minimal PNG image used by GetGraphPngSafe as a
324 # placeholder when rendering fail to run.
325 _DummyPngImage = (
326  b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00'
327  b'\x01\x01\x00\x00\x00\x007n\xf9$\x00\x00\x00\nIDATx\x9cc`\x00\x00'
328  b'\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82')
329 
330 
331 def GetGraphPngSafe(func, *args, **kwargs):
332  """
333  Invokes `func` (e.g. GetPydotGraph) with args. If anything fails - returns
334  and empty image instead of throwing Exception
335  """
336  try:
337  graph = func(*args, **kwargs)
338  if not isinstance(graph, pydot.Dot):
339  raise ValueError("func is expected to return pydot.Dot")
340  return graph.create_png()
341  except Exception as e:
342  logger.error("Failed to draw graph: {}".format(e))
343  return _DummyPngImage
344 
345 
346 def main():
347  parser = argparse.ArgumentParser(description="Caffe2 net drawer.")
348  parser.add_argument(
349  "--input",
350  type=str, required=True,
351  help="The input protobuf file."
352  )
353  parser.add_argument(
354  "--output_prefix",
355  type=str, default="",
356  help="The prefix to be added to the output filename."
357  )
358  parser.add_argument(
359  "--minimal", action="store_true",
360  help="If set, produce a minimal visualization."
361  )
362  parser.add_argument(
363  "--minimal_dependency", action="store_true",
364  help="If set, only draw minimal dependency."
365  )
366  parser.add_argument(
367  "--append_output", action="store_true",
368  help="If set, append the output blobs to the operator names.")
369  parser.add_argument(
370  "--rankdir", type=str, default="LR",
371  help="The rank direction of the pydot graph."
372  )
373  args = parser.parse_args()
374  with open(args.input, 'r') as fid:
375  content = fid.read()
376  graphs = utils.GetContentFromProtoString(
377  content, {
378  caffe2_pb2.PlanDef: lambda x: GetOperatorMapForPlan(x),
379  caffe2_pb2.NetDef: lambda x: {x.name: x.op},
380  }
381  )
382  for key, operators in viewitems(graphs):
383  if args.minimal:
384  graph = GetPydotGraphMinimal(
385  operators,
386  name=key,
387  rankdir=args.rankdir,
388  node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE),
389  minimal_dependency=args.minimal_dependency)
390  else:
391  graph = GetPydotGraph(
392  operators,
393  name=key,
394  rankdir=args.rankdir,
395  node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE))
396  filename = args.output_prefix + graph.get_name() + '.dot'
397  graph.write(filename, format='raw')
398  pdf_filename = filename[:-3] + 'pdf'
399  try:
400  graph.write_pdf(pdf_filename)
401  except Exception:
402  print(
403  'Error when writing out the pdf file. Pydot requires graphviz '
404  'to convert dot files to pdf, and you may not have installed '
405  'graphviz. On ubuntu this can usually be installed with "sudo '
406  'apt-get install graphviz". We have generated the .dot file '
407  'but will not be able to generate pdf file for now.'
408  )
409 
410 
411 if __name__ == '__main__':
412  main()