Caffe2 - Python API
A deep learning, cross platform ML framework
net_drawer.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package net_drawer
17 # Module caffe2.python.net_drawer
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 import argparse
23 import json
24 import logging
25 from collections import defaultdict
26 from caffe2.python import utils
27 from future.utils import viewitems
28 
29 logger = logging.getLogger(__name__)
30 logger.setLevel(logging.INFO)
31 
32 try:
33  import pydot
34 except ImportError:
35  logger.info(
36  'Cannot import pydot, which is required for drawing a network. This '
37  'can usually be installed in python with "pip install pydot". Also, '
38  'pydot requires graphviz to convert dot files to pdf: in ubuntu, this '
39  'can usually be installed with "sudo apt-get install graphviz".'
40  )
41  print(
42  'net_drawer will not run correctly. Please install the correct '
43  'dependencies.'
44  )
45  pydot = None
46 
47 from caffe2.proto import caffe2_pb2
48 
49 OP_STYLE = {
50  'shape': 'box',
51  'color': '#0F9D58',
52  'style': 'filled',
53  'fontcolor': '#FFFFFF'
54 }
55 BLOB_STYLE = {'shape': 'octagon'}
56 
57 
58 def _rectify_operator_and_name(operators_or_net, name):
59  """Gets the operators and name for the pydot graph."""
60  if isinstance(operators_or_net, caffe2_pb2.NetDef):
61  operators = operators_or_net.op
62  if name is None:
63  name = operators_or_net.name
64  elif hasattr(operators_or_net, 'Proto'):
65  net = operators_or_net.Proto()
66  if not isinstance(net, caffe2_pb2.NetDef):
67  raise RuntimeError(
68  "Expecting NetDef, but got {}".format(type(net)))
69  operators = net.op
70  if name is None:
71  name = net.name
72  else:
73  operators = operators_or_net
74  if name is None:
75  name = "unnamed"
76  return operators, name
77 
78 
79 def _escape_label(name):
80  # json.dumps is poor man's escaping
81  return json.dumps(name)
82 
83 
84 def GetOpNodeProducer(append_output, **kwargs):
85  def ReallyGetOpNode(op, op_id):
86  if op.name:
87  node_name = '%s/%s (op#%d)' % (op.name, op.type, op_id)
88  else:
89  node_name = '%s (op#%d)' % (op.type, op_id)
90  if append_output:
91  for output_name in op.output:
92  node_name += '\n' + output_name
93  return pydot.Node(node_name, **kwargs)
94  return ReallyGetOpNode
95 
96 
97 def GetPydotGraph(
98  operators_or_net,
99  name=None,
100  rankdir='LR',
101  node_producer=None
102 ):
103  if node_producer is None:
104  node_producer = GetOpNodeProducer(False, **OP_STYLE)
105  operators, name = _rectify_operator_and_name(operators_or_net, name)
106  graph = pydot.Dot(name, rankdir=rankdir)
107  pydot_nodes = {}
108  pydot_node_counts = defaultdict(int)
109  for op_id, op in enumerate(operators):
110  op_node = node_producer(op, op_id)
111  graph.add_node(op_node)
112  # print 'Op: %s' % op.name
113  # print 'inputs: %s' % str(op.input)
114  # print 'outputs: %s' % str(op.output)
115  for input_name in op.input:
116  if input_name not in pydot_nodes:
117  input_node = pydot.Node(
118  _escape_label(
119  input_name + str(pydot_node_counts[input_name])),
120  label=_escape_label(input_name),
121  **BLOB_STYLE
122  )
123  pydot_nodes[input_name] = input_node
124  else:
125  input_node = pydot_nodes[input_name]
126  graph.add_node(input_node)
127  graph.add_edge(pydot.Edge(input_node, op_node))
128  for output_name in op.output:
129  if output_name in pydot_nodes:
130  # we are overwriting an existing blob. need to updat the count.
131  pydot_node_counts[output_name] += 1
132  output_node = pydot.Node(
133  _escape_label(
134  output_name + str(pydot_node_counts[output_name])),
135  label=_escape_label(output_name),
136  **BLOB_STYLE
137  )
138  pydot_nodes[output_name] = output_node
139  graph.add_node(output_node)
140  graph.add_edge(pydot.Edge(op_node, output_node))
141  return graph
142 
143 
144 def GetPydotGraphMinimal(
145  operators_or_net,
146  name=None,
147  rankdir='LR',
148  minimal_dependency=False,
149  node_producer=None,
150 ):
151  """Different from GetPydotGraph, hide all blob nodes and only show op nodes.
152 
153  If minimal_dependency is set as well, for each op, we will only draw the
154  edges to the minimal necessary ancestors. For example, if op c depends on
155  op a and b, and op b depends on a, then only the edge b->c will be drawn
156  because a->c will be implied.
157  """
158  if node_producer is None:
159  node_producer = GetOpNodeProducer(False, **OP_STYLE)
160  operators, name = _rectify_operator_and_name(operators_or_net, name)
161  graph = pydot.Dot(name, rankdir=rankdir)
162  # blob_parents maps each blob name to its generating op.
163  blob_parents = {}
164  # op_ancestry records the ancestors of each op.
165  op_ancestry = defaultdict(set)
166  for op_id, op in enumerate(operators):
167  op_node = node_producer(op, op_id)
168  graph.add_node(op_node)
169  # Get parents, and set up op ancestry.
170  parents = [
171  blob_parents[input_name] for input_name in op.input
172  if input_name in blob_parents
173  ]
174  op_ancestry[op_node].update(parents)
175  for node in parents:
176  op_ancestry[op_node].update(op_ancestry[node])
177  if minimal_dependency:
178  # only add nodes that do not have transitive ancestry
179  for node in parents:
180  if all(
181  [node not in op_ancestry[other_node]
182  for other_node in parents]
183  ):
184  graph.add_edge(pydot.Edge(node, op_node))
185  else:
186  # Add all parents to the graph.
187  for node in parents:
188  graph.add_edge(pydot.Edge(node, op_node))
189  # Update blob_parents to reflect that this op created the blobs.
190  for output_name in op.output:
191  blob_parents[output_name] = op_node
192  return graph
193 
194 
195 def GetOperatorMapForPlan(plan_def):
196  operator_map = {}
197  for net_id, net in enumerate(plan_def.network):
198  if net.HasField('name'):
199  operator_map[plan_def.name + "_" + net.name] = net.op
200  else:
201  operator_map[plan_def.name + "_network_%d" % net_id] = net.op
202  return operator_map
203 
204 
205 def _draw_nets(nets, g):
206  nodes = []
207  for i, net in enumerate(nets):
208  nodes.append(pydot.Node(_escape_label(net)))
209  g.add_node(nodes[-1])
210  if i > 0:
211  g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
212  return nodes
213 
214 
215 def _draw_steps(steps, g, skip_step_edges=False): # noqa
216  kMaxParallelSteps = 3
217 
218  def get_label():
219  label = [step.name + '\n']
220  if step.report_net:
221  label.append('Reporter: {}'.format(step.report_net))
222  if step.should_stop_blob:
223  label.append('Stopper: {}'.format(step.should_stop_blob))
224  if step.concurrent_substeps:
225  label.append('Concurrent')
226  if step.only_once:
227  label.append('Once')
228  return '\n'.join(label)
229 
230  def substep_edge(start, end):
231  return pydot.Edge(start, end, arrowhead='dot', style='dashed')
232 
233  nodes = []
234  for i, step in enumerate(steps):
235  parallel = step.concurrent_substeps
236 
237  nodes.append(pydot.Node(_escape_label(get_label()), **OP_STYLE))
238  g.add_node(nodes[-1])
239 
240  if i > 0 and not skip_step_edges:
241  g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
242 
243  if step.network:
244  sub_nodes = _draw_nets(step.network, g)
245  elif step.substep:
246  if parallel:
247  sub_nodes = _draw_steps(
248  step.substep[:kMaxParallelSteps], g, skip_step_edges=True)
249  else:
250  sub_nodes = _draw_steps(step.substep, g)
251  else:
252  raise ValueError('invalid step')
253 
254  if parallel:
255  for sn in sub_nodes:
256  g.add_edge(substep_edge(nodes[-1], sn))
257  if len(step.substep) > kMaxParallelSteps:
258  ellipsis = pydot.Node('{} more steps'.format(
259  len(step.substep) - kMaxParallelSteps), **OP_STYLE)
260  g.add_node(ellipsis)
261  g.add_edge(substep_edge(nodes[-1], ellipsis))
262  else:
263  g.add_edge(substep_edge(nodes[-1], sub_nodes[0]))
264 
265  return nodes
266 
267 
268 def GetPlanGraph(plan_def, name=None, rankdir='TB'):
269  graph = pydot.Dot(name, rankdir=rankdir)
270  _draw_steps(plan_def.execution_step, graph)
271  return graph
272 
273 
274 def GetGraphInJson(operators_or_net, output_filepath):
275  operators, _ = _rectify_operator_and_name(operators_or_net, None)
276  blob_strid_to_node_id = {}
277  node_name_counts = defaultdict(int)
278  nodes = []
279  edges = []
280  for op_id, op in enumerate(operators):
281  op_label = op.name + '/' + op.type if op.name else op.type
282  op_node_id = len(nodes)
283  nodes.append({
284  'id': op_node_id,
285  'label': op_label,
286  'op_id': op_id,
287  'type': 'op'
288  })
289  for input_name in op.input:
290  strid = _escape_label(
291  input_name + str(node_name_counts[input_name]))
292  if strid not in blob_strid_to_node_id:
293  input_node = {
294  'id': len(nodes),
295  'label': input_name,
296  'type': 'blob'
297  }
298  blob_strid_to_node_id[strid] = len(nodes)
299  nodes.append(input_node)
300  else:
301  input_node = nodes[blob_strid_to_node_id[strid]]
302  edges.append({
303  'source': blob_strid_to_node_id[strid],
304  'target': op_node_id
305  })
306  for output_name in op.output:
307  strid = _escape_label(
308  output_name + str(node_name_counts[output_name]))
309  if strid in blob_strid_to_node_id:
310  # we are overwriting an existing blob. need to update the count.
311  node_name_counts[output_name] += 1
312  strid = _escape_label(
313  output_name + str(node_name_counts[output_name]))
314 
315  if strid not in blob_strid_to_node_id:
316  output_node = {
317  'id': len(nodes),
318  'label': output_name,
319  'type': 'blob'
320  }
321  blob_strid_to_node_id[strid] = len(nodes)
322  nodes.append(output_node)
323  edges.append({
324  'source': op_node_id,
325  'target': blob_strid_to_node_id[strid]
326  })
327 
328  with open(output_filepath, 'w') as f:
329  json.dump({'nodes': nodes, 'edges': edges}, f)
330 
331 
332 # A dummy minimal PNG image used by GetGraphPngSafe as a
333 # placeholder when rendering fail to run.
334 _DummyPngImage = (
335  b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00'
336  b'\x01\x01\x00\x00\x00\x007n\xf9$\x00\x00\x00\nIDATx\x9cc`\x00\x00'
337  b'\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82')
338 
339 
340 def GetGraphPngSafe(func, *args, **kwargs):
341  """
342  Invokes `func` (e.g. GetPydotGraph) with args. If anything fails - returns
343  and empty image instead of throwing Exception
344  """
345  try:
346  graph = func(*args, **kwargs)
347  if not isinstance(graph, pydot.Dot):
348  raise ValueError("func is expected to return pydot.Dot")
349  return graph.create_png()
350  except Exception as e:
351  logger.error("Failed to draw graph: {}".format(e))
352  return _DummyPngImage
353 
354 
355 def main():
356  parser = argparse.ArgumentParser(description="Caffe2 net drawer.")
357  parser.add_argument(
358  "--input",
359  type=str, required=True,
360  help="The input protobuf file."
361  )
362  parser.add_argument(
363  "--output_prefix",
364  type=str, default="",
365  help="The prefix to be added to the output filename."
366  )
367  parser.add_argument(
368  "--minimal", action="store_true",
369  help="If set, produce a minimal visualization."
370  )
371  parser.add_argument(
372  "--minimal_dependency", action="store_true",
373  help="If set, only draw minimal dependency."
374  )
375  parser.add_argument(
376  "--append_output", action="store_true",
377  help="If set, append the output blobs to the operator names.")
378  parser.add_argument(
379  "--rankdir", type=str, default="LR",
380  help="The rank direction of the pydot graph."
381  )
382  args = parser.parse_args()
383  with open(args.input, 'r') as fid:
384  content = fid.read()
385  graphs = utils.GetContentFromProtoString(
386  content, {
387  caffe2_pb2.PlanDef: lambda x: GetOperatorMapForPlan(x),
388  caffe2_pb2.NetDef: lambda x: {x.name: x.op},
389  }
390  )
391  for key, operators in viewitems(graphs):
392  if args.minimal:
393  graph = GetPydotGraphMinimal(
394  operators,
395  name=key,
396  rankdir=args.rankdir,
397  node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE),
398  minimal_dependency=args.minimal_dependency)
399  else:
400  graph = GetPydotGraph(
401  operators,
402  name=key,
403  rankdir=args.rankdir,
404  node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE))
405  filename = args.output_prefix + graph.get_name() + '.dot'
406  graph.write(filename, format='raw')
407  pdf_filename = filename[:-3] + 'pdf'
408  try:
409  graph.write_pdf(pdf_filename)
410  except Exception:
411  print(
412  'Error when writing out the pdf file. Pydot requires graphviz '
413  'to convert dot files to pdf, and you may not have installed '
414  'graphviz. On ubuntu this can usually be installed with "sudo '
415  'apt-get install graphviz". We have generated the .dot file '
416  'but will not be able to generate pdf file for now.'
417  )
418 
419 
420 if __name__ == '__main__':
421  main()