Caffe2 - Python API
A deep learning, cross platform ML framework
net_printer.py
1 ## @package net_printer
2 # Module caffe2.python.net_printer
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.proto.caffe2_pb2 import OperatorDef, NetDef
9 from caffe2.python.checkpoint import Job
10 from caffe2.python.core import Net, ExecutionStep, Plan
11 from caffe2.python.task import Task, TaskGroup, WorkspaceType, TaskOutput
12 from collections import defaultdict
13 from contextlib import contextmanager
14 from copy import copy
15 from future.utils import viewkeys
16 from itertools import chain
17 from six import binary_type, text_type
18 
19 
20 class Visitor(object):
21  @classmethod
22  def register(cls, Type):
23  if not(hasattr(cls, 'visitors')):
24  cls.visitors = {}
25  else:
26  assert Type not in cls.visitors, \
27  '{} already registered!'.format(Type)
28 
29  def _register(func):
30  cls.visitors[Type] = func
31  return func
32 
33  return _register
34 
35  def __call__(self, obj, *args, **kwargs):
36  if obj is None:
37  return
38 
39  Type = type(obj)
40  if Type not in self.__class__.visitors:
41  raise TypeError('%s: unsupported object type: %s' % (
42  self.__class__.__name__, Type))
43 
44  func = self.__class__.visitors[Type]
45  return func(self, obj, *args, **kwargs)
46 
47 
49  PREFIXES_TO_IGNORE = {'distributed_ctx_init'}
50 
51  def __init__(self):
52  self.workspaces = defaultdict(lambda: defaultdict(lambda: 0))
53  self.workspace_ctx = []
54 
55  @property
56  def workspace(self):
57  return self.workspace_ctx[-1]
58 
59  @contextmanager
60  def set_workspace(self, node=None, ws=None, do_copy=False):
61  if ws is not None:
62  ws = ws
63  elif node is not None:
64  ws = self.workspaces[str(node)]
65  else:
66  ws = self.workspace
67  if do_copy:
68  ws = copy(ws)
69  self.workspace_ctx.append(ws)
70  yield ws
71  del self.workspace_ctx[-1]
72 
73  def define_blob(self, blob):
74  self.workspace[blob] += 1
75 
76  def need_blob(self, blob):
77  if any(blob.startswith(p) for p in Analyzer.PREFIXES_TO_IGNORE):
78  return
79  assert blob in self.workspace, 'Blob undefined: %s' % blob
80 
81 
82 @Analyzer.register(OperatorDef)
83 def analyze_op(analyzer, op):
84  for x in op.input:
85  analyzer.need_blob(x)
86  for x in op.output:
87  analyzer.define_blob(x)
88 
89 
90 @Analyzer.register(Net)
91 def analyze_net(analyzer, net):
92  for x in net.Proto().op:
93  analyzer(x)
94 
95 
96 @Analyzer.register(ExecutionStep)
97 def analyze_step(analyzer, step):
98  proto = step.Proto()
99  with analyzer.set_workspace(do_copy=proto.create_workspace):
100  if proto.report_net:
101  with analyzer.set_workspace(do_copy=True):
102  analyzer(step.get_net(proto.report_net))
103  all_new_blobs = set()
104  substeps = step.Substeps() + [step.get_net(n) for n in proto.network]
105  for substep in substeps:
106  with analyzer.set_workspace(
107  do_copy=proto.concurrent_substeps) as ws_in:
108  analyzer(substep)
109  if proto.should_stop_blob:
110  analyzer.need_blob(proto.should_stop_blob)
111  if proto.concurrent_substeps:
112  new_blobs = set(viewkeys(ws_in)) - set(viewkeys(analyzer.workspace))
113  assert len(all_new_blobs & new_blobs) == 0, (
114  'Error: Blobs created by multiple parallel steps: %s' % (
115  ', '.join(all_new_blobs & new_blobs)))
116  all_new_blobs |= new_blobs
117  for x in all_new_blobs:
118  analyzer.define_blob(x)
119 
120 
121 @Analyzer.register(Task)
122 def analyze_task(analyzer, task):
123  # check that our plan protobuf is not too large (limit of 64Mb)
124  step = task.get_step()
125  plan = Plan(task.node)
126  plan.AddStep(step)
127  proto_len = len(plan.Proto().SerializeToString())
128  assert proto_len < 2 ** 26, (
129  'Due to a protobuf limitation, serialized tasks must be smaller '
130  'than 64Mb, but this task has {} bytes.' % proto_len)
131 
132  is_private = task.workspace_type() != WorkspaceType.GLOBAL
133  with analyzer.set_workspace(do_copy=is_private):
134  analyzer(step)
135 
136 
137 @Analyzer.register(TaskGroup)
138 def analyze_task_group(analyzer, tg):
139  for task in tg.tasks_by_node().tasks():
140  with analyzer.set_workspace(node=task.node):
141  analyzer(task)
142 
143 
144 @Analyzer.register(Job)
145 def analyze_job(analyzer, job):
146  analyzer(job.init_group)
147  analyzer(job.epoch_group)
148 
149 
150 def analyze(obj):
151  """
152  Given a Job, visits all the execution steps making sure that:
153  - no undefined blobs will be found during excution
154  - no blob with same name is defined in concurrent steps
155  """
156  Analyzer()(obj)
157 
158 
159 class Text(object):
160  def __init__(self):
161  self._indent = 0
162  self._lines_in_context = [0]
163  self.lines = []
164 
165  @contextmanager
166  def context(self, text):
167  if text is not None:
168  self.add('with %s:' % text)
169  self._indent += 4
170  self._lines_in_context.append(0)
171  yield
172  if text is not None:
173  if self._lines_in_context[-1] == 0:
174  self.add('pass')
175  self._indent -= 4
176  del self._lines_in_context[-1]
177 
178  def add(self, text):
179  self._lines_in_context[-1] += 1
180  self.lines.append((' ' * self._indent) + text)
181 
182  def __str__(self):
183  return '\n'.join(self.lines)
184 
185 
187  def __init__(self, factor_prefixes=False, c2_syntax=True):
188  super(Visitor, self).__init__()
189  super(Text, self).__init__()
190  self.factor_prefixes = factor_prefixes
191  self.c2_syntax = c2_syntax
192  self.c2_net_name = None
193 
194 
195 def _sanitize_str(s):
196  if isinstance(s, text_type):
197  sanitized = s
198  elif isinstance(s, binary_type):
199  sanitized = s.decode('ascii', errors='ignore')
200  else:
201  sanitized = str(s)
202  if len(sanitized) < 64:
203  return "'%s'" % sanitized
204  else:
205  return "'%s'" % sanitized[:64] + '...<+len=%d>' % (len(sanitized) - 64)
206 
207 
208 def _arg_val(arg):
209  if arg.HasField('f'):
210  return str(arg.f)
211  if arg.HasField('i'):
212  return str(arg.i)
213  if arg.HasField('s'):
214  return _sanitize_str(arg.s)
215  if arg.floats:
216  return str(list(arg.floats))
217  if arg.ints:
218  return str(list(arg.ints))
219  if arg.strings:
220  return str([_sanitize_str(s) for s in arg.strings])
221  return '[]'
222 
223 
224 def commonprefix(m):
225  "Given a list of strings, returns the longest common prefix"
226  if not m:
227  return ''
228  s1 = min(m)
229  s2 = max(m)
230  for i, c in enumerate(s1):
231  if c != s2[i]:
232  return s1[:i]
233  return s1
234 
235 
236 def format_value(val):
237  if isinstance(val, list):
238  return '[%s]' % ', '.join("'%s'" % str(v) for v in val)
239  else:
240  return str(val)
241 
242 
243 def factor_prefix(vals, do_it):
244  vals = [format_value(v) for v in vals]
245  prefix = commonprefix(vals) if len(vals) > 1 and do_it else ''
246  joined = ', '.join(v[len(prefix):] for v in vals)
247  return '%s[%s]' % (prefix, joined) if prefix else joined
248 
249 
250 def call(op, inputs=None, outputs=None, factor_prefixes=False):
251  if not inputs:
252  inputs = ''
253  else:
254  inputs_v = [a for a in inputs if not isinstance(a, tuple)]
255  inputs_kv = [a for a in inputs if isinstance(a, tuple)]
256  inputs = ', '.join(
257  x
258  for x in chain(
259  [factor_prefix(inputs_v, factor_prefixes)],
260  ('%s=%s' % kv for kv in inputs_kv),
261  )
262  if x
263  )
264  call = '%s(%s)' % (op, inputs)
265  return call if not outputs else '%s = %s' % (
266  factor_prefix(outputs, factor_prefixes), call)
267 
268 
269 def format_device_option(dev_opt):
270  if not dev_opt or not (
271  dev_opt.device_type or dev_opt.device_id or dev_opt.node_name):
272  return None
273  return call(
274  'DeviceOption',
275  [dev_opt.device_type, dev_opt.device_id, "'%s'" % dev_opt.node_name])
276 
277 
278 @Printer.register(OperatorDef)
279 def print_op(text, op):
280  args = [(a.name, _arg_val(a)) for a in op.arg]
281  dev_opt_txt = format_device_option(op.device_option)
282  if dev_opt_txt:
283  args.append(('device_option', dev_opt_txt))
284 
285  if text.c2_net_name:
286  text.add(call(
287  text.c2_net_name + '.' + op.type,
288  [list(op.input), list(op.output)] + args))
289  else:
290  text.add(call(
291  op.type,
292  list(op.input) + args,
293  op.output,
294  factor_prefixes=text.factor_prefixes))
295  for arg in op.arg:
296  if arg.HasField('n'):
297  with text.context('arg: %s' % arg.name):
298  text(arg.n)
299 
300 
301 @Printer.register(NetDef)
302 def print_net_def(text, net_def):
303  if text.c2_syntax:
304  text.add(call('core.Net', ["'%s'" % net_def.name], [net_def.name]))
305  text.c2_net_name = net_def.name
306  else:
307  text.add('# net: %s' % net_def.name)
308  for op in net_def.op:
309  text(op)
310  if text.c2_syntax:
311  text.c2_net_name = None
312 
313 
314 @Printer.register(Net)
315 def print_net(text, net):
316  text(net.Proto())
317 
318 
319 def _get_step_context(step):
320  proto = step.Proto()
321  if proto.should_stop_blob:
322  return call('loop'), False
323  if proto.num_iter and proto.num_iter != 1:
324  return call('loop', [proto.num_iter]), False
325  if proto.num_concurrent_instances > 1:
326  return (
327  call('parallel',
328  [('num_instances', proto.num_concurrent_instances)]),
329  len(step.Substeps()) > 1)
330  concurrent = proto.concurrent_substeps and len(step.Substeps()) > 1
331  if concurrent:
332  return call('parallel'), True
333  if proto.report_net:
334  return call('run_once'), False
335  return None, False
336 
337 
338 @Printer.register(ExecutionStep)
339 def print_step(text, step):
340  proto = step.Proto()
341  step_ctx, do_substep = _get_step_context(step)
342  with text.context(step_ctx):
343  if proto.report_net:
344  with text.context(call('report_net', [proto.report_interval])):
345  text(step.get_net(proto.report_net))
346  substeps = step.Substeps() + [step.get_net(n) for n in proto.network]
347  for substep in substeps:
348  sub_proto = (
349  substep.Proto() if isinstance(substep, ExecutionStep) else None)
350  if sub_proto is not None and sub_proto.run_every_ms:
351  substep_ctx = call(
352  'reporter',
353  [str(substep), ('interval_ms', sub_proto.run_every_ms)])
354  elif do_substep:
355  title = (
356  'workspace'
357  if sub_proto is not None and sub_proto.create_workspace else
358  'step')
359  substep_ctx = call(title, [str(substep)])
360  else:
361  substep_ctx = None
362  with text.context(substep_ctx):
363  text(substep)
364  if proto.should_stop_blob:
365  text.add(call('yield stop_if', [proto.should_stop_blob]))
366 
367 
368 def _print_task_output(x):
369  assert isinstance(x, TaskOutput)
370  return 'Output[' + ', '.join(str(x) for x in x.names) + ']'
371 
372 
373 @Printer.register(Task)
374 def print_task(text, task):
375  outs = ', '.join(_print_task_output(o) for o in task.outputs())
376  context = [('node', task.node), ('name', task.name), ('outputs', outs)]
377  with text.context(call('Task', context)):
378  text(task.get_step())
379 
380 
381 @Printer.register(TaskGroup)
382 def print_task_group(text, tg, header=None):
383  with text.context(header or call('TaskGroup')):
384  for task in tg.tasks_by_node().tasks():
385  text(task)
386 
387 
388 @Printer.register(Job)
389 def print_job(text, job):
390  text(job.init_group, 'Job.current().init_group')
391  text(job.epoch_group, 'Job.current().epoch_group')
392  with text.context('Job.current().stop_conditions'):
393  for out in job.stop_conditions:
394  text.add(_print_task_output(out))
395  text(job.download_group, 'Job.current().download_group')
396  text(job.exit_group, 'Job.current().exit_group')
397 
398 
399 def to_string(obj, **kwargs):
400  """
401  Given a Net, ExecutionStep, Task, TaskGroup or Job, produces a string
402  with detailed description of the execution steps.
403  """
404  printer = Printer(**kwargs)
405  printer(obj)
406  return str(printer)
407 
408 
409 def debug_net(net):
410  """
411  Given a Net, produce another net that logs info about the operator call
412  before each operator execution. Use for debugging purposes.
413  """
414  assert isinstance(net, Net)
415  debug_net = Net(str(net))
416  assert isinstance(net, Net)
417  for op in net.Proto().op:
418  text = Text()
419  print_op(op, text)
420  debug_net.LogInfo(str(text))
421  debug_net.Proto().op.extend([op])
422  return debug_net
Module caffe2.python.workspace.
Module caffe2.python.context.