Caffe2 - Python API
A deep learning, cross platform ML framework
core.py
1 ## @package core
2 # Module caffe2.python.core
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from collections import namedtuple, OrderedDict, defaultdict
9 from past.builtins import basestring
10 from future.utils import viewitems, viewkeys, viewvalues
11 from itertools import chain
12 from six import binary_type, string_types, text_type
13 
14 from caffe2.proto import caffe2_pb2
15 from caffe2.python import scope, utils, workspace
17  gen_do_gradient, gen_if_gradient, gen_while_gradient, disambiguate_grad_if_op_output
18 
20 
21 import copy
22 import pickle
23 import numpy as np
24 import sys
25 import traceback
26 import os
27 
28 # Mac os specific message
29 if (sys.platform == 'darwin' and 'leveldb' in C.registered_dbs()):
30  print('If you are using homebrew leveldb on a Mac OS, you might see an '
31  'error warning you that malloc_zone_unregister() failed. This is '
32  'not a caffe2 issue but is due to the homebrew leveldb having an '
33  'incompatible memory allocator. It does not affect usage.')
34 
35 # Convenience redirections to functions inside scope.
36 DeviceScope = scope.DeviceScope
37 NameScope = scope.NameScope
38 
39 
40 # Bring datatype enums to the main namespace
41 class DataType:
42  pass
43 
44 
45 def _InitDataType():
46  for name, value in caffe2_pb2.TensorProto.DataType.items():
47  setattr(DataType, name, value)
48 
49 
50 _InitDataType()
51 
52 
53 def _GetRegisteredOperators():
54  return set(workspace.RegisteredOperators())
55 
56 
57 _REGISTERED_OPERATORS = _GetRegisteredOperators()
58 
59 
60 def RefreshRegisteredOperators():
61  global _REGISTERED_OPERATORS
62  _REGISTERED_OPERATORS = _GetRegisteredOperators()
63 
64 
65 _GLOBAL_INIT_ARGS = []
66 
67 
68 def GlobalInit(args):
69  _GLOBAL_INIT_ARGS.extend(args[1:])
70  C.global_init(args)
71 
72 
73 def GetGlobalInitArgs():
74  return _GLOBAL_INIT_ARGS[:]
75 
76 
77 def IsOperator(op_type):
78  return IsOperatorWithEngine(op_type, engine='DEFAULT')
79 
80 
81 def IsOperatorWithEngine(op_type, engine):
82  return C.op_registry_key(op_type, engine) in _REGISTERED_OPERATORS
83 
84 
85 def IsGPUDeviceType(device_type):
86  return device_type in {caffe2_pb2.CUDA, caffe2_pb2.HIP}
87 
88 
89 def DeviceOption(
90  device_type,
91  device_id=0,
92  random_seed=None,
93  node_name=None,
94  numa_node_id=None,
95  extra_info=None,
96 ):
97  option = caffe2_pb2.DeviceOption()
98  option.device_type = device_type
99  option.device_id = device_id
100  if node_name is not None:
101  option.node_name = node_name
102  if random_seed is not None:
103  option.random_seed = random_seed
104  if numa_node_id is not None:
105  assert device_type == caffe2_pb2.CPU
106  option.numa_node_id = numa_node_id
107  if extra_info is not None:
108  option.extra_info.extend(extra_info)
109  return option
110 
111 
112 def device_option_equal(opt1, opt2, ignore_node_name=True, ignore_random_seed=True):
113  if not opt1 or not opt2:
114  return opt1 == opt2
115  if not ignore_node_name and opt1.node_name != opt2.node_name:
116  return False
117  if not ignore_random_seed and opt1.random_seed != opt2.random_seed:
118  return False
119  if not opt1.device_type or not opt2.device_type:
120  # At least one option is for CPU, check if both are for CPU.
121  return not opt1.device_type and not opt2.device_type
122  return opt1.device_id == opt2.device_id
123 
124 
125 def InferBlobDevices(net):
126  '''
127  Compute mapping from parameters to devices by looking at the
128  device option of the op that creates the blob has
129  '''
130  mapping = {}
131  for op in net.Proto().op:
132  op_device = op.device_option
133  if op_device is None:
134  op_device = caffe2_pb2.DeviceOption(caffe2_pb2.CPU)
135  # TODO: T18892922, use device annotations
136  for b in op.output:
137  mapping[b] = op_device
138  return mapping
139 
140 
141 def InferOpBlobDevicesAsDict(op):
142  input_dev_list, output_dev_list = InferOpBlobDevices(op)
143  input_dict = {
144  op.input[i]: input_dev_list[i]
145  for i in range(len(op.input))
146  }
147  output_dict = {
148  op.output[i]: output_dev_list[i]
149  for i in range(len(op.output))
150  }
151  return input_dict, output_dict
152 
153 
154 def InferOpBlobDevices(op):
155  device_info = C.infer_op_input_output_device(op.SerializeToString())
156  input_info = []
157  output_info = []
158  for dev_str in device_info[0]:
159  device_option = caffe2_pb2.DeviceOption()
160  device_option.ParseFromString(dev_str)
161  input_info.append(device_option)
162  for dev_str in device_info[1]:
163  device_option = caffe2_pb2.DeviceOption()
164  device_option.ParseFromString(dev_str)
165  output_info.append(device_option)
166  return input_info, output_info
167 
168 
169 def InferOpDeviceAsBlobDevices(op):
170  op_dev = op.device_option if op.device_option else caffe2_pb2.DeviceOption()
171  input_dev = [op_dev] * len(op.input)
172  output_dev = [op_dev] * len(op.output)
173  return input_dev, output_dev
174 
175 
176 GradientSlice = namedtuple('GradientSlice', ['indices', 'values'])
177 
178 
179 class BlobReference(object):
180  """A wrapper around a blob in a net.
181 
182  BlobReference gives us a way to refer to the network that the blob is
183  generated from. Note that blobs are, essentially, just strings in the
184  current workspace.
185  """
186 
187  def __init__(self, name, net=None):
188  """Initializes a blob reference.
189 
190  Note that this does not prepends the namescope. If needed, use
191  ScopedBlobReference() to prepend the existing namespace.
192  """
193  if isinstance(name, string_types):
194  self._name = name
195  elif isinstance(name, binary_type):
196  self._name = name.decode('utf-8')
197  else:
198  self._name = str(name)
199  self._from_net = net
200  # meta allows helper functions to put whatever metainformation needed
201  # there.
202  self.meta = {}
203 
204  def __hash__(self):
205  return hash(self._name)
206 
207  def __eq__(self, other):
208  if isinstance(other, string_types):
209  return self._name == other
210  elif isinstance(other, binary_type):
211  return self._name == other.decode('utf-8')
212  elif isinstance(other, BlobReference):
213  return self._name == other._name
214  else:
215  return False
216 
217  def __ne__(self, other):
218  return not(self == other)
219 
220  def __str__(self):
221  return self._name
222 
223  def __repr__(self):
224  return 'BlobReference("{}")'.format(self._name)
225 
226  def __add__(self, other):
227  if not isinstance(other, string_types):
228  raise RuntimeError('Cannot add BlobReference to a non-string.')
229  return BlobReference(self._name + other, self._from_net)
230 
231  def __radd__(self, other):
232  if not isinstance(other, string_types):
233  raise RuntimeError('Cannot add a non-string to BlobReference.')
234  return BlobReference(other + self._name, self._from_net)
235 
236  def Net(self):
237  return self._from_net
238 
239  def GetNameScope(self):
240  return self._name[:self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1]
241 
242  def GetUnscopedName(self):
243  return self._name[self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1:]
244 
245  def _CreateAndAddToNet(self, op_type, inputs=None, *args, **kwargs):
246  """Internal function that routes the operator generation to the
247  network's __getattr__ function.
248  """
249  inputs = [] if inputs is None else inputs
250  if isinstance(inputs, BlobReference) or isinstance(inputs, string_types):
251  inputs = [inputs]
252  # add self to the input list.
253  inputs.insert(0, self)
254  return self._from_net.__getattr__(op_type)(inputs, *args, **kwargs)
255 
256  def __getattr__(self, op_type):
257  """A wrapper allowing one to initiate operators from a blob reference.
258 
259  Example: for a blob reference b that comes from network n, doing
260  b.Relu(...)
261  is equivalent to doing
262  net.Relu([b], ...)
263  """
264  if op_type.startswith('__'):
265  raise AttributeError('Attribute {} not found.'.format(op_type))
266  if self._from_net is None:
267  raise RuntimeError(
268  'You cannot use a blob reference that does not have a net '
269  'source to create operators. Create the operator from an '
270  'explicit net object.')
271  if not IsOperator(op_type):
272  raise RuntimeError(
273  'Method ' + op_type + ' is not a registered operator.' +
274  ' Did you mean: [' +
275  ",".join(workspace.C.nearby_opnames(op_type)) + ']'
276  )
277  return lambda *args, **kwargs: self._CreateAndAddToNet(
278  op_type, *args, **kwargs)
279 
280  def __dir__(self):
281  additional_methods = [
282  op
283  for op in _REGISTERED_OPERATORS
284  if '_ENGINE_' not in op or '_ENGINE_CUDNN' in op]
285  return sorted(set(chain(
286  dir(type(self)),
287  viewkeys(self.__dict__),
288  additional_methods
289  )))
290 
291 
292 def ScopedName(name):
293  """prefix the name with the current scope."""
294  if isinstance(name, binary_type):
295  name = name.decode('ascii')
296  return scope.CurrentNameScope() + name
297 
298 
299 def ScopedBlobReference(name, *args, **kwargs):
300  """Returns a blob reference with scope prefixed."""
301  return BlobReference(ScopedName(name), *args, **kwargs)
302 
303 
304 def _RectifyInputOutput(blobs, net=None):
305  """A helper function to rectify the input or output of the CreateOperator
306  interface.
307  """
308  if isinstance(blobs, string_types) or isinstance(blobs, binary_type):
309  # If blobs is a single string, prepend scope.CurrentNameScope()
310  # and put it as a list.
311  # TODO(jiayq): enforce using BlobReference instead of raw strings.
312  return [ScopedBlobReference(blobs, net=net)]
313  elif type(blobs) is BlobReference:
314  # If blob is a BlobReference, simply put it as a list.
315  return [blobs]
316  elif type(blobs) in (list, tuple):
317  # If blob is a list, we go through it and type check.
318  rectified = []
319  for blob in blobs:
320  if isinstance(blob, string_types) or isinstance(blob, binary_type):
321  rectified.append(ScopedBlobReference(blob, net=net))
322  elif type(blob) is BlobReference:
323  rectified.append(blob)
324  else:
325  raise TypeError(
326  "I/O blob #{} of unsupported type: {} of type {}"
327  .format(len(rectified), str(blob), type(blob)))
328  return rectified
329  else:
330  raise TypeError(
331  "Unknown input/output type: %s of type %s." %
332  (str(blobs), type(blobs))
333  )
334 
335 
336 def CreateOperator(
337  operator_type,
338  inputs,
339  outputs,
340  name='',
341  control_input=None,
342  device_option=None,
343  arg=None,
344  engine=None,
345  debug_info=None,
346  **kwargs
347 ):
348  """A function wrapper that allows one to create operators based on the
349  operator type. The type should be a string corresponding to an operator
350  registered with Caffe2.
351  """
352  operator = caffe2_pb2.OperatorDef()
353  if (os.environ.get('CAFFE2_DEBUG')):
354  stack = traceback.format_stack()
355  operator.debug_info = "".join(stack[:-1])
356 
357  operator.type = operator_type
358  operator.name = name
359  # Add rectified inputs and outputs
360  inputs = _RectifyInputOutput(inputs)
361  outputs = _RectifyInputOutput(outputs)
362  operator.input.extend([text_type(i) for i in inputs])
363  operator.output.extend([text_type(o) for o in outputs])
364  if control_input:
365  control_input = _RectifyInputOutput(control_input)
366  operator.control_input.extend([text_type(i) for i in control_input])
367  # Set device option:
368  # (1) If device_option is explicitly set, use device_option.
369  # (2) If not, but scope.CurrentDeviceScope() is set,
370  # then we use scope.CurrentDeviceScope().
371  # (3) Otherwise, do not set device option.
372  if device_option is not None:
373  operator.device_option.CopyFrom(device_option)
374  elif scope.CurrentDeviceScope() is not None:
375  operator.device_option.CopyFrom(scope.CurrentDeviceScope())
376  if engine is not None:
377  operator.engine = engine
378  if debug_info is not None:
379  operator.debug_info = debug_info
380  # random seed is defined in the device option, so we need to do special
381  # care.
382 
383  if 'random_seed' in kwargs:
384  operator.device_option.random_seed = kwargs['random_seed']
385  del kwargs['random_seed']
386  # Add given arguments that do not need parsing
387  if arg is not None:
388  operator.arg.extend(arg)
389  # Add all other arguments
390  for key, value in viewitems(kwargs):
391  if value is not None:
392  operator.arg.add().CopyFrom(utils.MakeArgument(key, value))
393 
394  if workspace.IsImmediate():
395  workspace.RunOperatorImmediate(operator)
396  return operator
397 
398 
399 def _RegisterPythonImpl(
400  f, grad_f=None, python_func_type=None, pass_workspace=False
401 ):
402  if python_func_type:
403  func = python_func_type(f)
404  f = func.forward
405  grad_f = func.backward
406  else:
407  if isinstance(f, tuple):
408  f = f[0](*f[1], **f[2])
409  if isinstance(grad_f, tuple):
410  grad_f = grad_f[0](*grad_f[1], **grad_f[2])
411 
412  token = C.register_python_op(f, pass_workspace, '')
413  if grad_f:
414  C.register_python_gradient_op(token, grad_f)
415  return token
416 
417 
418 def CreatePythonOperator(
419  f, inputs,
420  outputs,
421  grad_f=None,
422  pass_workspace=False,
423  python_func_type=None,
424  *args,
425  **kwargs
426 ):
427  """
428  `f` should have a signature (inputs, outputs)
429 
430  If `pass_workspace` is True, the signature is changed to
431  (inputs, outputs, workspace) where `workspace` is the workspace the op
432  is going to run on. This is potentially dangerous (as the op can manipulate
433  the workspace directly), use on your own risk.
434  """
435  kwargs["token"] = _RegisterPythonImpl(
436  f, grad_f, python_func_type, pass_workspace=pass_workspace
437  )
438  return CreateOperator("Python", inputs, outputs, *args, **kwargs)
439 
440 
441 def GetIndexFromGradientList(g_list, name):
442  """A helper function to get the index from a gradient list, None if not
443  matching."""
444  for i, g in enumerate(g_list):
445  if g == name:
446  return i
447  elif type(g) is GradientSlice:
448  if (g.indices == name or g.values == name):
449  return i
450  return None
451 
452 
453 OpSSA = namedtuple('OpSSA', ['op', 'in_versions', 'out_versions'])
454 GradGenMeta = namedtuple('GradGenMeta', ['grad_op', 'idx', 'gradient'])
455 SparseGradGenMeta = namedtuple('SparseGradGenMeta', [
456  'grad_op_indices', 'idx_indices',
457  'grad_op_values', 'idx_values',
458  'gradient',
459 ])
460 
461 
462 class IR(object):
463  """A simple IR class to keep track of all intermediate representations used
464  in the gradient computation.
465  """
466 
467  def __init__(self, operators):
468  # The IR class holds multiple metadata from the forward pass:
469  # a) ssa: a list of [op, in_versions, out_versions] recording the
470  # input and the output version of each operator, similar
471  # to a normal SSA form.
472  # b) input_usages: a dictionary specifying for each blob and
473  # each of its version, how many times it is used as input for another
474  # op.
475  # c) frontier: maintaining the current versions of the blobs
476  # we are having in the workspace, after the execution of all the ops
477  # added to the IR so far. This is useful because if a gradient is
478  # trying to access an earlier version of a blob, we can sanity check
479  # that it is no longer there, and thus throw an error.
480  # d) gradient_frontier: maps the names of blobs to its version that the
481  # gradient corresponds to.
482  # e) gradient_generators: for each blob and each of its version, maps to
483  # a list of operators that generates its gradient together with the
484  # gradient name.
485  self.ssa = []
486  self.input_usages = defaultdict(lambda: defaultdict(list))
487  self.frontier = defaultdict(int)
488  self.gradient_frontier = {}
489  self.gradient_generators = defaultdict(lambda: defaultdict(list))
490  self.out_version_history = defaultdict(list)
491  self.in_version_history = defaultdict(list)
492 
493  for op in operators:
494  self.Play(op)
495 
496  self.SanityCheck(operators)
497 
498  def SanityCheck(self, operators):
499  # Validate StopGradient usage by checking that StopGradient's output
500  # is actually passed forward
501  for op in operators:
502  if op.type == 'StopGradient':
503  if op.output[0] not in self.input_usages:
504  raise ValueError("""StopGradient's output '{}' is orphan.
505 You typically want to specify same input and output for
506 StopGradient. Op:\n\n{}""".format(op.output[0], str(op)))
507 
508  def Play(self, op):
509  """"Adds an op to the current IR, and update the internal states to
510  reflect the blobs and versions after the execution of the op.
511  """
512  # For input, they are the current version in the dict.
513  in_versions = {}
514  for s in op.input:
515  in_versions[s] = self.frontier[s]
516  self.input_usages[s][self.frontier[s]].append(len(self.ssa))
517  self.in_version_history[s].append((op, self.frontier[s]))
518  # For output, they are the current version plus one. If this is a
519  # newly created blob, its version starts with zero.
520  out_versions = {}
521  for s in op.output:
522  if s in self.frontier:
523  self.frontier[s] += 1
524  out_versions[s] = self.frontier[s]
525  self.out_version_history[s].append((op, self.frontier[s]))
526  # Add to SSA for bookkeeping.
527  self.ssa.append(OpSSA(op, in_versions, out_versions))
528 
530  self, grad_op_input, g_output, fwd_op_idx, locally_generated_blobs):
531  """Checks if the gradient operators can be correctly carried out."""
532  forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
533  original_index = GetIndexFromGradientList(g_output, grad_op_input)
534 
535  # Functions to generate debug help for version-mismatches
536  def versionMismatchInfoOut(name):
537  s = "DEBUG HELP:\n"
538  s += "Maybe you use same output blob twice for different ops?\n"
539  s += "== Version history of blob [{}]\n".format(name)
540  for (op, vers) in self.out_version_history[name]:
541  s += "Version (out) {} <-- {}".format(vers, op)
542  s += "\n"
543  return s
544 
545  def versionMismatchInfoIn(name):
546  s = "DEBUG HELP:\n"
547  s += "Maybe the blob was overwritten by another op?\n"
548  s += "== Version history of blob [{}]\n".format(name)
549  for (op, vers) in self.in_version_history[name]:
550  s += "version (in) {} <-- {}".format(vers, op)
551  s += "\n"
552  return s
553 
554  # If it is a dense or sparse gradient name, it should match the
555  # version of the corresponding output.
556  if original_index is not None:
557  original_name = forward_op.output[original_index]
558  if (out_versions[original_name] !=
559  self.gradient_frontier[original_name]):
560  raise RuntimeError(
561  'Gradient name "%s" is expected to correspond '
562  'to version %d of "%s", but currently we have '
563  'version %d.\n\n' % (
564  grad_op_input, out_versions[original_name],
565  original_name,
566  self.gradient_frontier[original_name]) +
567  versionMismatchInfoOut(original_name))
568  # If it is an output name, the current version should match the
569  # version when the operator was run.
570  elif grad_op_input in out_versions:
571  if self.frontier[grad_op_input] != out_versions[grad_op_input]:
572  raise RuntimeError(
573  'Gradient operator needs output "%s" at version'
574  ' %d, but currently we have version %d.\n\n' % (
575  grad_op_input, out_versions[grad_op_input],
576  self.frontier[grad_op_input]
577  ) + versionMismatchInfoOut(grad_op_input)
578  )
579  # If it is an input name, the current version should match the
580  # version when the operator was run.
581  elif grad_op_input in in_versions:
582  if (self.frontier[grad_op_input] != in_versions[grad_op_input]):
583  raise RuntimeError(
584  'Gradient operator needs input "%s" at version '
585  '%d, but currently we have version %d.\n\n' % (
586  grad_op_input, in_versions[grad_op_input],
587  self.frontier[grad_op_input]
588  ) + versionMismatchInfoIn(grad_op_input)
589  )
590  # If it is none of the above, it should be a blob that is
591  # generated locally by one of the previous gradient operators.
592  else:
593  if grad_op_input not in locally_generated_blobs:
594  raise RuntimeError(
595  'Blob name "%s" not in the scope of operator: '
596  '%s\nand is not generated by any of the local '
597  'gradient operators.' % (grad_op_input, str(forward_op))
598  )
599 
600  def AppendSparseGenerators(self, sparse_generators):
601  # merge indices and values generators for sparse gradients
602  for name, input_generators in viewitems(sparse_generators):
603  for version, generators in viewitems(input_generators):
604  if len(generators) == 1:
605  # either indices or values are generated (but not both)
606  generator = generators[0]
607  else:
608  # both indices and values are generated
609  assert(len(generators) == 2)
610  op1_i, idx1_i, op1_v, idx1_v, g1 = generators[0]
611  op2_i, idx2_i, op2_v, idx2_v, g2 = generators[1]
612  assert(g1 == g2)
613  assert(op1_i is None or op2_i is None)
614  assert(op1_v is None or op2_v is None)
615  assert(idx1_i == 0 or idx2_i == 0)
616  assert(idx1_v == 0 or idx2_v == 0)
617  generator = SparseGradGenMeta(
618  op1_i or op2_i, idx1_i + idx2_i,
619  op1_v or op2_v, idx1_v + idx2_v,
620  g1)
621  self.gradient_generators[name][version].append(generator)
622 
623  def BuildGradientGenerators( # NOQA
624  self, fwd_op_idx, gradient_ops, g_output, g_input):
625  """Updates gradient_generators and gradient_frontier"""
626  forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
627  locally_generated_blobs = []
628  sparse_generators = defaultdict(lambda: defaultdict(list))
629 
630  for grad_op in gradient_ops:
631  # (1) check that inputs are valid
632  for s in grad_op.input:
634  s, g_output, fwd_op_idx, locally_generated_blobs)
635 
636  # (2) add outputs to the locally generated blobs
637  # If an output corresponds to the gradient of an input, we also
638  # record it to gradient_generators
639  locally_generated_blobs.extend([str(s) for s in grad_op.output])
640  for i, output in enumerate(grad_op.output):
641  input_index = GetIndexFromGradientList(g_input, output)
642  if input_index is not None:
643  input_name = forward_op.input[input_index]
644  input_version = in_versions[input_name]
645  g = g_input[input_index]
646  if type(g) is GradientSlice:
647  # the output corresponds either to the indices or the
648  # values of the sparse gradient. In either case we
649  # create a (partial) SparseGradGenMeta. If necessary,
650  # we'll merge indices and values generators
651  # corresponding to the same gradient in step (3)
652  if g.indices == output:
653  m = SparseGradGenMeta(grad_op, i, None, 0, g)
654  else:
655  assert(g.values == output)
656  m = SparseGradGenMeta(None, 0, grad_op, i, g)
657  sparse_generators[input_name][input_version].append(m)
658  else:
659  self.gradient_generators[input_name][input_version] \
660  .append(GradGenMeta(
661  grad_op, i, g))
662 
663  # (3) merge indices and values generators for sparse gradients, and
664  # add them to gradient_generators
665  self.AppendSparseGenerators(sparse_generators)
666 
667  # (4) for ops (e.g., Add, Sum, Sub) which have gradient outputs directly
668  # passed from inputs (not computed from gradient ops), we create an
669  # GradGenMeta with None grad_op and idx so that the gradient_generators
670  # knows where the gradients are coming from. This is needed for creating
671  # Sum op to accumulate the gradients from multiple parents.
672  for input_index, g in enumerate(g_input):
673  input_name = forward_op.input[input_index]
674  input_version = in_versions[input_name]
675  if not g:
676  continue
677  if type(g) is GradientSlice:
678  if str(g.indices) not in locally_generated_blobs and \
679  str(g.values) not in locally_generated_blobs:
680  self.gradient_generators[input_name][input_version].append(
681  SparseGradGenMeta(None, 0, None, 0, g))
682  else:
683  if str(g) not in locally_generated_blobs:
684  self.gradient_generators[input_name][input_version].append(
685  GradGenMeta(None, 0, g))
686 
687  # Finally, for the gradients specified in g_input, we update the
688  # gradient frontier to reflect the input versions that the gradients
689  # correspond to.
690  for i, g in enumerate(g_input):
691  if g is not None:
692  input_name = forward_op.input[i]
693  input_version = in_versions[input_name]
694  self.gradient_frontier[input_name] = input_version
695 
696  def _GetSumOpOutputName(self, generator, input_name):
697  def remove_suffix(s, suffix):
698  if s.endswith(suffix):
699  return s[:-len(suffix)]
700  return s
701 
702  for g in generator:
703  if type(g) is GradGenMeta:
704  grad_op, idx, _ = g
705  if grad_op:
706  return grad_op.output[idx]
707  else:
708  assert(type(g) is SparseGradGenMeta)
709  op_i, idx_i, op_v, idx_v, _ = g
710  if op_i:
711  return remove_suffix(op_i.output[idx_i], '_indices')
712  if op_v:
713  return remove_suffix(op_v.output[idx_v], '_values')
714 
715  return input_name + '_grad'
716 
717  def _SetSumOpsDeviceOption(self, sum_ops, generators):
718  # we already checked that device options are consistent so we can just
719  # use the first one we find
720  for generator in generators:
721  grad_op = generator.grad_op if type(generator) is GradGenMeta \
722  else generator.grad_op_values or generator.grad_op_indices
723  if grad_op:
724  if grad_op.HasField('device_option'):
725  for op in sum_ops:
726  op.device_option.CopyFrom(grad_op.device_option)
727  del op.device_option.extra_info[:]
728  break
729 
730  def _DisambiguateGradOpOutput(self, grad_op, idx, cnt):
731  new_grad_output = (
732  '_' + grad_op.output[idx] + '_autosplit_{}'.format(cnt))
733  if grad_op.type == "If":
734  disambiguate_grad_if_op_output(grad_op, idx, new_grad_output)
735  else:
736  grad_op.output[idx] = new_grad_output
737  return grad_op.output[idx], cnt + 1
738 
739  def _CheckSumOpsConflict(self, out_base_name, g):
740  if str(out_base_name) == str(g):
741  # TODO not sure what this message really means
742  raise RuntimeError(
743  'The gradient output of empty gradient op can not '
744  'be the same as the normal name of the current '
745  'input gradient.')
746 
747  def _MakeDenseSumOps(self, generators, out_base_name):
748  sum_op_input = []
749  cnt = 0
750 
751  assert len(generators) > 1
752 
753  first_grad_op = True
754  for generator in generators:
755  grad_op, idx, g = generator
756  assert(type(g) is not GradientSlice)
757  if grad_op:
758  if first_grad_op:
759  first_grad_op = False
760  out = grad_op.output[idx]
761  else:
762  out, cnt = self._DisambiguateGradOpOutput(grad_op, idx, cnt)
763  sum_op_input.append(out)
764  else:
765  self._CheckSumOpsConflict(out_base_name, g)
766  sum_op_input.append(str(g))
767 
768  if out_base_name in sum_op_input:
769  # Sum inplace mode works only for the first input
770  # So we do a swap
771  idx = sum_op_input.index(out_base_name)
772  sum_op_input[0], sum_op_input[idx] = (
773  sum_op_input[idx], sum_op_input[0]
774  )
775  sum_ops = [CreateOperator(
776  "Sum",
777  [BlobReference(x) for x in sum_op_input],
778  BlobReference(out_base_name))]
779  return sum_ops, out_base_name
780 
781  def _MakeSparseSumOps(self, generators, out_base_name):
782  indices_concat_input = []
783  values_concat_input = []
784  cnt_i = 0
785  cnt_v = 0
786 
787  for generator in generators:
788  assert(type(generator) is SparseGradGenMeta)
789  op_i, idx_i, op_v, idx_v, g = generator
790  if op_i:
791  out, cnt_i = self._DisambiguateGradOpOutput(op_i, idx_i, cnt_i)
792  indices_concat_input.append(out)
793  else:
794  self._CheckSumOpsConflict(out_base_name, g.indices)
795  indices_concat_input.append(g.indices)
796  if op_v:
797  out, cnt_v = self._DisambiguateGradOpOutput(op_v, idx_v, cnt_v)
798  values_concat_input.append(out)
799  else:
800  self._CheckSumOpsConflict(out_base_name, g.values)
801  values_concat_input.append(g.values)
802 
803  indices_concat_output = out_base_name + '_indices_concat'
804  indices_concat_split = out_base_name + '_indices_concat_split'
805  values_concat_output = out_base_name + '_values_concat'
806  values_concat_split = out_base_name + '_values_concat_split'
807  # Sum the given sparse representations by simply concatenating the
808  # indices (resp. values) tensors together. We don't do any deduplication
809  # of indices at this point. This will be done as needed before the
810  # optimizer is called
811  sum_ops = [
812  CreateOperator(
813  "Concat",
814  [BlobReference(x) for x in indices_concat_input],
815  [BlobReference(x) for x in
816  [indices_concat_output, indices_concat_split]],
817  axis=0
818  ),
819  CreateOperator(
820  "Concat",
821  [BlobReference(x) for x in values_concat_input],
822  [BlobReference(x) for x in
823  [values_concat_output, values_concat_split]],
824  axis=0
825  ),
826  ]
827  sum_op_output = GradientSlice(
828  indices=indices_concat_output,
829  values=values_concat_output,
830  )
831  return sum_ops, sum_op_output
832 
833  def _MakeSumOps(self, input_name, input_version):
834  generators = self.gradient_generators[input_name][input_version]
835  out_base_name = self._GetSumOpOutputName(generators, input_name)
836  types = list(set(type(x) for x in generators))
837  assert(len(types) == 1)
838  if types[0] is GradGenMeta:
839  sum_ops, g = self._MakeDenseSumOps(generators, out_base_name)
840  else:
841  assert(types[0] is SparseGradGenMeta)
842  sum_ops, g = self._MakeSparseSumOps(generators, out_base_name)
843  self._SetSumOpsDeviceOption(sum_ops, generators)
844  return sum_ops, g
845 
846  def _VerifyGradientGenerators(self, generator):
847  # (1) check if all gradients are of the same type. Aggregating a mix of
848  # sparse and dense gradients is not supported yet
849  if len({type(g) for g in generator}) > 1:
850  raise RuntimeError(
851  'Automatic aggregation of a mix of sparse and dense gradients '
852  'is not supported yet')
853 
854  # If for all the operators that used the operator, none or only one
855  # produced the gradient, then no additional sum needs to be carried
856  # out.
857  if len(generator) < 2:
858  return False
859 
860  all_gradient_names = []
861  all_device_options = []
862  for g in generator:
863  if type(g) is GradGenMeta:
864  if g.grad_op:
865  all_gradient_names.append(g.gradient)
866  all_device_options.append(g.grad_op.device_option)
867  else:
868  assert(type(g) is SparseGradGenMeta)
869  if g.grad_op_indices:
870  all_device_options.append(g.grad_op_indices.device_option)
871  if g.grad_op_values:
872  all_device_options.append(g.grad_op_values.device_option)
873  all_gradient_names.append(g.gradient.values)
874 
875  # Check if all grad op device options are the same.
876  if len(all_device_options) >= 2 and not all(
877  device_option_equal(d, all_device_options[0])
878  for d in all_device_options[1:]):
879  raise RuntimeError('Unexpected behavior: not all grad ops '
880  'have the same device option.')
881  return True
882 
883  def DoGradientAccumulation(self, fwd_op_idx):
884  """For each input name in the forward op, check if we will need to
885  add gradient accumulation. If so, do gradient accumulation and return
886  the list of gradient operators.
887 
888  The criteria for doing gradient accumulation is:
889  (1) the specific input version has been used by multiple operators.
890  (2) the current fwd_op_idx is the first to use that input, i.e. in the
891  backward pass, is the last to optionally generate the gradient for
892  the op.
893  (3) For the operators that used the input, their gradient operators
894  have generated more than 1 gradient.
895 
896  When accumulating operators, our current solution is to rename all the
897  created gradients with an internal intermediate name, and then add a
898  Sum() operator that adds up all the gradients. This may use more memory
899  due to intermediate storage, but is usually the fastest approach as one
900  can do one single sum for multiple intermediate gradients.
901  """
902  forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
903  additional_sum_ops = []
904  grad_map = {}
905  for _i, input_name in enumerate(set(forward_op.input)):
906  input_version = in_versions[input_name]
907  input_usage = self.input_usages[input_name][input_version]
908  if (len(input_usage) <= 1 or fwd_op_idx != input_usage[0]):
909  # We do not need to do gradient accumulation yet.
910  continue
911  generator = self.gradient_generators[input_name][input_version]
912  try:
913  if not self._VerifyGradientGenerators(generator):
914  continue
915  except RuntimeError as err:
916  raise RuntimeError(
917  "Gradients for param ''{}'' failed to verify: {}".format(
918  input_name,
919  err
920  )
921  )
922 
923  # Finally, let's create the sum operator.
924  sum_ops, g = self._MakeSumOps(input_name, input_version)
925  additional_sum_ops.extend(sum_ops)
926  grad_map[input_name] = g
927  return additional_sum_ops, grad_map
928 
929  def _AppendAutoGradGenerator(self, y, grad, autograd_op):
930  # Gradient here is not sparse as it was generated by
931  # a ConstantFill operator. Autogeneration for sparse gradients is
932  # not supported
933  generator = GradGenMeta(
934  autograd_op, 0 if autograd_op else None, str(grad))
935 
936  self.gradient_generators[str(y)][self.frontier[str(y)]].append(
937  generator)
938 
939 
940  def _GetInitGradients(self, ys):
941  input_to_grad = {}
942  gradient_ops = []
943 
944  for y, g in viewitems(ys):
945  autograd_op = None
946  if g is None:
947  autograd_op = CreateOperator(
948  "ConstantFill", [y], [str(y) + "_autogen_grad"],
949  value=1.0)
950  gradient_ops.append(autograd_op)
951  g = autograd_op.output[0]
952  # Since the C++ gradient registry does not have notion of
953  # NameScopes, we will convert all references to strings.
954  input_to_grad[str(y)] = (
955  GradientSlice(str(g[0]), str(g[1]))
956  if isinstance(g, GradientSlice) else str(g))
957  # Autogenerated gradients are assumed to be provided for the last
958  # input version
959  if autograd_op is not None:
960  self._AppendAutoGradGenerator(y, g, autograd_op)
961 
962  return input_to_grad, gradient_ops
963 
964  def _GenerateGradientsForForwardOp(
965  self, forward_op_idx, input_to_grad):
966  new_input_to_grad = {}
967  gradient_ops = []
968  forward_op, in_versions, out_versions = self.ssa[forward_op_idx]
969  g_output = list(
970  input_to_grad.get(name, None) for name in forward_op.output)
971 
972  if not all(g is None for g in g_output) or (
973  forward_op.type == "ZeroGradient"):
974  gradient_ops, g_input = GradientRegistry.GetGradientForOp(
975  forward_op, g_output)
976  # Check if the gradient operators are legal, and update
977  # gradient_generators and gradient_frontier
979  forward_op_idx, gradient_ops, g_output, g_input)
980  # Record the gradient map to all_input_to_grad.
981  for name, grad in zip(forward_op.input, g_input):
982  # Do not overwrite an existing gradient with a None
983  # unless the input is also an output of the op, since
984  # we update the blob version when blob is output of an
985  # operator.
986  if grad is not None or \
987  name not in input_to_grad or \
988  name in list(forward_op.output):
989  new_input_to_grad[name] = grad
990 
991  return new_input_to_grad, gradient_ops
992 
993  def GetBackwardPass(self, ys):
994  """Gets the backward pass that computes the derivatives of given blobs.
995 
996  Inputs:
997  ys: a list or a dictionary specifying what blobs we want to compute
998  derivatives of. If the input is a list, we will automatically
999  generate their gradients with all-one values; if the input is a
1000  dictionary, for any dictionary entries that are not None, we will
1001  take the corresponding blobs as their gradients; for all those
1002  that are None, we will auto-fill them with 1.
1003  """
1004  if isinstance(ys, list):
1005  ys = dict((y, None) for y in ys)
1006  elif not isinstance(ys, dict):
1007  raise TypeError("ys should either be a list or a dict.")
1008 
1009  # Set the gradient frontier with the initialized external
1010  # gradients.
1011  for y in viewkeys(ys):
1012  self.gradient_frontier[y] = self.frontier[y]
1013  self.input_usages[str(y)][self.frontier[str(y)]].append(
1014  len(self.ssa))
1015 
1016  all_input_to_grad, all_gradient_ops = self._GetInitGradients(ys)
1017 
1018  # (2) Now, after having the virtual play above, we now play the ops
1019  # backwards, creating the gradients along the path. Note that although
1020  # we are playing it backwards, we cannot refer to variables that are
1021  # at a version older than current_versions because it is already been
1022  # overwritten.
1023  for forward_op_idx in reversed(range(len(self.ssa))):
1024  input_to_grad, gradient_ops = self._GenerateGradientsForForwardOp(
1025  forward_op_idx, all_input_to_grad)
1026  all_input_to_grad.update(input_to_grad)
1027  all_gradient_ops += gradient_ops
1028 
1029  # If there are multiple use blobs, do gradient accumulation.
1030  additional_sum_ops, grad_map = self.DoGradientAccumulation(
1031  forward_op_idx)
1032  # This line is so that if in an accumulation some of the operators
1033  # have not produced gradients, they still do not overwrite the
1034  # general all_input_to_grad map.
1035  all_input_to_grad.update(grad_map)
1036  all_gradient_ops += additional_sum_ops
1037 
1038  # (3) Post-processing.
1039  # After we have done computation for each op, we now have the gradient
1040  # operators ready. For the output map, we will convert everything to
1041  # BlobReferences for easier handling in python.
1042  all_input_to_grad_out = {}
1043  for key, val in viewitems(all_input_to_grad):
1044  if val is not None:
1045  if (isinstance(val, string_types) or
1046  isinstance(val, binary_type)):
1047  grad_out = BlobReference(val)
1048  else:
1049  grad_out = GradientSlice(BlobReference(val[0]),
1050  BlobReference(val[1]))
1051  all_input_to_grad_out[BlobReference(key)] = grad_out
1052  return all_gradient_ops, all_input_to_grad_out
1053 
1054 
1055 class GradientRegistry(object):
1056  """GradientRegistry holds the mapping from operators to their gradients."""
1057  gradient_registry_ = {}
1058 
1059  @classmethod
1060  def RegisterGradient(cls, op_type):
1061  """A decorator for registering gradient mappings."""
1062 
1063  def Wrapper(func):
1064  cls.gradient_registry_[op_type] = func
1065  return func
1066 
1067  return Wrapper
1068 
1069  @classmethod
1070  def _GetGradientForOpCC(cls, op_def, g_output):
1071  # TODO(tulloch) - Propagate GradientWrapper up through the stack.
1072  def from_untyped(grad):
1073  if grad is None:
1074  w = C.GradientWrapper()
1075  assert w.is_empty()
1076  return w
1077  try:
1078  (indices, values) = grad
1079  w = C.GradientWrapper()
1080  w.indices = indices
1081  w.values = values
1082  assert w.is_sparse()
1083  return w
1084  except ValueError:
1085  w = C.GradientWrapper()
1086  w.dense = grad
1087  assert w.is_dense()
1088  return w
1089 
1090  g_output = [from_untyped(grad) for grad in g_output]
1091  grad_defs_str, g_input = C.get_gradient_defs(
1092  op_def.SerializeToString(), g_output)
1093 
1094  def to_untyped(grad_wrapper):
1095  if grad_wrapper.is_empty():
1096  return None
1097  if grad_wrapper.is_sparse():
1098  return GradientSlice(grad_wrapper.indices, grad_wrapper.values)
1099  assert grad_wrapper.is_dense()
1100  return grad_wrapper.dense
1101 
1102  g_input = [to_untyped(grad_wrapper) for grad_wrapper in g_input]
1103  grad_defs = []
1104  for grad_def_str in grad_defs_str:
1105  grad_def = caffe2_pb2.OperatorDef()
1106  grad_def.ParseFromString(grad_def_str)
1107  grad_defs.append(grad_def)
1108  return grad_defs, g_input
1109 
1110  @classmethod
1111  def GetGradientForOp(cls, op, g_output):
1112  try:
1113  gradient_ops, g_input = cls._GetGradientForOpCC(op, g_output)
1114  except Exception as e:
1115  # Not supported in C++; will try python registration next.
1116  if op.type in cls.gradient_registry_:
1117  gradient_ops, g_input = cls.gradient_registry_[op.type](
1118  op, g_output
1119  )
1120  else:
1121  raise Exception(
1122  "Exception when creating gradient for [{}]:{}.\nOp: \n{}".
1123  format(op.type, e, str(op))
1124  )
1125 
1126  if gradient_ops is None:
1127  return [], g_input
1128  if type(gradient_ops) is not list:
1129  gradient_ops = [gradient_ops]
1130  return gradient_ops, g_input
1131 
1132  @classmethod
1133  def GetBackwardPass(cls, operators, ys, ys_generate_gradient=False):
1134  """Gets the backward pass for the list of operators.
1135 
1136  Args:
1137  operators: a list of operators constituting the forward pass.
1138  ys: a list or a dictionary specifying what blobs we want to compute
1139  derivatives of. If the input is a list, we will automatically
1140  generate their gradients with all-one values; if the input is a
1141  dictionary, for any dictionary entries that are not None, we'll
1142  take the corresponding blobs as their gradients; for all those
1143  that are None, we will auto-fill them with 1.
1144  Returns:
1145  gradient_ops: a list of gradient operators to run.
1146  all_input_to_grads: a map from input to their corresponding
1147  gradients.
1148  """
1149  ir = IR(operators)
1150  return ir.GetBackwardPass(ys)
1151 
1152 
1153 GradientRegistry.RegisterGradient('Do')(gen_do_gradient)
1154 GradientRegistry.RegisterGradient('If')(gen_if_gradient)
1155 GradientRegistry.RegisterGradient('While')(gen_while_gradient)
1156 
1157 
1158 def get_ssa(net, blob_versions=None):
1159  """
1160  Given a net, return a structure containing the version of each input and
1161  output blob used by each operator.
1162 
1163  Args:
1164  net: either a Net or a NetDef
1165  blob_versions: (optional) map with current version number for given
1166  blob names. If not provided or blob not found, start
1167  from version 0.
1168  Returns:
1169  Tuple (ssa, blob_versions)
1170  ssa: list of tuples (versioned_inputs, versioned_outputs)
1171  for each op in the net. A versioned input is a tuple
1172  (blob_name, version).
1173  blob_versions: updated map with latest version of each blob found in
1174  the net.
1175  """
1176  proto = net.Proto() if isinstance(net, Net) else net
1177  assert isinstance(proto, caffe2_pb2.NetDef)
1178  if blob_versions is None:
1179  blob_versions = {}
1180  if isinstance(net, list):
1181  return [get_ssa(n, blob_versions) for n in net], blob_versions
1182  for i in proto.external_input:
1183  if i not in blob_versions:
1184  blob_versions[str(i)] = 0
1185  ssa = []
1186  for op in proto.op:
1187  if not proto.external_input:
1188  for i in op.input:
1189  if i not in blob_versions:
1190  blob_versions[i] = 0
1191  inputs = [(str(i), blob_versions.get(str(i), 0)) for i in op.input]
1192  for o in op.output:
1193  blob_versions[str(o)] = blob_versions.get(str(o), 0) + 1
1194  outputs = [(str(o), blob_versions[str(o)]) for o in op.output]
1195  ssa.append((inputs, outputs))
1196  return ssa, blob_versions
1197 
1198 
1199 def get_undefined_blobs(ssa):
1200  """
1201  Given a ssa in the format produced by get_ssa(), return a set of blobs that
1202  are used before they are defined, which corresponds to inputs at version 0.
1203  """
1204  undef_blobs = set()
1205  for inputs, _outputs in ssa:
1206  undef_blobs |= set(name for (name, ver) in inputs if ver == 0)
1207  return undef_blobs
1208 
1209 
1210 def get_output_producers(ssa):
1211  """
1212  Given a ssa in the format produced by get_ssa(), returns a map from
1213  versioned blob into the operator index that produces that version of
1214  the blob. A versioned blob is a tuple (blob_name, version).
1215  """
1216  producers = {}
1217  for i, (_inputs, outputs) in enumerate(ssa):
1218  for o in outputs:
1219  producers[o] = i
1220  return producers
1221 
1222 
1223 def get_op_ids_in_path(ssa, blob_versions, inputs, outputs):
1224  """
1225  Given a ssa and blob_versions as produced by get_ssa(), returns the list
1226  of op indices that are necessary in order to generate the blobs in
1227  `outputs`, given blobs in `inputs`.
1228  Consider that the `inputs` are given in their latest version.
1229  """
1230  inputs_set = set((str(i), blob_versions[str(i)]) for i in inputs)
1231  producers = get_output_producers(ssa)
1232  queue = [(str(o), blob_versions[str(o)]) for o in outputs]
1233  used_op_ids = set()
1234  while len(queue) > 0:
1235  o = queue.pop()
1236  if (o not in inputs_set) and (o in producers):
1237  op_id = producers[o]
1238  if op_id not in used_op_ids:
1239  used_op_ids |= {op_id}
1240  inputs, _ = ssa[op_id]
1241  queue.extend(inputs)
1242  return sorted(used_op_ids)
1243 
1244 
1245 def recurrent_network_op_remap(op, prefix, blob_remap):
1246  """
1247  Parameters
1248  ----------
1249  op : Caffe2 operator (RecurrentNetworkOp or RecurrentNetworkGradientOp).
1250  prefix: this argument is not used in this function, just for legacy support.
1251  blob_remap : Dictionary that represents the map from old blob name to new.
1252 
1253  Updates blob names in arguments of RecurrentNetworkOp and
1254  RecurrentNetworkGradientOp to conform to cloned input and output of both
1255  operators and also makes sure names of locally generated blobs in arguments
1256  have the same prefix as the input and output of the operators.
1257  """
1258 
1259  def get_remapped_str(blob_str):
1260  if isinstance(blob_str, binary_type):
1261  blob_str = blob_str.decode('utf-8')
1262  return blob_remap.get(blob_str, blob_str).encode('utf-8')
1263 
1264  for argument in op.arg:
1265  if len(argument.strings) > 0:
1266  for i in range(len(argument.strings)):
1267  argument.strings[i] = get_remapped_str(argument.strings[i])
1268  elif argument.name == 'timestep':
1269  argument.s = get_remapped_str(argument.s)
1270  elif argument.name.endswith('step_net'):
1271  # argument is a proto
1272  remap_proto(argument, blob_remap)
1273 
1274 
1275 def control_op_remap(op, prefix, blob_remap):
1276  net_arg_names = []
1277  if op.type == "If":
1278  net_arg_names = ['then_net', 'else_net']
1279  else:
1280  net_arg_names = ['loop_net', 'cond_net']
1281  for argument in op.arg:
1282  if argument.name in net_arg_names:
1283  assert argument.n, \
1284  "Expected non empty net in " + op.type + "'s " + argument.name + " argument"
1285  subnet = Net(argument.n)
1286  remapped_subnet = subnet.Clone(
1287  name=(subnet._net.name if subnet._net.name else '') + '_remapped',
1288  blob_remap=blob_remap)
1289  argument.n.CopyFrom(remapped_subnet.Proto())
1290 
1291 
1292 DEFAULT_REMAP_FUNCS = {
1293  'RecurrentNetwork': recurrent_network_op_remap,
1294  'RecurrentNetworkGradient': recurrent_network_op_remap,
1295  'If': control_op_remap,
1296  'While': control_op_remap,
1297 }
1298 
1299 
1300 def remap_proto(argument, blob_remap):
1301  subnet = Net(argument.n)
1302 
1303  cloned_sub_net = subnet.Clone(
1304  'cloned_sub_net',
1305  blob_remap,
1306  )
1307 
1308  argument.n.CopyFrom(cloned_sub_net.Proto())
1309 
1310 
1311 def clone_and_bind_net(net, name, prefix, blob_remap=None, inputs=None,
1312  keep_schema=True):
1313  """
1314  Clone the given Net, binding its input schema to the given `inputs` record.
1315  Blob names defined by the net are prepended with the given `prefix`.
1316 
1317  Args:
1318  net: the net to clone
1319  name: the name of the new net
1320  prefix: the prefix to append to local blobs
1321  blob_remap: (optional) dict with additional blob name remapping.
1322  inputs: (optional) input record that will provide actual input
1323  values for the cloned net. Must be compatible with the
1324  net's input schema or be a strict superset of it
1325  keep_schema: by default (True), the original schema will be kept and
1326  remapped accordingly. otherwise, the schema will be set as
1327  inputs or left empty if inputs is not given.
1328  Returns:
1329  Tuple (cloned_net, blob_remap)
1330  clone_net: the cloned Net
1331  blob_remap: a map from original blob names into remapped blob names
1332  """
1333  from caffe2.python import schema
1334  assert isinstance(net, Net)
1335  if blob_remap is None:
1336  blob_remap = {}
1337  if inputs is not None:
1338  assert isinstance(inputs, schema.Field)
1339  original = net.input_record()
1340  assert original is not None
1341  # TODO(azzolini): improve schema type checking
1342  diff = set(original.field_names()) - set(inputs.field_names())
1343  assert len(diff) == 0, (
1344  "Schemas don't match, extra fields {diff} found in the net {name}. "
1345  "original: {original}; inputs: {inputs}"
1346  .format(
1347  diff=diff, name=net.Name(), original=original.field_names(),
1348  inputs=inputs.field_names()
1349  )
1350  )
1351  original_mapping = dict(zip(original.field_names(),
1352  original.field_blobs()))
1353  for fn, fb in zip(inputs.field_names(), inputs.field_blobs()):
1354  if fn in original_mapping:
1355  blob_remap[str(original_mapping[fn])] = str(fb)
1356  proto = net.Proto()
1357  ssa, blob_versions = get_ssa(proto)
1358  undef_blobs = get_undefined_blobs(ssa)
1359 
1360  for blob in viewkeys(blob_versions):
1361  if blob in blob_remap:
1362  continue
1363  elif blob in undef_blobs:
1364  blob_remap[blob] = blob
1365  else:
1366  blob_remap[blob] = prefix + blob
1367  cloned_net = net.Clone(name, blob_remap, keep_schema=keep_schema)
1368  if not keep_schema and inputs:
1369  cloned_net.set_input_record(inputs)
1370  return cloned_net, blob_remap
1371 
1372 
1373 def _get_blob_ref(blob_name_or_ref):
1374  return (
1375  blob_name_or_ref if isinstance(input, BlobReference)
1376  else BlobReference(blob_name_or_ref)
1377  )
1378 
1379 
1380 def _recover_record_by_prefix(names, prefix=''):
1381  """
1382  Tries to recover record by taking a subset of blob names with
1383  a given prefix name and interpreting them as schema column names
1384  """
1385  from caffe2.python import schema
1386  column_names = [name[len(prefix):] for name in names
1387  if name.startswith(prefix)]
1388  if not column_names:
1389  return None
1390  return schema.from_column_list(
1391  column_names,
1392  col_blobs=[_get_blob_ref(prefix + name) for name in column_names])
1393 
1394 
1395 class Net(object):
1396  _net_names_used = set()
1397  operator_registry_ = {}
1398 
1399  @staticmethod
1400  def current_prefix():
1401  from caffe2.python.net_builder import NetBuilder
1402  builder = NetBuilder.current(required=False)
1403  return builder.name if builder else ''
1404 
1405  @staticmethod
1406  def _get_next_net_name(basename):
1407  name = basename = '/'.join(
1408  x for x in [Net.current_prefix(), basename] if x
1409  )
1410  next_idx = 1
1411  while name in Net._net_names_used:
1412  name = basename + '_' + str(next_idx)
1413  next_idx += 1
1414  Net._net_names_used |= set([name])
1415  return name
1416 
1417  def __init__(self, name_or_proto):
1418  """
1419  Create a Net.
1420  Args:
1421  name_or_proto: If a NetDef is provided, clone it. Otherwise,
1422  create an empty net with the given name.
1423  """
1424  self._input_record = None
1425  self._output_record = None
1426  # Register blobs so that it's guaranteed that different calls to
1427  # NextBlob/NextScopedBlob always return blobs with different names
1428  self._registered_blob_names = set()
1429  self._recreate_lookup_tables = False
1430  self._op_outputs = set()
1431  self._external_input_map = set()
1432  self._attr_dict = defaultdict(list)
1433  if type(name_or_proto) is caffe2_pb2.NetDef:
1434  proto = name_or_proto
1435  # We rae initializing a network by a NetDef. In this case, we will
1436  # initialize our network with the given netdef.
1437  self._net = caffe2_pb2.NetDef()
1438  self._net.CopyFrom(proto)
1439 
1440  existing_outputs = [list(op.output) for op in self._net.op]
1441 
1442  self._external_input_map.update(list(self._net.external_input))
1443 
1444  # Set the next name index properly.
1445  existing_names = set(
1446  sum(
1447  [list(op.input) for op in self._net.op], []
1448  ) + sum(
1449  existing_outputs, []
1450  )
1451  )
1452  for outs in existing_outputs:
1453  self._op_outputs.update(outs)
1454 
1455  prefix_len = len(self._net.name + '_blob_')
1456  autogen_indices = []
1457  for s in existing_names:
1458  if s.startswith(self._net.name + '_blob_'):
1459  try:
1460  autogen_indices.append(int(s[prefix_len]))
1461  except ValueError:
1462  pass
1463  if len(autogen_indices):
1464  self._next_name_index = max(autogen_indices) + 1
1465  else:
1466  self._next_name_index = 0
1467  name = self._net.name
1468  else:
1469  name = name_or_proto
1470  self._net = caffe2_pb2.NetDef()
1471  self._next_name_index = 0
1472 
1473  # make sure that this net name hasn't been used before
1474  self._net.name = Net._get_next_net_name(name)
1475 
1476  def AppendNet(self, net, device_option=None):
1477  assert isinstance(net, Net)
1478  for i in net.Proto().external_input:
1479  if (
1480  i not in self.Proto().external_input and
1481  i not in self._op_outputs
1482  ):
1483  self.Proto().external_input.append(i)
1484 
1485  self.Proto().external_output.extend(
1486  [
1487  o for o in net.Proto().external_output
1488  if o not in self.Proto().external_output
1489  ]
1490  )
1491  ops = net.Proto().op
1492  if device_option is not None:
1493  ops = [copy.deepcopy(op) for op in ops]
1494  map(lambda x: x.device_option.CopyFrom(device_option), ops)
1495 
1496  self._ExtendOps(ops)
1497  return self
1498 
1499  def LogInfo(self, *msg_or_blobs):
1500  for msg_or_blob in msg_or_blobs:
1501  if not isinstance(msg_or_blob, BlobReference):
1502  blob = self.GivenTensorStringFill(
1503  [], self.NextName('log'),
1504  shape=[], values=[msg_or_blob])
1505  else:
1506  blob = msg_or_blob
1507  self.Print(blob, [])
1508 
1509  def add_attribute(self, name, obj):
1510  """
1511  Add `obj` to the list of attributes in this net under the given `name`.
1512  Attributes are user-defined objects and have no pre-defined semantics.
1513  """
1514  self._attr_dict[name].append(obj)
1515 
1516  def get_attributes(self, name):
1517  """
1518  Returns the list of attributes in this net for a given `name`.
1519  Attributes are user-defined objects added with `add_attribute'.
1520  """
1521  return self._attr_dict.get(name, [])
1522 
1523  def set_rand_seed(self, seed=100, sequence_seed=True, seed_on_op_def=False):
1524  """
1525  Adds a random seed to each op in the net.
1526  If sequence_seed is set, the i-th op has rand_seed=`seed + i`
1527  If seed_on_op_def is set, the op rand_seed=hash(str(op))
1528  sequence_seed and seed_on_op_def cannot be both set to True.
1529  """
1530  assert not (sequence_seed and seed_on_op_def), (
1531  'sequence_seed and seed_on_op_def cannot be both set to True.')
1532  for i, op in enumerate(self.Proto().op):
1533  if sequence_seed:
1534  curr_seed = seed + i
1535  elif seed_on_op_def:
1536  curr_seed = hash(str(op) + str(seed)) % np.iinfo(np.uint32).max
1537  else:
1538  curr_seed = seed
1539  op.device_option.random_seed = curr_seed
1540 
1541  def Name(self):
1542  return self._net.name
1543 
1544  def __str__(self):
1545  return self.Name()
1546 
1547  def Const(self, array, blob_out=None, dtype=None):
1548  if isinstance(array, bool):
1549  return self.ConstantFill(
1550  [],
1551  blob_out or 1,
1552  dtype=DataType.BOOL,
1553  value=array)
1554 
1555  if dtype is None:
1556  array = np.array(array)
1557  else:
1558  array = np.array(array, dtype=dtype)
1559 
1560  def do_set(operator):
1561  return operator(
1562  [],
1563  blob_out or 1,
1564  shape=array.shape,
1565  values=array.flatten().tolist())
1566 
1567  if array.dtype == np.int32:
1568  return do_set(self.GivenTensorIntFill)
1569  elif array.dtype == np.int64:
1570  return do_set(self.GivenTensorInt64Fill)
1571  elif array.dtype == np.str:
1572  return do_set(self.GivenTensorStringFill)
1573  elif array.dtype == np.bool:
1574  return do_set(self.GivenTensorBoolFill)
1575  else:
1576  return do_set(self.GivenTensorFill)
1577 
1578  def BlobIsDefined(self, blob):
1579  """
1580  Returns true if the given BlobReference is produced as output of
1581  an operator in this net, or if it is provided as an external input.
1582  """
1583  if self._recreate_lookup_tables:
1584  self._RecreateLookupTables()
1585  name = str(blob)
1586  return (name in self._op_outputs) or (name in self._external_input_map)
1587 
1588  def UsesBlob(self, blob):
1589  """
1590  Returns true iff the given BlobReference is used by any operator
1591  or this net, or if it is one of the external inputs of the net.
1592  """
1593  blob_name = str(blob)
1594  for op in self._net.op:
1595  for input in op.input:
1596  if input == blob_name:
1597  return True
1598  return blob_name in self._external_input_map
1599 
1600  def UsedBlobNames(self):
1601  """
1602  Returns a set of blob names used in the net
1603  """
1604  blob_names = set()
1605  for op in self._net.op:
1606  blob_names |= set(op.input)
1607  blob_names |= set(op.output)
1608  if self._net.external_input:
1609  blob_names |= set(self._net.external_input)
1610  if self._net.external_output:
1611  blob_names |= set(self._net.external_output)
1612  return blob_names
1613 
1614  def GetBlobRef(self, blob_name):
1615  """
1616  Given the name of a blob produced by this net, return a BlobReference
1617  to it. If the blob is not produced by any op in this net,
1618  raises KeyError.
1619  """
1620  blob_name = str(blob_name)
1621  if not self.BlobIsDefined(blob_name):
1622  raise KeyError('Net does not define blob %s' % blob_name)
1623  return BlobReference(blob_name, self)
1624 
1625  def Clone(
1626  self,
1627  name,
1628  blob_remap=None,
1629  op_id_mask=None,
1630  remap_funcs=None,
1631  keep_schema=True,
1632  update_external_list=False,
1633  ):
1634  """
1635  Clone this net.
1636  Args:
1637  name: name of the cloned net
1638  blob_remap: optional map with list of blob names to replace
1639  op_id_mask: optional list of operator indices to include in
1640  the cloned net. If not provided, all ops are included.
1641  """
1642  orig_remap_funcs = {} if remap_funcs is None else remap_funcs
1643  # by default we want to put RecurrentNetworkOp and
1644  # RecurrentNetworkGradientOp into remap_funcs, as these two operators
1645  # also take blobs and proto into the arguments.
1646  remap_funcs = DEFAULT_REMAP_FUNCS.copy()
1647  remap_funcs.update(orig_remap_funcs)
1648  proto = self._net
1649  new_proto = caffe2_pb2.NetDef()
1650  new_proto.CopyFrom(proto)
1651  new_proto.name = name
1652 
1653  if blob_remap is None:
1654  blob_remap = {}
1655  if op_id_mask is None:
1656  op_id_mask = list(range(0, len(proto.op)))
1657 
1658  def get_remapped_str(blob):
1659  blob_str = str(blob)
1660  return str(blob_remap.get(blob_str, blob_str))
1661 
1662  def remap_list(proto_list):
1663  new_list = [get_remapped_str(b) for b in proto_list]
1664  del proto_list[:]
1665  proto_list.extend(new_list)
1666 
1667  def remap_op(op):
1668  new_op = caffe2_pb2.OperatorDef()
1669  new_op.CopyFrom(op)
1670  remap_list(new_op.input)
1671  remap_list(new_op.output)
1672  if new_op.type in remap_funcs:
1673  remap_funcs[new_op.type](
1674  new_op,
1675  (name + '/') if name else '',
1676  blob_remap,
1677  )
1678  return new_op
1679 
1680  del new_proto.op[:]
1681  new_proto.op.extend([remap_op(proto.op[op_id]) for op_id in op_id_mask])
1682  remap_list(new_proto.external_input)
1683  remap_list(new_proto.external_output)
1684  new_net = Net(new_proto)
1685 
1686  if keep_schema:
1687  from caffe2.python import schema
1688  if self._input_record:
1689  new_net._input_record = schema.from_blob_list(
1690  self._input_record,
1691  [
1692  BlobReference(get_remapped_str(blob), net=new_net)
1693  for blob in self._input_record.field_blobs()
1694  ],
1695  )
1696  if self._output_record:
1697  new_net._output_record = schema.from_blob_list(
1698  self._output_record,
1699  [
1700  BlobReference(get_remapped_str(blob), net=new_net)
1701  for blob in self._output_record.field_blobs()
1702  ],
1703  )
1704 
1705  new_net._attr_dict.update(self._attr_dict)
1706  if update_external_list:
1707  # external input list
1708  existing_outputs = set()
1709  used_outputs = set()
1710  del new_net.Proto().external_input[:]
1711  del new_net.Proto().external_output[:]
1712  for op in new_net.Proto().op:
1713  for ib in op.input:
1714  if ib not in existing_outputs:
1715  new_net.Proto().external_input.extend([ib])
1716  else:
1717  used_outputs.add(ib)
1718  for ob in op.output:
1719  existing_outputs.add(ob)
1720  # external outputs
1721  for ob in existing_outputs:
1722  if ob not in used_outputs:
1723  new_net.Proto().external_output.extend([ob])
1724  return new_net
1725 
1726  def ClonePartial(self, name, inputs, outputs, remap_funcs=None):
1727  """
1728  Clone this net, including only ops that are necessary in order to
1729  compute `outputs` given `inputs`. Return references to the cloned
1730  outputs. Internal blobs (blobs that are produced and consumed inside
1731  the net but not used as outputs) will be remapped to avoid name
1732  conflict.
1733 
1734  Args:
1735  name: the name of the cloned net
1736  inputs: map where the keys correspond to BlobReferences in the
1737  original net, and the values correspond to external inputs
1738  in the partially cloned net. If `inputs` is a list, don't
1739  remap input names.
1740  outputs: outputs to be produced by the cloned net.
1741 
1742  Returns:
1743  Tuple (new_net, new_outputs)
1744  new_net: a new Net object.
1745  new_outputs: list of BlobReferences corresponding to the
1746  outputs produced by new_net.
1747  """
1748  input_is_pair_list = isinstance(inputs, list) and all(
1749  isinstance(i, tuple) and len(i) == 2 for i in inputs)
1750  inputs = (
1751  inputs if isinstance(inputs, (dict, OrderedDict)) else
1752  OrderedDict(inputs) if input_is_pair_list else
1753  OrderedDict(zip(inputs, inputs)))
1754  for output in outputs:
1755  assert self.BlobIsDefined(output), "{} is not defined".format(output)
1756  input_names = {str(k): str(v) for k, v in viewitems(inputs)}
1757  output_names = [str(o) for o in outputs]
1758  proto = self._net
1759  blob_versions = {str(i): 0 for i in inputs}
1760  ssa, blob_versions = get_ssa(proto, blob_versions)
1761  used_op_ids = get_op_ids_in_path(ssa, blob_versions, inputs, outputs)
1762  disallowed_op_ids = get_op_ids_in_path(ssa, blob_versions, [], inputs)
1763  assert len(set(used_op_ids) & set(disallowed_op_ids)) == 0, (
1764  'Cannot partially clone net: some of the ops required would ' +
1765  'generate the given input.')
1766 
1767  sub_ssa = [op for i, op in enumerate(ssa) if i in used_op_ids]
1768  undef_blobs = get_undefined_blobs(sub_ssa) - set(viewkeys(input_names))
1769  prefix = (name + '/') if name else ''
1770 
1771  def remap(blob_name):
1772  if blob_name in input_names:
1773  return input_names[blob_name]
1774  elif blob_name in undef_blobs:
1775  return blob_name
1776  else:
1777  return prefix + blob_name
1778 
1779  blob_mapping = {b: remap(b) for b in viewkeys(blob_versions)}
1780  new_net = self.Clone(name, blob_mapping, used_op_ids, remap_funcs)
1781  new_in = [
1782  blob_mapping[i] for i in viewkeys(input_names)] + list(undef_blobs)
1783  new_out = [blob_mapping[o] for o in output_names]
1784  del new_net.Proto().external_input[:]
1785  new_net.Proto().external_input.extend(new_in)
1786  new_net._external_input_map = set(list(new_in))
1787  del new_net.Proto().external_output[:]
1788  new_net.Proto().external_output.extend(new_out)
1789  return new_net, [new_net.GetBlobRef(o) for o in new_out]
1790 
1791  def Proto(self):
1793  return self._net
1794 
1795  def insert_op_at_idx(self, op, op_idx):
1796  r""" inserting operator at index. Will update external blob list.
1797  """
1798  assert op_idx >= 0
1799  temp_ops = self.Proto().op[op_idx:]
1800  del self.Proto().op[op_idx:]
1801  self.Proto().op.extend([op])
1802  self.Proto().op.extend(temp_ops)
1803  self.external_outputs.extend(op.output)
1804  self.external_inputs.extend(op.input)
1805 
1806  def reroute_tensor(self, tensor, new_producer, can_modify=None):
1807  r""" reroute tensor to new_producer. And feed new tensor to consumers
1808  and interseciton with can_modify if provided.
1809  Inputs:
1810  tensor: str or blob_reference the tensor to reroute
1811  new_producer: an op takes in tensor gives new_tesnor
1812  can_modify: a list/set of operators that consumes tensor and can be
1813  modified
1814 
1815  Returns:
1816  reroute_cnt: how many consumer op has been changed
1817 
1818  Note: assume no inplace blob in net
1819  """
1820  def _find_tensor_input_op(tensor):
1821  if tensor in self.external_inputs:
1822  op_idx = -1
1823  else:
1824  assert tensor in new_producer.input, \
1825  "new producer {} is not taking in {}".format(
1826  new_producer.type, tensor)
1827  # assuming that the net has no inplace blob
1828  op_idx = -2
1829  for index, op in enumerate(self.Proto().op):
1830  if_found = False
1831  for o in op.output:
1832  if o == tensor:
1833  # tensor should not be modified yet.
1834  if_found = True
1835  op_idx = index
1836  break
1837  if if_found:
1838  break
1839  return op_idx
1840 
1841  # the place to inject new_producer is not just determined by tensor
1842  op_idx = max(_find_tensor_input_op(t) for t in new_producer.input)
1843  self.insert_op_at_idx(new_producer, op_idx + 1)
1844  new_tensor = new_producer.output[0]
1845  # modify external outputs
1846  if tensor in self.external_outputs:
1847  new_list = [new_tensor if b == tensor else b for b in self.external_outputs]
1848  del self.Proto().external_output[:]
1849  self.Proto().external_output.extend(new_list)
1850 
1851  # modify consumers
1852  reroute_cnt = 0
1853  if can_modify:
1854  for op in self.Proto().op:
1855  if op in can_modify: # this is not necessarily true
1856  remap_input(op, {tensor: new_tensor})
1857  reroute_cnt = reroute_cnt + 1
1858  return reroute_cnt
1859 
1860  def PopulateProtoWithFileName(self):
1861  net_tb = workspace.operator_tracebacks.get(self.Name(), None)
1862  if net_tb is not None:
1863  for idx, op in enumerate(self.Proto().op):
1864  if idx in net_tb:
1865  op.name = ':'.join(map(str, net_tb[idx][0]))
1866 
1867  def NextScopedBlob(self, prefix='unnamed'):
1868  """Return the blob that has not been defined or registered in the
1869  current net. It returns `ScopedBlobReference(prefix)`, if it's valid,
1870  otherwise `ScopedBlobReference(prefix) + '_auto_' + ?`. Different calls
1871  is guaranteed to return blob with different names.
1872  """
1873  output_blob_base = ScopedName(prefix)
1874  return self.NextBlob(output_blob_base)
1875 
1876  def NextBlob(self, prefix='unnamed'):
1877  """Return the blob that has not been defined or registered in the
1878  current net. It returns `BlobReference(prefix)`, if it's valid,
1879  otherwise `BlobReference(prefix) + '_auto_' + ?`. Different calls
1880  is guaranteed to return blob with different names."""
1881  output_blob_base = BlobReference(prefix)
1882  output_blob = output_blob_base
1883  index = 0
1884  while str(output_blob) in self._registered_blob_names or (
1885  self.BlobIsDefined(output_blob)):
1886  output_blob = output_blob_base + '_auto_' + str(index)
1887  index += 1
1888 
1889  self._registered_blob_names.add(str(output_blob))
1890  return output_blob
1891 
1892  def NextName(self, prefix=None, output_id=None):
1893  """Returns the next name to be used, if you do not want to explicitly
1894  name your blob. [Deprecated, use NextBlob, NextScopedBlob instead]"""
1895  if prefix:
1896  output_name_base = self._net.name + '/' + prefix
1897  output_name = output_name_base
1898  if output_id is not None:
1899  output_name += ':' + str(output_id)
1900  index = 2
1901  while self.BlobIsDefined(str(ScopedBlobReference(output_name))):
1902  output_name = output_name_base + '_' + str(index)
1903  if output_id is not None:
1904  output_name += ':' + str(output_id)
1905  index += 1
1906  else:
1907  output_name = self._net.name + '_blob_' + str(self._next_name_index)
1908  self._next_name_index += 1
1909  return str(output_name)
1910 
1911  def _ExtendOps(self, new_ops):
1912  self._net.op.extend(new_ops)
1913  for op in new_ops:
1914  self._op_outputs.update([text_type(o) for o in op.output])
1915 
1916  def _CheckLookupTables(self):
1917  '''
1918  Called from unit tests to validate the internal lookup tables
1919  match the protobuf contents.
1920  '''
1921  test_op_outputs = set()
1922  for op in self._net.op:
1923  for o in op.output:
1924  test_op_outputs.add(o)
1925 
1926  test_external_inp = set()
1927  for inp in self._net.external_input:
1928  test_external_inp.add(inp)
1929 
1930  assert test_op_outputs.difference(self._op_outputs) == set()
1931  assert test_external_inp.difference(self._external_input_map) == set()
1932 
1933  def _InvalidateLookupTables(self):
1934  self._recreate_lookup_tables = True
1935 
1936  def _RecreateLookupTables(self):
1937  self._op_outputs = set()
1938  for op in self._net.op:
1939  for o in op.output:
1940  self._op_outputs.add(o)
1941 
1942  self._external_input_map = set()
1943  for inp in self._net.external_input:
1944  self._external_input_map.add(inp)
1945 
1946  self._recreate_lookup_tables = False
1947 
1948  def AddGradientOperators(self, ys, skip=0):
1949  """Add the gradient for operators in the net.
1950 
1951  Inputs:
1952  ys: a list or a dictionary specifying what blobs we want to compute
1953  derivatives of. If the input is a list, we will automatically
1954  generate their gradients with all-one values; if the input is a
1955  dictionary, for any dictionary entries that are not None, we will
1956  take the corresponding blobs as their gradients; for all those
1957  that are None, we will auto-fill them with 1.
1958  skip: skips the first n operators. This is provided mainly because a
1959  lot of nets may use the first few operators for data generation
1960  like stuff which really do not need to have gradients.
1961 
1962  Outputs:
1963  returns a map from the blob name in the input network to a blob
1964  containing gradient or a GradientSlice in case of sparse gradient
1965 
1966  Currently, this is hard-coded for float operators if there are branches
1967  (i.e. a blob is used as input to multiple operators). This is because
1968  the gradient accumulation (Sum) is float only right now.
1969  """
1970 
1971  grad_ops, input_to_grad = GradientRegistry.GetBackwardPass(
1972  self._net.op[skip:], ys)
1973  # Check if in immediate mode: the grad_ops are actually being produced
1974  # by C++ and bypasses the CreateOperator() call, so in immediate mode
1975  # we will have to explicitly run them.
1976  if workspace.IsImmediate():
1977  for op in grad_ops:
1978  workspace.RunOperatorImmediate(op)
1979  self._ExtendOps(grad_ops)
1980  return input_to_grad
1981 
1982  def AddArgument(self, arg_name, arg_value):
1983  self._net.arg.extend([utils.MakeArgument(arg_name, arg_value)])
1984 
1985  def AddExternalInput(self, *inputs):
1986  assert len(inputs) > 0
1987  refs = []
1988  for input in inputs:
1989  input_name = str(input)
1990  assert str(input) not in self._external_input_map, (
1991  'Net already contains an input named %s' % input_name)
1992  for input in inputs:
1993  input_name = str(input)
1994  self._net.external_input.extend([input_name])
1995  self._external_input_map.update([input_name])
1996  refs.append(_get_blob_ref(input_name))
1997 
1998  return refs[0] if len(refs) == 1 else refs
1999 
2000  def AddExternalOutput(self, *outputs):
2001  for output in outputs:
2002  assert isinstance(output, BlobReference)
2003  assert self.BlobIsDefined(output), "{} is not defined".format(output)
2004  for output in outputs:
2005  self.Proto().external_output.extend([str(output)])
2006 
2007  def AddScopedExternalInputs(self, *inputs):
2008  res = self.AddExternalInput(
2009  * [ScopedBlobReference(b) for b in inputs]
2010  )
2011  if not isinstance(res, list):
2012  res = [res]
2013  return res
2014 
2015  def AddScopedExternalOutputs(self, *outputs):
2016  return self.AddExternalOutput(
2017  * [ScopedBlobReference(b) for b in outputs]
2018  )
2019 
2020  # This returns a reference to the observer
2021  def AddObserver(self, observer_type):
2022  return C.add_observer_to_net(self._net.name, observer_type)
2023 
2024  def RemoveObserver(self, observer):
2025  C.remove_observer_from_net(self._net.name, observer)
2026 
2027  def NumObservers(self):
2028  return C.num_observers_on_net(self._net.name)
2029 
2030  @property
2031  def external_inputs(self):
2032  return [_get_blob_ref(x) for x in self._net.external_input]
2033 
2034  @property
2035  def external_outputs(self):
2036  return [_get_blob_ref(x) for x in self._net.external_output]
2037 
2038  def set_input_record(self, input_record):
2039  from caffe2.python import schema
2040  assert self._input_record is None or (input_record.has_blobs() and
2041  set(input_record.field_blobs()) ==
2042  set(self._input_record.field_blobs())), (
2043  'Input schema cannot be reset')
2044  if not input_record.has_blobs():
2045  with NameScope(self.Name()):
2046  self._input_record = schema.NewRecord(self, input_record)
2047  else:
2048  self._input_record = input_record
2049 
2050  for blob in self._input_record.field_blobs():
2051  if blob not in self.external_inputs:
2052  self.AddExternalInput(blob)
2053  return self._input_record
2054 
2056  """
2057  Tries to recover input record by taking a subset of external_inputs with
2058  a given prefix name and interpreting them as schema column names
2059  """
2060  record = _recover_record_by_prefix(self._net.external_input, prefix)
2061  if record:
2062  self.set_input_record(record)
2063 
2064  def set_output_record(self, record):
2065  assert self._output_record is None or (record.has_blobs() and
2066  set(record.field_blobs()) ==
2067  set(self._output_record.field_blobs())), (
2068  'Output schema cannot be reset')
2069  for blob in record.field_blobs():
2070  assert self.BlobIsDefined(blob), "{} is not defined".format(blob)
2071  for blob in record.field_blobs():
2072  if blob not in self.external_outputs:
2073  self.AddExternalOutput(blob)
2074  self._output_record = record
2075 
2077  """
2078  Tries to recover out record by taking a subset of external_outputs with
2079  a given prefix name and interpreting them as schema column names
2080  """
2081  record = _recover_record_by_prefix(self._net.external_output, prefix)
2082  if record:
2083  self.set_output_record(record)
2084 
2085  def AppendOutputRecordField(self, field_name, record):
2086  from caffe2.python import schema
2087  assert self._output_record is not None, (
2088  'Tried to append to missing output record'
2089  )
2090  for blob in record.field_blobs():
2091  assert self.BlobIsDefined(blob), "{} is not defined".format(blob)
2092  for blob in record.field_blobs():
2093  self.AddExternalOutput(blob)
2095  (field_name, record)
2096  )
2097 
2098  def input_record(self):
2099  return self._input_record
2100 
2101  def output_record(self):
2102  return self._output_record
2103 
2104  def AddExternalInputs(self, *inputs):
2105  return self.AddExternalInput(*inputs)
2106 
2107  def AddExternalOutputs(self, *outputs):
2108  self.AddExternalOutput(*outputs)
2109 
2110  def DeduplicateGradientSlices(self, g, aggregator='sum'):
2111  assert isinstance(g, GradientSlice)
2112  unique, remapping = self.Unique([g.indices], 2, engine='SparseHash')
2113  if aggregator.lower() == 'sum':
2114  new_g = self.UnsortedSegmentSum([g.values, remapping], 1)
2115  elif aggregator.lower() == 'mean':
2116  new_g = self.UnsortedSegmentMean([g.values, remapping], 1)
2117  else:
2118  raise ValueError('{} is not supported'.format(aggregator))
2119  return GradientSlice(indices=unique, values=new_g)
2120 
2121  @staticmethod
2122  def _RunAllOnGPU(net, gpu_id=0, use_cudnn=False):
2123  device_option = caffe2_pb2.DeviceOption()
2124  device_option.device_type = workspace.GpuDeviceType
2125  device_option.device_id = gpu_id
2126  net.device_option.CopyFrom(device_option)
2127  if use_cudnn:
2128  for op in net.op:
2129  op.engine = "CUDNN"
2130  # Move RecurrentNetwork operators on GPU as well
2131  for op in net.op:
2132  if op.type != "RecurrentNetwork":
2133  continue
2134  for arg in op.arg:
2135  if arg.name == "step_net":
2136  Net._RunAllOnGPU(arg.n, gpu_id, use_cudnn)
2137 
2138  def RunAllOnGPU(self, gpu_id=0, use_cudnn=False):
2139  """A convenient function to run everything on the GPU."""
2140  self._RunAllOnGPU(self._net, gpu_id, use_cudnn)
2141 
2142 
2143 
2144  def RunAllOnMKL(self):
2145  """A convenient function to run everything using MKLDNN."""
2146  device_option = caffe2_pb2.DeviceOption()
2147  device_option.device_type = caffe2_pb2.MKLDNN
2148  self._net.device_option.CopyFrom(device_option)
2149 
2150  def RunAllOnIDEEP(self):
2151  """A convenient function to run everything using IDEEP."""
2152  device_option = caffe2_pb2.DeviceOption()
2153  device_option.device_type = caffe2_pb2.IDEEP
2154  self._net.device_option.CopyFrom(device_option)
2155 
2156  def _CreateAndAddToSelf(self, op_type, inputs, outputs=None, **kwargs):
2157  """A helper function to create an operator and add it to self.
2158  """
2159  inputs = _RectifyInputOutput(inputs)
2160  for input in inputs:
2161  if not self.BlobIsDefined(input):
2162  assert input.Net() != self
2163  self.AddExternalInput(input)
2164  if outputs is None:
2165  # If we do not specify an output, we will assume that this op
2166  # produces one output in this case.
2167  outputs = self.NextName(prefix=op_type)
2168  elif type(outputs) is int:
2169  # In this case, we will auto-fill the given number of outputs
2170  # with auto-generated names.
2171  outputs = [
2172  self.NextName(prefix=op_type, output_id=i)
2173  for i in range(outputs)]
2174  outputs = _RectifyInputOutput(outputs, net=self)
2175  op = CreateOperator(op_type, inputs, outputs, **kwargs)
2176  self._ExtendOps([op])
2177 
2178  workspace.operator_tracebacks[self.Name()][
2179  len(self._net.op) - 1] = _extract_stacktrace()
2180 
2181  if len(op.output) == 0:
2182  return
2183  elif len(op.output) == 1:
2184  return BlobReference(op.output[0], self)
2185  else:
2186  return tuple(BlobReference(o, self) for o in op.output)
2187 
2188  def __getattr__(self, op_type):
2189  if op_type.startswith('__'):
2190  raise AttributeError('Attribute {} not found.'.format(op_type))
2191  if not IsOperator(op_type) and not IsOperatorWithEngine(op_type, "CUDNN"):
2192  raise AttributeError(
2193  'Method ' + op_type + ' is not a registered operator.' +
2194  ' Did you mean: [' +
2195  ",".join(workspace.C.nearby_opnames(op_type)) + ']'
2196  )
2197  return lambda *args, **kwargs: self._CreateAndAddToSelf(
2198  op_type, *args, **kwargs)
2199 
2200  def __dir__(self):
2201  additional_methods = [
2202  op
2203  for op in _REGISTERED_OPERATORS
2204  if '_ENGINE_' not in op]
2205  return sorted(set(chain(
2206  dir(type(self)),
2207  viewkeys(self.__dict__),
2208  additional_methods
2209  )))
2210 
2211  def Python(
2212  self,
2213  f,
2214  grad_f=None,
2215  python_func_type=None,
2216  pass_workspace=False,
2217  grad_output_indices=None,
2218  grad_input_indices=None
2219  ):
2220  """
2221  Registers and returns a python operator.
2222 
2223  `f` and `grad_f` can be one of the following:
2224  - a function with signature (inputs, outputs), where inputs and
2225  outputs are a list of CPUTensor objects. This function will be
2226  called from C++ everytime the operator is executed.
2227  - a tuple (func, args, kwargs), here `func` is a callable, args is
2228  an argument list, and kwargs is a dict list. The call:
2229  f = func(*args, kwargs)
2230  will be performed locally at node initialization time, on all of
2231  the nodes of the job, returning `f`, a callable that will be used
2232  as the python operator function to be called during Net execution.
2233  This is to be used when using python operator in a distributed
2234  context, and allows to create and keep local python state across
2235  calls to the operator.
2236 
2237  `python_func_type` is a type of an object that constructed as
2238  python_func_type(f) and provides an implementation to forward and
2239  backward functions. Its useful in such a case where users needs
2240  a statefull PythonOp (ex: use autograd for computing grad_f).
2241 
2242  If `pass_workspace` is True, the signature is changed to
2243  (inputs, outputs, workspace) where `workspace` is the workspace the op
2244  is going to run on. This is potentially dangerous (as the op can
2245  manipulate the workspace directly), use on your own risk.
2246 
2247  If a gradient function is specified (`grad_f`), by default its inputs
2248  will be: (1) all inputs to `f`, (2) followed by all outputs of `f`, (3)
2249  and then all gradient outputs of `f`. The outputs of `grad_f` will be
2250  (by default) all gradient inputs to `f`. If a subset of the gradient
2251  outputs or gradient inputs is desired instead, then the subsets can be
2252  specified by providing `grad_output_indices` and/or `grad_input_indices`
2253  which identify the indices of `f`'s inputs and outputs which have
2254  gradients.
2255  """
2256  assert(IsOperator('Python'))
2257 
2258  def make_builder(t):
2259  if not isinstance(t, tuple):
2260  return ''
2261  assert len(t) == 3, 'Expected builder tuple (func, args, kwargs)'
2262  func, args, kwargs = t
2263  normalized = (func, tuple(args), dict(kwargs))
2264  return pickle.dumps(normalized)
2265 
2266  f_builder = make_builder(f)
2267  grad_f_builder = make_builder(grad_f)
2268 
2269  assert (not grad_f) or ((not f_builder) == (not grad_f_builder)), (
2270  'A tuple has to be passed to both f and grad_f or neither.')
2271 
2272  core_kwargs = {}
2273  if f_builder:
2274  core_kwargs['pickled_builder'] = f_builder
2275  core_kwargs['pickled_grad_builder'] = grad_f_builder
2276  core_kwargs['pass_workspace'] = pass_workspace
2277  else:
2278  core_kwargs['token'] = _RegisterPythonImpl(
2279  f, grad_f, python_func_type, pass_workspace=pass_workspace)
2280 
2281  grad_output_indices = grad_output_indices or []
2282  grad_input_indices = grad_input_indices or []
2283  return lambda *args, **kwargs: self._CreateAndAddToSelf(
2284  'Python',
2285  grad_output_indices=grad_output_indices,
2286  grad_input_indices=grad_input_indices,
2287  *args,
2288  **dict(chain(viewitems(kwargs), viewitems(core_kwargs)))
2289  )
2290 
2291  def is_external_input(self, blob):
2292  name = str(blob)
2293  return name in self._external_input_map
2294 
2295  def extend_ops(self, new_ops):
2296  return self._ExtendOps(new_ops)
2297 
2298 
2299 def remap_input(op, blob_name_remapping):
2300  new_list = [blob_name_remapping.get(b, b) for b in op.input]
2301  del op.input[:]
2302  op.input.extend(new_list)
2303 
2304 
2305 def copy_func_between_devices(src, dst):
2306  CPU = caffe2_pb2.CPU
2307  is_src_gpu = IsGPUDeviceType(src.device_type)
2308  is_dst_gpu = IsGPUDeviceType(dst.device_type)
2309 
2310  if src.device_type == CPU and dst.device_type == CPU:
2311  return None
2312 
2313  if is_src_gpu and is_dst_gpu:
2314  if src.device_id == dst.device_id:
2315  return None
2316  else:
2317  def fun(net, *args, **kw):
2318  with DeviceScope(dst):
2319  return net.Copy(*args, **kw)
2320  return fun
2321 
2322  if is_src_gpu and dst.device_type == CPU:
2323  def fun(net, *args, **kw):
2324  with DeviceScope(src):
2325  return net.CopyGPUToCPU(*args, **kw)
2326  return fun
2327 
2328  if src.device_type == CPU and is_dst_gpu:
2329  def fun(net, *args, **kw):
2330  with DeviceScope(dst):
2331  return net.CopyCPUToGPU(*args, **kw)
2332  return fun
2333 
2334  raise ValueError('Non-supported devices: %s and %s' % (src, dst))
2335 
2336 
2337 def device_equal(src, dst):
2338  '''
2339  We are using this fucntion instead of == operator because optional-value
2340  comparison between empty device_options and {device_type:0, device_id:0}
2341  returns not equal in some cases.
2342  '''
2343  return src.device_type == dst.device_type and src.device_id == dst.device_id
2344 
2345 
2346 def update_placeholder_op_output(op, blob_to_device):
2347  '''
2348  Placeholder ops (for e.g. Recv) always runs on CPU. So ensure their
2349  output blobs reside on CPU.
2350  '''
2351  outputs = []
2352  for output in op.output:
2353  if (output in blob_to_device and
2354  blob_to_device[output].device_type != caffe2_pb2.CPU):
2355  output += '_cpu'
2356  outputs.append(output)
2357  del op.output[:]
2358  op.output.extend(outputs)
2359 
2360 
2362  def __init__(self, blob, device):
2363  self.blob = blob
2364  self.device = device
2365 
2366  def __eq__(self, other):
2367  return self.blob == other.blob and self.device == other.device
2368 
2369  def __hash__(self):
2370  return hash(self.blob + str(self.device))
2371 
2372 
2373 def InjectCrossDeviceCopies(net, blob_to_device=None, blob_remap=None,
2374  placeHolderOps=None):
2375  '''
2376  Injecting Copy functions between device within a net. Users can provide
2377  a net with part of operators using different device_options. This method
2378  will automatically create a new net with Copy ops inserted in it.
2379 
2380  Inputs:
2381  blob_to_device: If not None, it is a map of blobs and their device locations.
2382  blob_remap: If not None, it is a map from a pair (blob, device) to
2383  the name of the blob in the given device. Blobs found in this
2384  map are assumed to be cached and don't need to be copied.
2385  Outputs:
2386  new_net: A new net with CopyCPUToGPU inserted with correct device option
2387 
2388  required_external_to_device:
2389  A mapping between unresolved external inputs and their
2390  required device options.
2391  Assumptions:
2392  1. every external inputs of this net is already in blob_to_device!
2393  2. if not, this function will use net device option
2394  3. InferOpBlobDevices might fail to get the correct inference for ops like
2395  EnsureCPUOutput that could take in input from multiple places.
2396  '''
2397  new_net = net.Clone(net._net.name + '_cross_device', keep_schema=True)
2398  del new_net._net.op[:]
2399  if blob_to_device is None:
2400  blob_to_device = {}
2401  # remapping of input blobs for each op.
2402  if blob_remap is None:
2403  blob_remap = {}
2404  temp_remap = {}
2405  net_option = net._net.device_option or caffe2_pb2.DeviceOption()
2406 
2407  # if external_inputs have device remappings generated by previous nets,
2408  # then add those remappings as external inputs as well.
2409  all_remaps = defaultdict(list)
2410  for entry, mapped_blob in blob_remap.items():
2411  all_remaps[entry.blob].append(mapped_blob)
2412  mapped_external_inputs = []
2413  for input in new_net._net.external_input:
2414  mapped_external_inputs.extend(all_remaps.get(input) or [])
2415  new_net._net.external_input.extend(mapped_external_inputs)
2416 
2417  for op in net._net.op:
2418  temp_remap.clear()
2419  # Get where inputs and outputs should be. If it is a Placeholder
2420  # (i.e. fake) op, then set op's device as blob's devices.
2421  input_dev = None
2422  output_dev = None
2423  if placeHolderOps is not None and op.type in placeHolderOps:
2424  input_dev, output_dev = InferOpDeviceAsBlobDevices(op)
2425  else:
2426  input_dev, output_dev = InferOpBlobDevices(op)
2427 
2428  for dev, input in zip(input_dev, op.input):
2429  assert net.BlobIsDefined(input), \
2430  "input {} should be defined in the net.".format(input)
2431  if input not in blob_to_device:
2432  if net.is_external_input(input):
2433  blob_to_device[input] = net_option
2434  else:
2435  raise AttributeError(
2436  "No device information found for blob {}.".
2437  format(input)
2438  )
2439 
2440  if not device_equal(blob_to_device[input], dev):
2441  # reuse already moved input
2442  if (RemapEntry(input, dev) in blob_remap and
2443  blob_to_device[blob_remap[RemapEntry(input, dev)]] == dev):
2444  temp_remap[input] = blob_remap[RemapEntry(input, dev)]
2445  else:
2446  # need to make input on correct device.
2447  copy_func = copy_func_between_devices(
2448  blob_to_device[input], dev
2449  )
2450 
2451  def _gen_new_name(blob, device_option):
2452  CPU = caffe2_pb2.CPU
2453  if device_option.device_type == CPU:
2454  suffix = '_cpu'
2455  elif IsGPUDeviceType(device_option.device_type):
2456  suffix = '_gpu_' + str(device_option.device_id)
2457  else:
2458  raise RuntimeError(
2459  "Unknown device type: {}".
2460  format(device_option.device_type)
2461  )
2462  return blob + suffix
2463 
2464  new_name = _gen_new_name(input, dev)
2465  copy_func(new_net, input, new_name)
2466  blob_remap[RemapEntry(input, dev)] = new_name
2467  temp_remap[input] = new_name
2468  blob_to_device[new_name] = dev
2469 
2470  if placeHolderOps is not None and op.type in placeHolderOps:
2471  update_placeholder_op_output(op, blob_to_device)
2472 
2473  # Enforcing no reuse blob between operators. In-place blob usage in an
2474  # op is allowed. This is based on the assumption that in-place op has
2475  # same device info
2476  for dev, output in zip(output_dev, op.output):
2477  if output in blob_to_device and (
2478  output not in op.input and
2479  not device_equal(blob_to_device[output], dev)
2480  ):
2481  raise RuntimeError(
2482  "In-place blob: {} is not supported between operators "
2483  "with different device option previous:{} now: {}. "
2484  "Failed op:\n {}".format(
2485  output, blob_to_device[output], dev, op
2486  )
2487  )
2488  new_op = caffe2_pb2.OperatorDef()
2489  new_op.CopyFrom(op)
2490 
2491  new_list = [temp_remap.get(b, b) for b in new_op.input]
2492  del new_op.input[:]
2493  new_op.input.extend(new_list)
2494 
2495  # keep inplace blobs inplace
2496  original_inputs = list(op.input)
2497  for i, out in enumerate(new_op.output):
2498  try:
2499  input_idx = original_inputs.index(out)
2500  new_op.output[i] = new_op.input[input_idx]
2501  except ValueError:
2502  pass
2503 
2504  blob_to_device.update(
2505  {o: d for d, o in zip(output_dev, new_op.output)})
2506  new_net.extend_ops([new_op])
2507 
2508  return new_net, blob_to_device
2509 
2510 
2511 def InjectDeviceCopiesAmongNets(nets, blob_to_device_init=None):
2512  """
2513  Takes in a list of nets. They usually represent your whole execution graph.
2514  This function will insert cross device copy functions to all nets, and resolve
2515  inter-net external inputs dependencies. This method will insert Copy funcitons if
2516  external inputs of a net is produced on different device than it is required.
2517  Inputs:
2518  nets: a list of nets
2519  Outputs:
2520  new_nets: a list of new nets with device difference solved.
2521 
2522  Some notes from wyiming:
2523  1. You MUST pass nets in execution order. e.g. [train_init, train]
2524  """
2525  assert isinstance(nets, list), \
2526  "nets {} should be a list of nets.".format(str(nets))
2527  assert all(isinstance(net, Net) for net in nets), \
2528  "nets {} should be a list of nets.".format(str(nets))
2529  # A holistic blob to device mapping.
2530  blob_to_device = blob_to_device_init or {}
2531  blob_remap = {}
2532  new_nets = []
2533 
2534  for net in nets:
2535  new_net, blob_to_device = InjectCrossDeviceCopies(
2536  net,
2537  blob_to_device=blob_to_device,
2538  blob_remap=blob_remap,
2539  )
2540  new_nets.append(new_net)
2541 
2542  return new_nets, blob_to_device
2543 
2544 
2545 def InjectDeviceCopiesAmongNetsWithoutB2D(nets, blob_to_device_init=None):
2546  new_nets, _ = InjectDeviceCopiesAmongNets(nets, blob_to_device_init)
2547  return new_nets
2548 
2549 
2550 def get_net_name(netlike):
2551  if isinstance(netlike, Net):
2552  return netlike.Proto().name
2553  elif isinstance(netlike, caffe2_pb2.NetDef):
2554  return netlike.name
2555  else:
2556  return netlike
2557 
2558 
2559 def output_to_list(op_output):
2560  """
2561  Ensures that the output of an operator is a list.
2562  Use when an operator has a variable number of outputs, but a list of
2563  outputs is desired even when number of outputs is 1.
2564 
2565  Args:
2566  op_output: Either a BlobReferenece or an iterable of BlobReferences.
2567 
2568  Returns:
2569  A list of BlobReferences.
2570  """
2571  assert type(op_output) in (list, tuple, BlobReference)
2572  return (
2573  [op_output]
2574  if isinstance(op_output, BlobReference) else list(op_output))
2575 
2576 
2577 def _add_net_to_dict(net_dict, net):
2578  name = get_net_name(net)
2579  if name in net_dict:
2580  assert net_dict[name] is None or net == net_dict[name], (
2581  'Different nets with same name: ' + name)
2582  return False
2583  else:
2584  net_dict[name] = net if isinstance(net, Net) else None
2585  return True
2586 
2587 
2588 class ExecutionStep(object):
2589  _step_names_used = set()
2590 
2591  @staticmethod
2592  def _get_next_step_name(basename):
2593  name = basename
2594  next_idx = 1
2595  while name in ExecutionStep._step_names_used:
2596  name = basename + '_' + str(next_idx)
2597  next_idx += 1
2598  ExecutionStep._step_names_used |= set([name])
2599  return name
2600 
2601  def __init__(self, name, nets=None, num_iter=None):
2602  self._step = caffe2_pb2.ExecutionStep()
2603  self._step.name = name or ExecutionStep._get_next_step_name('step')
2604  self._net_dict = OrderedDict()
2605  self._is_used = False
2606  self._substeps = []
2607  if nets is not None:
2608  if type(nets) is Net:
2609  nets = [nets]
2610  for net in nets:
2611  if _add_net_to_dict(self._net_dict, net):
2612  self._step.network.extend([get_net_name(net)])
2613  if num_iter is not None:
2614  self._step.num_iter = num_iter
2615 
2616  def get_net(self, name):
2617  return self._net_dict[name]
2618 
2619  def Name(self):
2620  return self._step.name
2621 
2622  def __str__(self):
2623  return self._step.name
2624 
2625  def _assert_can_mutate(self):
2626  assert not self._is_used, (
2627  'Cannot mutate a step that has already been added to a plan/step.')
2628 
2629  def _notify_is_used(self):
2630  self._is_used = True
2631 
2632  def Proto(self):
2633  return self._step
2634 
2635  def HasNets(self):
2636  return self._step.network is not None and (
2637  len(self._step.network) > 0)
2638 
2639  def HasSubsteps(self):
2640  return self._step.substep is not None and (
2641  len(self._step.substep) > 0)
2642 
2643  def Nets(self):
2644  return list(viewvalues(self._net_dict))
2645 
2646  def Substeps(self):
2647  return self._substeps
2648 
2649  def SetIter(self, num_iter):
2650  self._assert_can_mutate()
2651  self._step.num_iter = num_iter
2652 
2653  def SetCreateWorkspace(self, create_workspace):
2654  self._assert_can_mutate()
2655  self._step.create_workspace = create_workspace
2656 
2657  def SetNumConcurrentInstances(self, num_concurrent_instances):
2658  self._assert_can_mutate()
2659  self._step.num_concurrent_instances = num_concurrent_instances
2660 
2661  def SetOnlyOnce(self, only_once):
2662  self._assert_can_mutate()
2663  self._step.only_once = only_once
2664 
2665  def SetShouldStopBlob(self, should_stop_blob):
2666  assert isinstance(should_stop_blob, BlobReference), (
2667  "expects BlobReference here, got {}".format(type(should_stop_blob)))
2668  self._assert_can_mutate()
2669  self._step.should_stop_blob = str(should_stop_blob)
2670 
2671  def RunEveryMillis(self, interval):
2672  """
2673  Run this step every interval millisecods, as long as its
2674  siblings are still running. It is guaranteed that, after all
2675  siblings finish, this step will run at least one.
2676 
2677  This property is ignored for top-level ExecutionSteps.
2678  """
2679  self._step.run_every_ms = interval
2680 
2681  def SetReportNet(self, report_net, report_interval):
2682  """ DEPRECATED. Use RunEveryMillis instead. """
2683  self._assert_can_mutate()
2684  _add_net_to_dict(self._net_dict, report_net)
2685  self._step.report_net = get_net_name(report_net)
2686  self._step.report_interval = report_interval
2687 
2688  def AddSubstep(self, substep):
2689  self._assert_can_mutate()
2690  assert not self.HasNets(), 'Cannot have both network and substeps.'
2691  if isinstance(substep, ExecutionStep):
2692  substep._notify_is_used()
2693  if not substep.HasNets() and not substep.HasSubsteps():
2694  return self
2695  for net in substep.Nets():
2696  _add_net_to_dict(self._net_dict, net)
2697  self._substeps.append(substep)
2698  proto = substep.Proto()
2699  else:
2700  proto = substep
2701  self._step.substep.add().CopyFrom(proto)
2702  return self
2703 
2704  def SetConcurrentSubsteps(self, concurrent_substeps):
2705  self._assert_can_mutate()
2706  assert not self.HasNets(), 'Cannot have both network and substeps.'
2707  self._step.concurrent_substeps = concurrent_substeps
2708 
2709  def AddNet(self, net):
2710  self._assert_can_mutate()
2711  assert not self.HasSubsteps(), 'Cannot have both network and substeps.'
2712  assert isinstance(net, Net)
2713  _add_net_to_dict(self._net_dict, net)
2714  self._step.network.extend([get_net_name(net)])
2715  return self
2716 
2717  def get_all_attributes(self, name):
2718  """
2719  Return the list of all attributes under the given `name`, present in
2720  all of the nets used in this execution step and its children.
2721  """
2722  return [
2723  attr
2724  for net in viewvalues(self._net_dict)
2725  for attr in net.get_attributes(name)
2726  ]
2727 
2728  @classmethod
2729  def create_from_proto(cls, step_proto, net_obj_dict, net_proto_dict):
2730  """
2731  Create ExecutionStep from ExecutionStep protobuf recursively
2732  """
2733  assert isinstance(step_proto, caffe2_pb2.ExecutionStep)
2734  assert (len(step_proto.network) > 0 and len(step_proto.substep) == 0) or \
2735  (len(step_proto.network) == 0 and len(step_proto.substep) > 0)
2736 
2737  steps_or_nets = []
2738  if len(step_proto.substep) > 0:
2739  for substep_proto in step_proto.substep:
2740  steps_or_nets.append(ExecutionStep.create_from_proto(
2741  substep_proto, net_obj_dict, net_proto_dict))
2742  else:
2743  for net_name in step_proto.network:
2744  if net_name not in net_obj_dict:
2745  assert net_name in net_proto_dict
2746  net = Net(net_proto_dict[net_name])
2747  net_obj_dict[net_name] = net
2748  net = net_obj_dict[net_name]
2749  assert isinstance(net, Net)
2750  steps_or_nets.append(net)
2751 
2752  num_iter = step_proto.num_iter if step_proto.HasField('num_iter') else None
2753  concurrent_substeps = step_proto.concurrent_substeps if\
2754  step_proto.HasField('concurrent_substeps') else None
2755  should_stop_blob = BlobReference(step_proto.should_stop_blob) if\
2756  step_proto.HasField('should_stop_blob') else None
2757  only_once = step_proto.only_once if\
2758  step_proto.HasField('only_once') else None
2759  num_concurrent_instances = step_proto.num_concurrent_instances if\
2760  step_proto.HasField('num_concurrent_instances') else None
2761  create_workspace = step_proto.create_workspace if\
2762  step_proto.HasField('create_workspace') else None
2763  run_every_ms = step_proto.run_every_ms if\
2764  step_proto.HasField('run_every_ms') else None
2765 
2766  return execution_step(
2767  step_proto.name,
2768  steps_or_nets,
2769  num_iter=num_iter,
2770  report_net=None, # DEPRECATED
2771  report_interval=None, # DEPRECATED
2772  concurrent_substeps=concurrent_substeps,
2773  should_stop_blob=should_stop_blob,
2774  only_once=only_once,
2775  num_concurrent_instances=num_concurrent_instances,
2776  create_workspace=create_workspace,
2777  run_every_ms=run_every_ms)
2778 
2779 
2780 def add_nets_in_order(step, net_list):
2781  proto = step.Proto()
2782  for substep in step.Substeps():
2783  add_nets_in_order(substep, net_list)
2784  for net in proto.network:
2785  if net not in net_list:
2786  net_list.append(net)
2787  # FIXME(azzolini): This is actually wrong. Report nets should be
2788  # instantiated first since they may run before any substep is run.
2789  # However, curerntly, Reporter depends on this behavior.
2790  if proto.report_net and proto.report_net not in net_list:
2791  net_list.append(proto.report_net)
2792 
2793 
2794 class Plan(object):
2795 
2796  def __init__(self, name_or_step):
2797  self._plan = caffe2_pb2.PlanDef()
2798  self._net_dict = OrderedDict()
2799  self._steps = [] # A list of ExecutionStep
2800  if isinstance(name_or_step, ExecutionStep):
2801  self._plan.name = name_or_step.Name()
2802  self.AddStep(name_or_step)
2803  elif isinstance(name_or_step, basestring):
2804  self._plan.name = name_or_step
2805  else:
2806  raise ValueError('name_or_step must be a string or ExecutionStep')
2807 
2808  def __str__(self):
2809  return self._plan.name
2810 
2811  def Proto(self):
2812  return self._plan
2813 
2814  def AddNets(self, nets):
2815  for net in nets:
2816  if _add_net_to_dict(self._net_dict, net):
2817  assert isinstance(net, Net)
2818  self._plan.network.add().CopyFrom(net.Proto())
2819 
2820  def Nets(self):
2821  return list(viewvalues(self._net_dict))
2822 
2823  def AddStep(self, step):
2824  assert isinstance(step, ExecutionStep)
2825  step._notify_is_used()
2826  if not step.HasNets() and not step.HasSubsteps():
2827  return
2828  self._plan.execution_step.add().CopyFrom(step.Proto())
2829  self._steps.append(step)
2830  # nets need to be added to the plan in order of usage
2831  net_list = []
2832  add_nets_in_order(step, net_list)
2833  self.AddNets([step.get_net(n) for n in net_list])
2834 
2835  def Steps(self):
2836  return self._steps
2837 
2838  def get_all_attributes(self, name):
2839  """
2840  Return the list of all attributes under the given `name`, present in
2841  all of the nets used in this plan.
2842  """
2843  return [
2844  attr
2845  for net in viewvalues(self._net_dict)
2846  for attr in net.get_attributes(name)
2847  ]
2848 
2849  @classmethod
2850  def create_from_proto(cls, plan_proto):
2851  assert isinstance(plan_proto, caffe2_pb2.PlanDef)
2852  plan = Plan(plan_proto.name)
2853  plan._plan.CopyFrom(plan_proto)
2854  del plan._plan.network[:]
2855  del plan._plan.execution_step[:]
2856 
2857  net_obj_dict = {}
2858  net_proto_dict = {}
2859  for net_proto in plan_proto.network:
2860  assert net_proto.name not in net_proto_dict
2861  net_proto_dict[net_proto.name] = net_proto
2862 
2863  for step_proto in plan_proto.execution_step:
2864  step = ExecutionStep.create_from_proto(
2865  step_proto, net_obj_dict, net_proto_dict)
2866  plan.AddStep(step)
2867 
2868  return plan
2869 
2870 
2871 def to_execution_step(step_or_nets, default_name=None):
2872  from caffe2.python.net_builder import NetBuilder
2873  if isinstance(step_or_nets, ExecutionStep):
2874  return step_or_nets
2875 
2876  stop_blob = None
2877  if not default_name and hasattr(step_or_nets, 'name'):
2878  default_name = step_or_nets.name
2879  if isinstance(step_or_nets, NetBuilder):
2880  stop_blob = step_or_nets._stop_blob
2881  step_or_nets = step_or_nets.get()
2882  return execution_step(
2883  default_name, step_or_nets, should_stop_blob=stop_blob)
2884 
2885 
2886 def execution_step(default_name,
2887  steps_or_nets,
2888  num_iter=None,
2889  report_net=None,
2890  report_interval=None,
2891  concurrent_substeps=None,
2892  should_stop_blob=None,
2893  only_once=None,
2894  num_concurrent_instances=None,
2895  create_workspace=False,
2896  run_every_ms=None):
2897  """
2898  Helper for creating an ExecutionStep.
2899  - steps_or_nets can be:
2900  - None
2901  - Net
2902  - ExecutionStep
2903  - list<Net>
2904  - list<ExecutionStep>
2905  - should_stop_blob is either None or a scalar boolean blob.
2906  - This blob is checked AFTER every substeps/subnets.
2907  - If specified and true, then this step will return immediately.
2908  - Be sure to handle race conditions if setting from concurrent threads.
2909  - if no should_stop_blob or num_iter is provided, defaults to num_iter=1
2910  """
2911  assert should_stop_blob is None or num_iter is None, (
2912  'Cannot set both should_stop_blob and num_iter.')
2913  if should_stop_blob is None and num_iter is None:
2914  num_iter = 1
2915 
2916  step = ExecutionStep(default_name)
2917  if should_stop_blob is not None:
2918  step.SetShouldStopBlob(should_stop_blob)
2919  if num_iter is not None:
2920  step.SetIter(num_iter)
2921  if only_once is not None:
2922  step.SetOnlyOnce(only_once)
2923  if concurrent_substeps is not None:
2924  step.SetConcurrentSubsteps(concurrent_substeps)
2925  if report_net is not None:
2926  assert report_interval is not None
2927  step.SetReportNet(report_net, report_interval)
2928  if num_concurrent_instances is not None:
2929  step.SetNumConcurrentInstances(num_concurrent_instances)
2930  if create_workspace:
2931  step.SetCreateWorkspace(True)
2932  if run_every_ms:
2933  step.RunEveryMillis(run_every_ms)
2934 
2935  if isinstance(steps_or_nets, ExecutionStep):
2936  step.AddSubstep(steps_or_nets)
2937  elif isinstance(steps_or_nets, Net):
2938  step.AddNet(steps_or_nets)
2939  elif isinstance(steps_or_nets, list):
2940  if all(isinstance(x, Net) for x in steps_or_nets):
2941  for x in steps_or_nets:
2942  step.AddNet(x)
2943  else:
2944  for x in steps_or_nets:
2945  step.AddSubstep(to_execution_step(x))
2946  elif steps_or_nets:
2947  raise ValueError(
2948  'steps_or_nets must be a step, a net, or a list of nets or steps.')
2949  return step
2950 
2951 
2952 def scoped_execution_step(name, *args, **kwargs):
2953  """Same as execution_step() except that the step name is scoped."""
2954  default_name = ScopedName(name) if name else name
2955  return execution_step(default_name, *args, **kwargs)
2956 
2957 
2958 def _extract_stacktrace():
2959  '''
2960  This function extracts stacktrace without file system access
2961  by purely using sys._getframe() and removes part that belongs to
2962  this file (core.py). We are not using inspect module because
2963  its just a wrapper on top of sys._getframe() whos
2964  logis is based on accessing source files on disk - exactly what
2965  we are trying to avoid here. Same stands for traceback module
2966 
2967  The reason for file system access avoidance is that
2968  if code is located on an NFS, file access might be slow
2969 
2970  Function returns a list of tuples (file_name, line_number, function)
2971  '''
2972 
2973  result = []
2974  # Ignore top 3 layers of stack: this function, _CreateAndAddToSelf, and
2975  # whatever calls _CreateAndAddToSelf (either __getattr__ or Python)
2976  frame = sys._getframe(3)
2977  # We just go down the frame stack in a loop
2978  while frame:
2979  # Its important to extract information from the frame here
2980  # as frame's current line most probably will change later.
2981  result.append((frame.f_code.co_filename, frame.f_lineno, frame.f_code.co_name))
2982  frame = frame.f_back
2983  return result
2984 
2985 
2986 SetPerOpEnginePref = C.set_per_op_engine_pref
2987 SetGlobalEnginePref = C.set_global_engine_pref
2988 SetEnginePref = C.set_engine_pref
2989 SetOpEnginePref = C.set_op_engine_pref
def AddStep(self, step)
Definition: core.py:2823
def RunAllOnIDEEP(self)
Definition: core.py:2150
def add_attribute(self, name, obj)
Definition: core.py:1509
def BuildGradientGenerators(self, fwd_op_idx, gradient_ops, g_output, g_input)
Definition: core.py:624
def _CreateAndAddToSelf(self, op_type, inputs, outputs=None, kwargs)
Definition: core.py:2156
def _RecreateLookupTables(self)
Definition: core.py:1936
def Name(self)
Definition: core.py:1541
def get_attributes(self, name)
Definition: core.py:1516
def recover_input_record_by_prefix(self, prefix)
Definition: core.py:2055
def AddExternalOutput(self, outputs)
Definition: core.py:2000
device
Definition: core.py:2364
def NextBlob(self, prefix='unnamed')
Definition: core.py:1876
def get_all_attributes(self, name)
Definition: core.py:2717
def RegisterGradient(cls, op_type)
Definition: core.py:1060
def external_inputs(self)
Definition: core.py:2031
def SanityCheck(self, operators)
Definition: core.py:498
def DoGradientAccumulation(self, fwd_op_idx)
Definition: core.py:883
def _VerifyGradientGenerators(self, generator)
Definition: core.py:846
def set_rand_seed(self, seed=100, sequence_seed=True, seed_on_op_def=False)
Definition: core.py:1523
def _CreateAndAddToNet(self, op_type, inputs=None, args, kwargs)
Definition: core.py:245
def _AppendAutoGradGenerator(self, y, grad, autograd_op)
Definition: core.py:929
def __init__(self, name_or_proto)
Definition: core.py:1417
def Clone(self, name, blob_remap=None, op_id_mask=None, remap_funcs=None, keep_schema=True, update_external_list=False)
Definition: core.py:1633
def _MakeSparseSumOps(self, generators, out_base_name)
Definition: core.py:781
def UsedBlobNames(self)
Definition: core.py:1600
def AddGradientOperators(self, ys, skip=0)
Definition: core.py:1948
def GetBackwardPass(self, ys)
Definition: core.py:993
def Proto(self)
Definition: core.py:1791
def _MakeSumOps(self, input_name, input_version)
Definition: core.py:833
def RunAllOnGPU(self, gpu_id=0, use_cudnn=False)
Definition: core.py:2138
def _RunAllOnGPU(net, gpu_id=0, use_cudnn=False)
Definition: core.py:2122
def get_all_attributes(self, name)
Definition: core.py:2838
def _DisambiguateGradOpOutput(self, grad_op, idx, cnt)
Definition: core.py:730
def CheckGradientOperatorInput(self, grad_op_input, g_output, fwd_op_idx, locally_generated_blobs)
Definition: core.py:530
def Play(self, op)
Definition: core.py:508
def RunEveryMillis(self, interval)
Definition: core.py:2671
def NextName(self, prefix=None, output_id=None)
Definition: core.py:1892
def _ExtendOps(self, new_ops)
Definition: core.py:1911
def __getattr__(self, op_type)
Definition: core.py:256
def set_output_record(self, record)
Definition: core.py:2064
def _GetSumOpOutputName(self, generator, input_name)
Definition: core.py:696
def _SetSumOpsDeviceOption(self, sum_ops, generators)
Definition: core.py:717
def AppendSparseGenerators(self, sparse_generators)
Definition: core.py:600
def create_from_proto(cls, step_proto, net_obj_dict, net_proto_dict)
Definition: core.py:2729
def AddNets(self, nets)
Definition: core.py:2814
def __init__(self, name, net=None)
Definition: core.py:187
blob
Definition: core.py:2363
def BlobIsDefined(self, blob)
Definition: core.py:1578
def UsesBlob(self, blob)
Definition: core.py:1588
def _MakeDenseSumOps(self, generators, out_base_name)
Definition: core.py:747
def _GetGradientForOpCC(cls, op_def, g_output)
Definition: core.py:1070
def AddExternalInput(self, inputs)
Definition: core.py:1985
def NextScopedBlob(self, prefix='unnamed')
Definition: core.py:1867
def insert_op_at_idx(self, op, op_idx)
Definition: core.py:1795
def _GetInitGradients(self, ys)
Definition: core.py:940
def Python(self, f, grad_f=None, python_func_type=None, pass_workspace=False, grad_output_indices=None, grad_input_indices=None)
Definition: core.py:2219
def _CheckSumOpsConflict(self, out_base_name, g)
Definition: core.py:739
Definition: core.py:2361
def RunAllOnMKL(self)
Definition: core.py:2144
def ClonePartial(self, name, inputs, outputs, remap_funcs=None)
Definition: core.py:1726
def GetBlobRef(self, blob_name)
Definition: core.py:1614
def _GenerateGradientsForForwardOp(self, forward_op_idx, input_to_grad)
Definition: core.py:965
def _InvalidateLookupTables(self)
Definition: core.py:1933
def SetReportNet(self, report_net, report_interval)
Definition: core.py:2681
def set_input_record(self, input_record)
Definition: core.py:2038
def GetBackwardPass(cls, operators, ys, ys_generate_gradient=False)
Definition: core.py:1133
def recover_output_record_by_prefix(self, prefix)
Definition: core.py:2076
def external_outputs(self)
Definition: core.py:2035