Caffe2 - Python API
A deep learning, cross platform ML framework
net_printer.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package net_printer
17 # Module caffe2.python.net_printer
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.proto.caffe2_pb2 import OperatorDef, NetDef
24 from caffe2.python.checkpoint import Job
25 from caffe2.python.core import Net, ExecutionStep, Plan
26 from caffe2.python.task import Task, TaskGroup, WorkspaceType, TaskOutput
27 from collections import defaultdict
28 from contextlib import contextmanager
29 from copy import copy
30 from future.utils import viewkeys
31 from itertools import chain
32 from six import binary_type, text_type
33 
34 
35 class Visitor(object):
36  @classmethod
37  def register(cls, Type):
38  if not(hasattr(cls, 'visitors')):
39  cls.visitors = []
40 
41  def _register(func):
42  cls.visitors.append((Type, func))
43  return func
44 
45  return _register
46 
47  def __call__(self, obj, *args, **kwargs):
48  if obj is None:
49  return
50  for Type, func in self.__class__.visitors:
51  if isinstance(obj, Type):
52  return func(self, obj, *args, **kwargs)
53  raise TypeError('%s: unsupported object type: %s' % (
54  self.__class__.__name__, type(obj)))
55 
56 
58  PREFIXES_TO_IGNORE = {'distributed_ctx_init'}
59 
60  def __init__(self):
61  self.workspaces = defaultdict(lambda: defaultdict(lambda: 0))
62  self.workspace_ctx = []
63 
64  @property
65  def workspace(self):
66  return self.workspace_ctx[-1]
67 
68  @contextmanager
69  def set_workspace(self, node=None, ws=None, do_copy=False):
70  if ws is not None:
71  ws = ws
72  elif node is not None:
73  ws = self.workspaces[str(node)]
74  else:
75  ws = self.workspace
76  if do_copy:
77  ws = copy(ws)
78  self.workspace_ctx.append(ws)
79  yield ws
80  del self.workspace_ctx[-1]
81 
82  def define_blob(self, blob):
83  self.workspace[blob] += 1
84 
85  def need_blob(self, blob):
86  if any(blob.startswith(p) for p in Analyzer.PREFIXES_TO_IGNORE):
87  return
88  assert blob in self.workspace, 'Blob undefined: %s' % blob
89 
90 
91 @Analyzer.register(OperatorDef)
92 def analyze_op(analyzer, op):
93  for x in op.input:
94  analyzer.need_blob(x)
95  for x in op.output:
96  analyzer.define_blob(x)
97 
98 
99 @Analyzer.register(Net)
100 def analyze_net(analyzer, net):
101  for x in net.Proto().op:
102  analyzer(x)
103 
104 
105 @Analyzer.register(ExecutionStep)
106 def analyze_step(analyzer, step):
107  proto = step.Proto()
108  with analyzer.set_workspace(do_copy=proto.create_workspace):
109  if proto.report_net:
110  with analyzer.set_workspace(do_copy=True):
111  analyzer(step.get_net(proto.report_net))
112  all_new_blobs = set()
113  substeps = step.Substeps() + [step.get_net(n) for n in proto.network]
114  for substep in substeps:
115  with analyzer.set_workspace(
116  do_copy=proto.concurrent_substeps) as ws_in:
117  analyzer(substep)
118  if proto.should_stop_blob:
119  analyzer.need_blob(proto.should_stop_blob)
120  if proto.concurrent_substeps:
121  new_blobs = set(viewkeys(ws_in)) - set(viewkeys(analyzer.workspace))
122  assert len(all_new_blobs & new_blobs) == 0, (
123  'Error: Blobs created by multiple parallel steps: %s' % (
124  ', '.join(all_new_blobs & new_blobs)))
125  all_new_blobs |= new_blobs
126  for x in all_new_blobs:
127  analyzer.define_blob(x)
128 
129 
130 @Analyzer.register(Task)
131 def analyze_task(analyzer, task):
132  # check that our plan protobuf is not too large (limit of 64Mb)
133  step = task.get_step()
134  plan = Plan(task.node)
135  plan.AddStep(step)
136  proto_len = len(plan.Proto().SerializeToString())
137  assert proto_len < 2 ** 26, (
138  'Due to a protobuf limitation, serialized tasks must be smaller '
139  'than 64Mb, but this task has {} bytes.' % proto_len)
140 
141  is_private = task.workspace_type() != WorkspaceType.GLOBAL
142  with analyzer.set_workspace(do_copy=is_private):
143  analyzer(step)
144 
145 
146 @Analyzer.register(TaskGroup)
147 def analyze_task_group(analyzer, tg):
148  for task in tg.tasks_by_node().tasks():
149  with analyzer.set_workspace(node=task.node):
150  analyzer(task)
151 
152 
153 @Analyzer.register(Job)
154 def analyze_job(analyzer, job):
155  analyzer(job.init_group)
156  analyzer(job.epoch_group)
157 
158 
159 def analyze(obj):
160  """
161  Given a Job, visits all the execution steps making sure that:
162  - no undefined blobs will be found during excution
163  - no blob with same name is defined in concurrent steps
164  """
165  Analyzer()(obj)
166 
167 
168 class Text(object):
169  def __init__(self):
170  self._indent = 0
171  self._lines_in_context = [0]
172  self.lines = []
173 
174  @contextmanager
175  def context(self, text):
176  if text is not None:
177  self.add('with %s:' % text)
178  self._indent += 4
179  self._lines_in_context.append(0)
180  yield
181  if text is not None:
182  if self._lines_in_context[-1] == 0:
183  self.add('pass')
184  self._indent -= 4
185  del self._lines_in_context[-1]
186 
187  def add(self, text):
188  self._lines_in_context[-1] += 1
189  self.lines.append((' ' * self._indent) + text)
190 
191  def __str__(self):
192  return '\n'.join(self.lines)
193 
194 
196  def __init__(self, factor_prefixes=False, c2_syntax=True):
197  super(Visitor, self).__init__()
198  super(Text, self).__init__()
199  self.factor_prefixes = factor_prefixes
200  self.c2_syntax = c2_syntax
201  self.c2_net_name = None
202 
203 
204 def _sanitize_str(s):
205  if isinstance(s, text_type):
206  sanitized = s
207  elif isinstance(s, binary_type):
208  sanitized = s.decode('ascii', errors='ignore')
209  else:
210  sanitized = str(s)
211  if len(sanitized) < 64:
212  return "'%s'" % sanitized
213  else:
214  return "'%s'" % sanitized[:64] + '...<+len=%d>' % (len(sanitized) - 64)
215 
216 
217 def _arg_val(arg):
218  if arg.HasField('f'):
219  return str(arg.f)
220  if arg.HasField('i'):
221  return str(arg.i)
222  if arg.HasField('s'):
223  return _sanitize_str(arg.s)
224  if arg.floats:
225  return str(list(arg.floats))
226  if arg.ints:
227  return str(list(arg.ints))
228  if arg.strings:
229  return str([_sanitize_str(s) for s in arg.strings])
230  return '[]'
231 
232 
233 def commonprefix(m):
234  "Given a list of strings, returns the longest common prefix"
235  if not m:
236  return ''
237  s1 = min(m)
238  s2 = max(m)
239  for i, c in enumerate(s1):
240  if c != s2[i]:
241  return s1[:i]
242  return s1
243 
244 
245 def format_value(val):
246  if isinstance(val, list):
247  return '[%s]' % ', '.join("'%s'" % str(v) for v in val)
248  else:
249  return str(val)
250 
251 
252 def factor_prefix(vals, do_it):
253  vals = [format_value(v) for v in vals]
254  prefix = commonprefix(vals) if len(vals) > 1 and do_it else ''
255  joined = ', '.join(v[len(prefix):] for v in vals)
256  return '%s[%s]' % (prefix, joined) if prefix else joined
257 
258 
259 def call(op, inputs=None, outputs=None, factor_prefixes=False):
260  if not inputs:
261  inputs = ''
262  else:
263  inputs_v = [a for a in inputs if not isinstance(a, tuple)]
264  inputs_kv = [a for a in inputs if isinstance(a, tuple)]
265  inputs = ', '.join(
266  x
267  for x in chain(
268  [factor_prefix(inputs_v, factor_prefixes)],
269  ('%s=%s' % kv for kv in inputs_kv),
270  )
271  if x
272  )
273  call = '%s(%s)' % (op, inputs)
274  return call if not outputs else '%s = %s' % (
275  factor_prefix(outputs, factor_prefixes), call)
276 
277 
278 def format_device_option(dev_opt):
279  if not dev_opt or not (
280  dev_opt.device_type or dev_opt.cuda_gpu_id or dev_opt.node_name):
281  return None
282  return call(
283  'DeviceOption',
284  [dev_opt.device_type, dev_opt.cuda_gpu_id, "'%s'" % dev_opt.node_name])
285 
286 
287 @Printer.register(OperatorDef)
288 def print_op(text, op):
289  args = [(a.name, _arg_val(a)) for a in op.arg]
290  dev_opt_txt = format_device_option(op.device_option)
291  if dev_opt_txt:
292  args.append(('device_option', dev_opt_txt))
293 
294  if text.c2_net_name:
295  text.add(call(
296  text.c2_net_name + '.' + op.type,
297  [list(op.input), list(op.output)] + args))
298  else:
299  text.add(call(
300  op.type,
301  list(op.input) + args,
302  op.output,
303  factor_prefixes=text.factor_prefixes))
304  for arg in op.arg:
305  if arg.HasField('n'):
306  with text.context('arg: %s' % arg.name):
307  text(arg.n)
308 
309 @Printer.register(NetDef)
310 def print_net_def(text, net_def):
311  if text.c2_syntax:
312  text.add(call('core.Net', ["'%s'" % net_def.name], [net_def.name]))
313  text.c2_net_name = net_def.name
314  else:
315  text.add('# net: %s' % net_def.name)
316  for op in net_def.op:
317  text(op)
318  if text.c2_syntax:
319  text.c2_net_name = None
320 
321 
322 @Printer.register(Net)
323 def print_net(text, net):
324  text(net.Proto())
325 
326 
327 def _get_step_context(step):
328  proto = step.Proto()
329  if proto.should_stop_blob:
330  return call('loop'), False
331  if proto.num_iter and proto.num_iter != 1:
332  return call('loop', [proto.num_iter]), False
333  if proto.num_concurrent_instances > 1:
334  return (
335  call('parallel',
336  [('num_instances', proto.num_concurrent_instances)]),
337  len(step.Substeps()) > 1)
338  concurrent = proto.concurrent_substeps and len(step.Substeps()) > 1
339  if concurrent:
340  return call('parallel'), True
341  if proto.report_net:
342  return call('run_once'), False
343  return None, False
344 
345 
346 @Printer.register(ExecutionStep)
347 def print_step(text, step):
348  proto = step.Proto()
349  step_ctx, do_substep = _get_step_context(step)
350  with text.context(step_ctx):
351  if proto.report_net:
352  with text.context(call('report_net', [proto.report_interval])):
353  text(step.get_net(proto.report_net))
354  substeps = step.Substeps() + [step.get_net(n) for n in proto.network]
355  for substep in substeps:
356  sub_proto = (
357  substep.Proto() if isinstance(substep, ExecutionStep) else None)
358  if sub_proto is not None and sub_proto.run_every_ms:
359  substep_ctx = call(
360  'reporter',
361  [str(substep), ('interval_ms', sub_proto.run_every_ms)])
362  elif do_substep:
363  title = (
364  'workspace'
365  if sub_proto is not None and sub_proto.create_workspace else
366  'step')
367  substep_ctx = call(title, [str(substep)])
368  else:
369  substep_ctx = None
370  with text.context(substep_ctx):
371  text(substep)
372  if proto.should_stop_blob:
373  text.add(call('yield stop_if', [proto.should_stop_blob]))
374 
375 
376 def _print_task_output(x):
377  assert isinstance(x, TaskOutput)
378  return 'Output[' + ', '.join(str(x) for x in x.names) + ']'
379 
380 
381 @Printer.register(Task)
382 def print_task(text, task):
383  outs = ', '.join(_print_task_output(o) for o in task.outputs())
384  context = [('node', task.node), ('name', task.name), ('outputs', outs)]
385  with text.context(call('Task', context)):
386  text(task.get_step())
387 
388 
389 @Printer.register(TaskGroup)
390 def print_task_group(text, tg, header=None):
391  with text.context(header or call('TaskGroup')):
392  for task in tg.tasks_by_node().tasks():
393  text(task)
394 
395 
396 @Printer.register(Job)
397 def print_job(text, job):
398  text(job.init_group, 'Job.current().init_group')
399  text(job.epoch_group, 'Job.current().epoch_group')
400  with text.context('Job.current().stop_signals'):
401  for out in job.stop_signals:
402  text.add(_print_task_output(out))
403  text(job.download_group, 'Job.current().download_group')
404  text(job.exit_group, 'Job.current().exit_group')
405 
406 
407 def to_string(obj, **kwargs):
408  """
409  Given a Net, ExecutionStep, Task, TaskGroup or Job, produces a string
410  with detailed description of the execution steps.
411  """
412  printer = Printer(**kwargs)
413  printer(obj)
414  return str(printer)
415 
416 
417 def debug_net(net):
418  """
419  Given a Net, produce another net that logs info about the operator call
420  before each operator execution. Use for debugging purposes.
421  """
422  assert isinstance(net, Net)
423  debug_net = Net(str(net))
424  assert isinstance(net, Net)
425  for op in net.Proto().op:
426  text = Text()
427  print_op(op, text)
428  debug_net.LogInfo(str(text))
429  debug_net.Proto().op.extend([op])
430  return debug_net
Module caffe2.python.workspace.
Module caffe2.python.context.