3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from collections
import namedtuple, OrderedDict, defaultdict
9 from past.builtins
import basestring
10 from future.utils
import viewitems, viewkeys, viewvalues
11 from itertools
import chain
12 from six
import binary_type, string_types, text_type
14 from caffe2.proto
import caffe2_pb2
17 gen_do_gradient, gen_if_gradient, gen_while_gradient, disambiguate_grad_if_op_output
29 if (sys.platform ==
'darwin' and 'leveldb' in C.registered_dbs()):
30 print(
'If you are using homebrew leveldb on a Mac OS, you might see an ' 31 'error warning you that malloc_zone_unregister() failed. This is ' 32 'not a caffe2 issue but is due to the homebrew leveldb having an ' 33 'incompatible memory allocator. It does not affect usage.')
36 DeviceScope = scope.DeviceScope
37 NameScope = scope.NameScope
46 for name, value
in caffe2_pb2.TensorProto.DataType.items():
47 setattr(DataType, name, value)
53 def _GetRegisteredOperators():
54 return set(workspace.RegisteredOperators())
57 _REGISTERED_OPERATORS = _GetRegisteredOperators()
60 def RefreshRegisteredOperators():
61 global _REGISTERED_OPERATORS
62 _REGISTERED_OPERATORS = _GetRegisteredOperators()
65 _GLOBAL_INIT_ARGS = []
69 _GLOBAL_INIT_ARGS.extend(args[1:])
73 def GetGlobalInitArgs():
74 return _GLOBAL_INIT_ARGS[:]
77 def IsOperator(op_type):
78 return IsOperatorWithEngine(op_type, engine=
'DEFAULT')
81 def IsOperatorWithEngine(op_type, engine):
82 return C.op_registry_key(op_type, engine)
in _REGISTERED_OPERATORS
85 def IsGPUDeviceType(device_type):
86 return device_type
in {caffe2_pb2.CUDA, caffe2_pb2.HIP}
97 option = caffe2_pb2.DeviceOption()
98 option.device_type = device_type
99 option.device_id = device_id
100 if node_name
is not None:
101 option.node_name = node_name
102 if random_seed
is not None:
103 option.random_seed = random_seed
104 if numa_node_id
is not None:
105 assert device_type == caffe2_pb2.CPU
106 option.numa_node_id = numa_node_id
107 if extra_info
is not None:
108 option.extra_info.extend(extra_info)
112 def device_option_equal(opt1, opt2, ignore_node_name=True, ignore_random_seed=True):
113 if not opt1
or not opt2:
115 if not ignore_node_name
and opt1.node_name != opt2.node_name:
117 if not ignore_random_seed
and opt1.random_seed != opt2.random_seed:
119 if not opt1.device_type
or not opt2.device_type:
121 return not opt1.device_type
and not opt2.device_type
122 return opt1.device_id == opt2.device_id
125 def InferBlobDevices(net):
127 Compute mapping from parameters to devices by looking at the 128 device option of the op that creates the blob has 131 for op
in net.Proto().op:
132 op_device = op.device_option
133 if op_device
is None:
134 op_device = caffe2_pb2.DeviceOption(caffe2_pb2.CPU)
137 mapping[b] = op_device
141 def InferOpBlobDevicesAsDict(op):
142 input_dev_list, output_dev_list = InferOpBlobDevices(op)
144 op.input[i]: input_dev_list[i]
145 for i
in range(len(op.input))
148 op.output[i]: output_dev_list[i]
149 for i
in range(len(op.output))
151 return input_dict, output_dict
154 def InferOpBlobDevices(op):
155 device_info = C.infer_op_input_output_device(op.SerializeToString())
158 for dev_str
in device_info[0]:
159 device_option = caffe2_pb2.DeviceOption()
160 device_option.ParseFromString(dev_str)
161 input_info.append(device_option)
162 for dev_str
in device_info[1]:
163 device_option = caffe2_pb2.DeviceOption()
164 device_option.ParseFromString(dev_str)
165 output_info.append(device_option)
166 return input_info, output_info
169 def InferOpDeviceAsBlobDevices(op):
170 op_dev = op.device_option
if op.device_option
else caffe2_pb2.DeviceOption()
171 input_dev = [op_dev] * len(op.input)
172 output_dev = [op_dev] * len(op.output)
173 return input_dev, output_dev
176 GradientSlice = namedtuple(
'GradientSlice', [
'indices',
'values'])
180 """A wrapper around a blob in a net. 182 BlobReference gives us a way to refer to the network that the blob is 183 generated from. Note that blobs are, essentially, just strings in the 188 """Initializes a blob reference. 190 Note that this does not prepends the namescope. If needed, use 191 ScopedBlobReference() to prepend the existing namespace. 193 if isinstance(name, string_types):
195 elif isinstance(name, binary_type):
196 self.
_name = name.decode(
'utf-8')
198 self.
_name = str(name)
205 return hash(self.
_name)
207 def __eq__(self, other):
208 if isinstance(other, string_types):
209 return self.
_name == other
210 elif isinstance(other, binary_type):
211 return self.
_name == other.decode(
'utf-8')
212 elif isinstance(other, BlobReference):
213 return self.
_name == other._name
217 def __ne__(self, other):
218 return not(self == other)
224 return 'BlobReference("{}")'.format(self.
_name)
226 def __add__(self, other):
227 if not isinstance(other, string_types):
228 raise RuntimeError(
'Cannot add BlobReference to a non-string.')
231 def __radd__(self, other):
232 if not isinstance(other, string_types):
233 raise RuntimeError(
'Cannot add a non-string to BlobReference.')
239 def GetNameScope(self):
240 return self.
_name[:self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1]
242 def GetUnscopedName(self):
243 return self.
_name[self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1:]
245 def _CreateAndAddToNet(self, op_type, inputs=None, *args, **kwargs):
246 """Internal function that routes the operator generation to the 247 network's __getattr__ function. 249 inputs = []
if inputs
is None else inputs
250 if isinstance(inputs, BlobReference)
or isinstance(inputs, string_types):
253 inputs.insert(0, self)
254 return self._from_net.__getattr__(op_type)(inputs, *args, **kwargs)
257 """A wrapper allowing one to initiate operators from a blob reference. 259 Example: for a blob reference b that comes from network n, doing 261 is equivalent to doing 264 if op_type.startswith(
'__'):
265 raise AttributeError(
'Attribute {} not found.'.format(op_type))
268 'You cannot use a blob reference that does not have a net ' 269 'source to create operators. Create the operator from an ' 270 'explicit net object.')
271 if not IsOperator(op_type):
273 'Method ' + op_type +
' is not a registered operator.' +
275 ",".join(workspace.C.nearby_opnames(op_type)) +
']' 278 op_type, *args, **kwargs)
281 additional_methods = [
283 for op
in _REGISTERED_OPERATORS
284 if '_ENGINE_' not in op
or '_ENGINE_CUDNN' in op]
285 return sorted(set(chain(
287 viewkeys(self.__dict__),
292 def ScopedName(name):
293 """prefix the name with the current scope.""" 294 if isinstance(name, binary_type):
295 name = name.decode(
'ascii')
296 return scope.CurrentNameScope() + name
299 def ScopedBlobReference(name, *args, **kwargs):
300 """Returns a blob reference with scope prefixed.""" 304 def _RectifyInputOutput(blobs, net=None):
305 """A helper function to rectify the input or output of the CreateOperator 308 if isinstance(blobs, string_types)
or isinstance(blobs, binary_type):
312 return [ScopedBlobReference(blobs, net=net)]
313 elif type(blobs)
is BlobReference:
316 elif type(blobs)
in (list, tuple):
320 if isinstance(blob, string_types)
or isinstance(blob, binary_type):
321 rectified.append(ScopedBlobReference(blob, net=net))
322 elif type(blob)
is BlobReference:
323 rectified.append(blob)
326 "I/O blob #{} of unsupported type: {} of type {}" 327 .format(len(rectified), str(blob), type(blob)))
331 "Unknown input/output type: %s of type %s." %
332 (str(blobs), type(blobs))
348 """A function wrapper that allows one to create operators based on the 349 operator type. The type should be a string corresponding to an operator 350 registered with Caffe2. 352 operator = caffe2_pb2.OperatorDef()
353 if (os.environ.get(
'CAFFE2_DEBUG')):
354 stack = traceback.format_stack()
355 operator.debug_info =
"".join(stack[:-1])
357 operator.type = operator_type
360 inputs = _RectifyInputOutput(inputs)
361 outputs = _RectifyInputOutput(outputs)
362 operator.input.extend([text_type(i)
for i
in inputs])
363 operator.output.extend([text_type(o)
for o
in outputs])
365 control_input = _RectifyInputOutput(control_input)
366 operator.control_input.extend([text_type(i)
for i
in control_input])
372 if device_option
is not None:
373 operator.device_option.CopyFrom(device_option)
374 elif scope.CurrentDeviceScope()
is not None:
375 operator.device_option.CopyFrom(scope.CurrentDeviceScope())
376 if engine
is not None:
377 operator.engine = engine
378 if debug_info
is not None:
379 operator.debug_info = debug_info
383 if 'random_seed' in kwargs:
384 operator.device_option.random_seed = kwargs[
'random_seed']
385 del kwargs[
'random_seed']
388 operator.arg.extend(arg)
390 for key, value
in viewitems(kwargs):
391 if value
is not None:
392 operator.arg.add().CopyFrom(utils.MakeArgument(key, value))
394 if workspace.IsImmediate():
395 workspace.RunOperatorImmediate(operator)
399 def _RegisterPythonImpl(
400 f, grad_f=
None, python_func_type=
None, pass_workspace=
False 403 func = python_func_type(f)
405 grad_f = func.backward
407 if isinstance(f, tuple):
408 f = f[0](*f[1], **f[2])
409 if isinstance(grad_f, tuple):
410 grad_f = grad_f[0](*grad_f[1], **grad_f[2])
412 token = C.register_python_op(f, pass_workspace,
'')
414 C.register_python_gradient_op(token, grad_f)
418 def CreatePythonOperator(
422 pass_workspace=
False,
423 python_func_type=
None,
428 `f` should have a signature (inputs, outputs) 430 If `pass_workspace` is True, the signature is changed to 431 (inputs, outputs, workspace) where `workspace` is the workspace the op 432 is going to run on. This is potentially dangerous (as the op can manipulate 433 the workspace directly), use on your own risk. 435 kwargs[
"token"] = _RegisterPythonImpl(
436 f, grad_f, python_func_type, pass_workspace=pass_workspace
438 return CreateOperator(
"Python", inputs, outputs, *args, **kwargs)
441 def GetIndexFromGradientList(g_list, name):
442 """A helper function to get the index from a gradient list, None if not 444 for i, g
in enumerate(g_list):
447 elif type(g)
is GradientSlice:
448 if (g.indices == name
or g.values == name):
453 OpSSA = namedtuple(
'OpSSA', [
'op',
'in_versions',
'out_versions'])
454 GradGenMeta = namedtuple(
'GradGenMeta', [
'grad_op',
'idx',
'gradient'])
455 SparseGradGenMeta = namedtuple(
'SparseGradGenMeta', [
456 'grad_op_indices',
'idx_indices',
457 'grad_op_values',
'idx_values',
463 """A simple IR class to keep track of all intermediate representations used 464 in the gradient computation. 467 def __init__(self, operators):
486 self.
input_usages = defaultdict(
lambda: defaultdict(list))
498 def SanityCheck(self, operators):
502 if op.type ==
'StopGradient':
504 raise ValueError(
"""StopGradient's output '{}' is orphan. 505 You typically want to specify same input and output for 506 StopGradient. Op:\n\n{}""".format(op.output[0], str(op)))
509 """"Adds an op to the current IR, and update the internal states to 510 reflect the blobs and versions after the execution of the op. 527 self.ssa.append(OpSSA(op, in_versions, out_versions))
530 self, grad_op_input, g_output, fwd_op_idx, locally_generated_blobs):
531 """Checks if the gradient operators can be correctly carried out.""" 532 forward_op, in_versions, out_versions = self.
ssa[fwd_op_idx]
533 original_index = GetIndexFromGradientList(g_output, grad_op_input)
536 def versionMismatchInfoOut(name):
538 s +=
"Maybe you use same output blob twice for different ops?\n" 539 s +=
"== Version history of blob [{}]\n".format(name)
541 s +=
"Version (out) {} <-- {}".format(vers, op)
545 def versionMismatchInfoIn(name):
547 s +=
"Maybe the blob was overwritten by another op?\n" 548 s +=
"== Version history of blob [{}]\n".format(name)
550 s +=
"version (in) {} <-- {}".format(vers, op)
556 if original_index
is not None:
557 original_name = forward_op.output[original_index]
558 if (out_versions[original_name] !=
561 'Gradient name "%s" is expected to correspond ' 562 'to version %d of "%s", but currently we have ' 563 'version %d.\n\n' % (
564 grad_op_input, out_versions[original_name],
567 versionMismatchInfoOut(original_name))
570 elif grad_op_input
in out_versions:
571 if self.
frontier[grad_op_input] != out_versions[grad_op_input]:
573 'Gradient operator needs output "%s" at version' 574 ' %d, but currently we have version %d.\n\n' % (
575 grad_op_input, out_versions[grad_op_input],
577 ) + versionMismatchInfoOut(grad_op_input)
581 elif grad_op_input
in in_versions:
582 if (self.
frontier[grad_op_input] != in_versions[grad_op_input]):
584 'Gradient operator needs input "%s" at version ' 585 '%d, but currently we have version %d.\n\n' % (
586 grad_op_input, in_versions[grad_op_input],
588 ) + versionMismatchInfoIn(grad_op_input)
593 if grad_op_input
not in locally_generated_blobs:
595 'Blob name "%s" not in the scope of operator: ' 596 '%s\nand is not generated by any of the local ' 597 'gradient operators.' % (grad_op_input, str(forward_op))
600 def AppendSparseGenerators(self, sparse_generators):
602 for name, input_generators
in viewitems(sparse_generators):
603 for version, generators
in viewitems(input_generators):
604 if len(generators) == 1:
606 generator = generators[0]
609 assert(len(generators) == 2)
610 op1_i, idx1_i, op1_v, idx1_v, g1 = generators[0]
611 op2_i, idx2_i, op2_v, idx2_v, g2 = generators[1]
613 assert(op1_i
is None or op2_i
is None)
614 assert(op1_v
is None or op2_v
is None)
615 assert(idx1_i == 0
or idx2_i == 0)
616 assert(idx1_v == 0
or idx2_v == 0)
617 generator = SparseGradGenMeta(
618 op1_i
or op2_i, idx1_i + idx2_i,
619 op1_v
or op2_v, idx1_v + idx2_v,
624 self, fwd_op_idx, gradient_ops, g_output, g_input):
625 """Updates gradient_generators and gradient_frontier""" 626 forward_op, in_versions, out_versions = self.
ssa[fwd_op_idx]
627 locally_generated_blobs = []
628 sparse_generators = defaultdict(
lambda: defaultdict(list))
630 for grad_op
in gradient_ops:
632 for s
in grad_op.input:
634 s, g_output, fwd_op_idx, locally_generated_blobs)
639 locally_generated_blobs.extend([str(s)
for s
in grad_op.output])
640 for i, output
in enumerate(grad_op.output):
641 input_index = GetIndexFromGradientList(g_input, output)
642 if input_index
is not None:
643 input_name = forward_op.input[input_index]
644 input_version = in_versions[input_name]
645 g = g_input[input_index]
646 if type(g)
is GradientSlice:
652 if g.indices == output:
653 m = SparseGradGenMeta(grad_op, i,
None, 0, g)
655 assert(g.values == output)
656 m = SparseGradGenMeta(
None, 0, grad_op, i, g)
657 sparse_generators[input_name][input_version].append(m)
672 for input_index, g
in enumerate(g_input):
673 input_name = forward_op.input[input_index]
674 input_version = in_versions[input_name]
677 if type(g)
is GradientSlice:
678 if str(g.indices)
not in locally_generated_blobs
and \
679 str(g.values)
not in locally_generated_blobs:
681 SparseGradGenMeta(
None, 0,
None, 0, g))
683 if str(g)
not in locally_generated_blobs:
685 GradGenMeta(
None, 0, g))
690 for i, g
in enumerate(g_input):
692 input_name = forward_op.input[i]
693 input_version = in_versions[input_name]
696 def _GetSumOpOutputName(self, generator, input_name):
697 def remove_suffix(s, suffix):
698 if s.endswith(suffix):
699 return s[:-len(suffix)]
703 if type(g)
is GradGenMeta:
706 return grad_op.output[idx]
708 assert(type(g)
is SparseGradGenMeta)
709 op_i, idx_i, op_v, idx_v, _ = g
711 return remove_suffix(op_i.output[idx_i],
'_indices')
713 return remove_suffix(op_v.output[idx_v],
'_values')
715 return input_name +
'_grad' 717 def _SetSumOpsDeviceOption(self, sum_ops, generators):
720 for generator
in generators:
721 grad_op = generator.grad_op
if type(generator)
is GradGenMeta \
722 else generator.grad_op_values
or generator.grad_op_indices
724 if grad_op.HasField(
'device_option'):
726 op.device_option.CopyFrom(grad_op.device_option)
727 del op.device_option.extra_info[:]
730 def _DisambiguateGradOpOutput(self, grad_op, idx, cnt):
732 '_' + grad_op.output[idx] +
'_autosplit_{}'.format(cnt))
733 if grad_op.type ==
"If":
734 disambiguate_grad_if_op_output(grad_op, idx, new_grad_output)
736 grad_op.output[idx] = new_grad_output
737 return grad_op.output[idx], cnt + 1
739 def _CheckSumOpsConflict(self, out_base_name, g):
740 if str(out_base_name) == str(g):
743 'The gradient output of empty gradient op can not ' 744 'be the same as the normal name of the current ' 747 def _MakeDenseSumOps(self, generators, out_base_name):
751 assert len(generators) > 1
754 for generator
in generators:
755 grad_op, idx, g = generator
756 assert(type(g)
is not GradientSlice)
759 first_grad_op =
False 760 out = grad_op.output[idx]
763 sum_op_input.append(out)
766 sum_op_input.append(str(g))
768 if out_base_name
in sum_op_input:
771 idx = sum_op_input.index(out_base_name)
772 sum_op_input[0], sum_op_input[idx] = (
773 sum_op_input[idx], sum_op_input[0]
775 sum_ops = [CreateOperator(
779 return sum_ops, out_base_name
781 def _MakeSparseSumOps(self, generators, out_base_name):
782 indices_concat_input = []
783 values_concat_input = []
787 for generator
in generators:
788 assert(type(generator)
is SparseGradGenMeta)
789 op_i, idx_i, op_v, idx_v, g = generator
792 indices_concat_input.append(out)
795 indices_concat_input.append(g.indices)
798 values_concat_input.append(out)
801 values_concat_input.append(g.values)
803 indices_concat_output = out_base_name +
'_indices_concat' 804 indices_concat_split = out_base_name +
'_indices_concat_split' 805 values_concat_output = out_base_name +
'_values_concat' 806 values_concat_split = out_base_name +
'_values_concat_split' 816 [indices_concat_output, indices_concat_split]],
823 [values_concat_output, values_concat_split]],
827 sum_op_output = GradientSlice(
828 indices=indices_concat_output,
829 values=values_concat_output,
831 return sum_ops, sum_op_output
833 def _MakeSumOps(self, input_name, input_version):
836 types = list(set(type(x)
for x
in generators))
837 assert(len(types) == 1)
838 if types[0]
is GradGenMeta:
841 assert(types[0]
is SparseGradGenMeta)
846 def _VerifyGradientGenerators(self, generator):
849 if len({type(g)
for g
in generator}) > 1:
851 'Automatic aggregation of a mix of sparse and dense gradients ' 852 'is not supported yet')
857 if len(generator) < 2:
860 all_gradient_names = []
861 all_device_options = []
863 if type(g)
is GradGenMeta:
865 all_gradient_names.append(g.gradient)
866 all_device_options.append(g.grad_op.device_option)
868 assert(type(g)
is SparseGradGenMeta)
869 if g.grad_op_indices:
870 all_device_options.append(g.grad_op_indices.device_option)
872 all_device_options.append(g.grad_op_values.device_option)
873 all_gradient_names.append(g.gradient.values)
876 if len(all_device_options) >= 2
and not all(
877 device_option_equal(d, all_device_options[0])
878 for d
in all_device_options[1:]):
879 raise RuntimeError(
'Unexpected behavior: not all grad ops ' 880 'have the same device option.')
884 """For each input name in the forward op, check if we will need to 885 add gradient accumulation. If so, do gradient accumulation and return 886 the list of gradient operators. 888 The criteria for doing gradient accumulation is: 889 (1) the specific input version has been used by multiple operators. 890 (2) the current fwd_op_idx is the first to use that input, i.e. in the 891 backward pass, is the last to optionally generate the gradient for 893 (3) For the operators that used the input, their gradient operators 894 have generated more than 1 gradient. 896 When accumulating operators, our current solution is to rename all the 897 created gradients with an internal intermediate name, and then add a 898 Sum() operator that adds up all the gradients. This may use more memory 899 due to intermediate storage, but is usually the fastest approach as one 900 can do one single sum for multiple intermediate gradients. 902 forward_op, in_versions, out_versions = self.
ssa[fwd_op_idx]
903 additional_sum_ops = []
905 for _i, input_name
in enumerate(set(forward_op.input)):
906 input_version = in_versions[input_name]
907 input_usage = self.
input_usages[input_name][input_version]
908 if (len(input_usage) <= 1
or fwd_op_idx != input_usage[0]):
915 except RuntimeError
as err:
917 "Gradients for param ''{}'' failed to verify: {}".format(
924 sum_ops, g = self.
_MakeSumOps(input_name, input_version)
925 additional_sum_ops.extend(sum_ops)
926 grad_map[input_name] = g
927 return additional_sum_ops, grad_map
929 def _AppendAutoGradGenerator(self, y, grad, autograd_op):
933 generator = GradGenMeta(
934 autograd_op, 0
if autograd_op
else None, str(grad))
940 def _GetInitGradients(self, ys):
944 for y, g
in viewitems(ys):
947 autograd_op = CreateOperator(
948 "ConstantFill", [y], [str(y) +
"_autogen_grad"],
950 gradient_ops.append(autograd_op)
951 g = autograd_op.output[0]
954 input_to_grad[str(y)] = (
955 GradientSlice(str(g[0]), str(g[1]))
956 if isinstance(g, GradientSlice)
else str(g))
959 if autograd_op
is not None:
962 return input_to_grad, gradient_ops
964 def _GenerateGradientsForForwardOp(
965 self, forward_op_idx, input_to_grad):
966 new_input_to_grad = {}
968 forward_op, in_versions, out_versions = self.
ssa[forward_op_idx]
970 input_to_grad.get(name,
None)
for name
in forward_op.output)
972 if not all(g
is None for g
in g_output)
or (
973 forward_op.type ==
"ZeroGradient"):
974 gradient_ops, g_input = GradientRegistry.GetGradientForOp(
975 forward_op, g_output)
979 forward_op_idx, gradient_ops, g_output, g_input)
981 for name, grad
in zip(forward_op.input, g_input):
986 if grad
is not None or \
987 name
not in input_to_grad
or \
988 name
in list(forward_op.output):
989 new_input_to_grad[name] = grad
991 return new_input_to_grad, gradient_ops
994 """Gets the backward pass that computes the derivatives of given blobs. 997 ys: a list or a dictionary specifying what blobs we want to compute 998 derivatives of. If the input is a list, we will automatically 999 generate their gradients with all-one values; if the input is a 1000 dictionary, for any dictionary entries that are not None, we will 1001 take the corresponding blobs as their gradients; for all those 1002 that are None, we will auto-fill them with 1. 1004 if isinstance(ys, list):
1005 ys = dict((y,
None)
for y
in ys)
1006 elif not isinstance(ys, dict):
1007 raise TypeError(
"ys should either be a list or a dict.")
1011 for y
in viewkeys(ys):
1023 for forward_op_idx
in reversed(range(len(self.
ssa))):
1025 forward_op_idx, all_input_to_grad)
1026 all_input_to_grad.update(input_to_grad)
1027 all_gradient_ops += gradient_ops
1035 all_input_to_grad.update(grad_map)
1036 all_gradient_ops += additional_sum_ops
1042 all_input_to_grad_out = {}
1043 for key, val
in viewitems(all_input_to_grad):
1045 if (isinstance(val, string_types)
or 1046 isinstance(val, binary_type)):
1052 return all_gradient_ops, all_input_to_grad_out
1056 """GradientRegistry holds the mapping from operators to their gradients.""" 1057 gradient_registry_ = {}
1061 """A decorator for registering gradient mappings.""" 1070 def _GetGradientForOpCC(cls, op_def, g_output):
1072 def from_untyped(grad):
1074 w = C.GradientWrapper()
1078 (indices, values) = grad
1079 w = C.GradientWrapper()
1082 assert w.is_sparse()
1085 w = C.GradientWrapper()
1090 g_output = [from_untyped(grad)
for grad
in g_output]
1091 grad_defs_str, g_input = C.get_gradient_defs(
1092 op_def.SerializeToString(), g_output)
1094 def to_untyped(grad_wrapper):
1095 if grad_wrapper.is_empty():
1097 if grad_wrapper.is_sparse():
1098 return GradientSlice(grad_wrapper.indices, grad_wrapper.values)
1099 assert grad_wrapper.is_dense()
1100 return grad_wrapper.dense
1102 g_input = [to_untyped(grad_wrapper)
for grad_wrapper
in g_input]
1104 for grad_def_str
in grad_defs_str:
1105 grad_def = caffe2_pb2.OperatorDef()
1106 grad_def.ParseFromString(grad_def_str)
1107 grad_defs.append(grad_def)
1108 return grad_defs, g_input
1111 def GetGradientForOp(cls, op, g_output):
1114 except Exception
as e:
1122 "Exception when creating gradient for [{}]:{}.\nOp: \n{}".
1123 format(op.type, e, str(op))
1126 if gradient_ops
is None:
1128 if type(gradient_ops)
is not list:
1129 gradient_ops = [gradient_ops]
1130 return gradient_ops, g_input
1134 """Gets the backward pass for the list of operators. 1137 operators: a list of operators constituting the forward pass. 1138 ys: a list or a dictionary specifying what blobs we want to compute 1139 derivatives of. If the input is a list, we will automatically 1140 generate their gradients with all-one values; if the input is a 1141 dictionary, for any dictionary entries that are not None, we'll 1142 take the corresponding blobs as their gradients; for all those 1143 that are None, we will auto-fill them with 1. 1145 gradient_ops: a list of gradient operators to run. 1146 all_input_to_grads: a map from input to their corresponding 1150 return ir.GetBackwardPass(ys)
1153 GradientRegistry.RegisterGradient(
'Do')(gen_do_gradient)
1154 GradientRegistry.RegisterGradient(
'If')(gen_if_gradient)
1155 GradientRegistry.RegisterGradient(
'While')(gen_while_gradient)
1158 def get_ssa(net, blob_versions=None):
1160 Given a net, return a structure containing the version of each input and 1161 output blob used by each operator. 1164 net: either a Net or a NetDef 1165 blob_versions: (optional) map with current version number for given 1166 blob names. If not provided or blob not found, start 1169 Tuple (ssa, blob_versions) 1170 ssa: list of tuples (versioned_inputs, versioned_outputs) 1171 for each op in the net. A versioned input is a tuple 1172 (blob_name, version). 1173 blob_versions: updated map with latest version of each blob found in 1176 proto = net.Proto()
if isinstance(net, Net)
else net
1177 assert isinstance(proto, caffe2_pb2.NetDef)
1178 if blob_versions
is None:
1180 if isinstance(net, list):
1181 return [get_ssa(n, blob_versions)
for n
in net], blob_versions
1182 for i
in proto.external_input:
1183 if i
not in blob_versions:
1184 blob_versions[str(i)] = 0
1187 if not proto.external_input:
1189 if i
not in blob_versions:
1190 blob_versions[i] = 0
1191 inputs = [(str(i), blob_versions.get(str(i), 0))
for i
in op.input]
1193 blob_versions[str(o)] = blob_versions.get(str(o), 0) + 1
1194 outputs = [(str(o), blob_versions[str(o)])
for o
in op.output]
1195 ssa.append((inputs, outputs))
1196 return ssa, blob_versions
1199 def get_undefined_blobs(ssa):
1201 Given a ssa in the format produced by get_ssa(), return a set of blobs that 1202 are used before they are defined, which corresponds to inputs at version 0. 1205 for inputs, _outputs
in ssa:
1206 undef_blobs |= set(name
for (name, ver)
in inputs
if ver == 0)
1210 def get_output_producers(ssa):
1212 Given a ssa in the format produced by get_ssa(), returns a map from 1213 versioned blob into the operator index that produces that version of 1214 the blob. A versioned blob is a tuple (blob_name, version). 1217 for i, (_inputs, outputs)
in enumerate(ssa):
1223 def get_op_ids_in_path(ssa, blob_versions, inputs, outputs):
1225 Given a ssa and blob_versions as produced by get_ssa(), returns the list 1226 of op indices that are necessary in order to generate the blobs in 1227 `outputs`, given blobs in `inputs`. 1228 Consider that the `inputs` are given in their latest version. 1230 inputs_set = set((str(i), blob_versions[str(i)])
for i
in inputs)
1231 producers = get_output_producers(ssa)
1232 queue = [(str(o), blob_versions[str(o)])
for o
in outputs]
1234 while len(queue) > 0:
1236 if (o
not in inputs_set)
and (o
in producers):
1237 op_id = producers[o]
1238 if op_id
not in used_op_ids:
1239 used_op_ids |= {op_id}
1240 inputs, _ = ssa[op_id]
1241 queue.extend(inputs)
1242 return sorted(used_op_ids)
1245 def recurrent_network_op_remap(op, prefix, blob_remap):
1249 op : Caffe2 operator (RecurrentNetworkOp or RecurrentNetworkGradientOp). 1250 prefix: this argument is not used in this function, just for legacy support. 1251 blob_remap : Dictionary that represents the map from old blob name to new. 1253 Updates blob names in arguments of RecurrentNetworkOp and 1254 RecurrentNetworkGradientOp to conform to cloned input and output of both 1255 operators and also makes sure names of locally generated blobs in arguments 1256 have the same prefix as the input and output of the operators. 1259 def get_remapped_str(blob_str):
1260 if isinstance(blob_str, binary_type):
1261 blob_str = blob_str.decode(
'utf-8')
1262 return blob_remap.get(blob_str, blob_str).encode(
'utf-8')
1264 for argument
in op.arg:
1265 if len(argument.strings) > 0:
1266 for i
in range(len(argument.strings)):
1267 argument.strings[i] = get_remapped_str(argument.strings[i])
1268 elif argument.name ==
'timestep':
1269 argument.s = get_remapped_str(argument.s)
1270 elif argument.name.endswith(
'step_net'):
1272 remap_proto(argument, blob_remap)
1275 def control_op_remap(op, prefix, blob_remap):
1278 net_arg_names = [
'then_net',
'else_net']
1280 net_arg_names = [
'loop_net',
'cond_net']
1281 for argument
in op.arg:
1282 if argument.name
in net_arg_names:
1283 assert argument.n, \
1284 "Expected non empty net in " + op.type +
"'s " + argument.name +
" argument" 1285 subnet =
Net(argument.n)
1286 remapped_subnet = subnet.Clone(
1287 name=(subnet._net.name
if subnet._net.name
else '') +
'_remapped',
1288 blob_remap=blob_remap)
1289 argument.n.CopyFrom(remapped_subnet.Proto())
1292 DEFAULT_REMAP_FUNCS = {
1293 'RecurrentNetwork': recurrent_network_op_remap,
1294 'RecurrentNetworkGradient': recurrent_network_op_remap,
1295 'If': control_op_remap,
1296 'While': control_op_remap,
1300 def remap_proto(argument, blob_remap):
1301 subnet =
Net(argument.n)
1303 cloned_sub_net = subnet.Clone(
1308 argument.n.CopyFrom(cloned_sub_net.Proto())
1311 def clone_and_bind_net(net, name, prefix, blob_remap=None, inputs=None,
1314 Clone the given Net, binding its input schema to the given `inputs` record. 1315 Blob names defined by the net are prepended with the given `prefix`. 1318 net: the net to clone 1319 name: the name of the new net 1320 prefix: the prefix to append to local blobs 1321 blob_remap: (optional) dict with additional blob name remapping. 1322 inputs: (optional) input record that will provide actual input 1323 values for the cloned net. Must be compatible with the 1324 net's input schema or be a strict superset of it 1325 keep_schema: by default (True), the original schema will be kept and 1326 remapped accordingly. otherwise, the schema will be set as 1327 inputs or left empty if inputs is not given. 1329 Tuple (cloned_net, blob_remap) 1330 clone_net: the cloned Net 1331 blob_remap: a map from original blob names into remapped blob names 1334 assert isinstance(net, Net)
1335 if blob_remap
is None:
1337 if inputs
is not None:
1339 original = net.input_record()
1340 assert original
is not None 1342 diff = set(original.field_names()) - set(inputs.field_names())
1343 assert len(diff) == 0, (
1344 "Schemas don't match, extra fields {diff} found in the net {name}. " 1345 "original: {original}; inputs: {inputs}" 1347 diff=diff, name=net.Name(), original=original.field_names(),
1348 inputs=inputs.field_names()
1351 original_mapping = dict(zip(original.field_names(),
1352 original.field_blobs()))
1353 for fn, fb
in zip(inputs.field_names(), inputs.field_blobs()):
1354 if fn
in original_mapping:
1355 blob_remap[str(original_mapping[fn])] = str(fb)
1357 ssa, blob_versions = get_ssa(proto)
1358 undef_blobs = get_undefined_blobs(ssa)
1360 for blob
in viewkeys(blob_versions):
1361 if blob
in blob_remap:
1363 elif blob
in undef_blobs:
1364 blob_remap[blob] = blob
1366 blob_remap[blob] = prefix + blob
1367 cloned_net = net.Clone(name, blob_remap, keep_schema=keep_schema)
1368 if not keep_schema
and inputs:
1369 cloned_net.set_input_record(inputs)
1370 return cloned_net, blob_remap
1373 def _get_blob_ref(blob_name_or_ref):
1375 blob_name_or_ref
if isinstance(input, BlobReference)
1380 def _recover_record_by_prefix(names, prefix=''):
1382 Tries to recover record by taking a subset of blob names with 1383 a given prefix name and interpreting them as schema column names 1386 column_names = [name[len(prefix):]
for name
in names
1387 if name.startswith(prefix)]
1388 if not column_names:
1390 return schema.from_column_list(
1392 col_blobs=[_get_blob_ref(prefix + name)
for name
in column_names])
1396 _net_names_used = set()
1397 operator_registry_ = {}
1400 def current_prefix():
1402 builder = NetBuilder.current(required=
False)
1403 return builder.name
if builder
else '' 1406 def _get_next_net_name(basename):
1407 name = basename =
'/'.join(
1408 x
for x
in [Net.current_prefix(), basename]
if x
1411 while name
in Net._net_names_used:
1412 name = basename +
'_' + str(next_idx)
1414 Net._net_names_used |= set([name])
1421 name_or_proto: If a NetDef is provided, clone it. Otherwise, 1422 create an empty net with the given name. 1433 if type(name_or_proto)
is caffe2_pb2.NetDef:
1434 proto = name_or_proto
1437 self.
_net = caffe2_pb2.NetDef()
1438 self._net.CopyFrom(proto)
1440 existing_outputs = [list(op.output)
for op
in self._net.op]
1442 self._external_input_map.update(list(self._net.external_input))
1445 existing_names = set(
1447 [list(op.input)
for op
in self._net.op], []
1449 existing_outputs, []
1452 for outs
in existing_outputs:
1453 self._op_outputs.update(outs)
1455 prefix_len = len(self._net.name +
'_blob_')
1456 autogen_indices = []
1457 for s
in existing_names:
1458 if s.startswith(self._net.name +
'_blob_'):
1460 autogen_indices.append(int(s[prefix_len]))
1463 if len(autogen_indices):
1467 name = self._net.name
1469 name = name_or_proto
1470 self.
_net = caffe2_pb2.NetDef()
1474 self._net.name = Net._get_next_net_name(name)
1476 def AppendNet(self, net, device_option=None):
1477 assert isinstance(net, Net)
1478 for i
in net.Proto().external_input:
1480 i
not in self.
Proto().external_input
and 1483 self.
Proto().external_input.append(i)
1485 self.
Proto().external_output.extend(
1487 o
for o
in net.Proto().external_output
1488 if o
not in self.
Proto().external_output
1491 ops = net.Proto().op
1492 if device_option
is not None:
1493 ops = [copy.deepcopy(op)
for op
in ops]
1494 map(
lambda x: x.device_option.CopyFrom(device_option), ops)
1499 def LogInfo(self, *msg_or_blobs):
1500 for msg_or_blob
in msg_or_blobs:
1501 if not isinstance(msg_or_blob, BlobReference):
1502 blob = self.GivenTensorStringFill(
1504 shape=[], values=[msg_or_blob])
1507 self.Print(blob, [])
1511 Add `obj` to the list of attributes in this net under the given `name`. 1512 Attributes are user-defined objects and have no pre-defined semantics. 1518 Returns the list of attributes in this net for a given `name`. 1519 Attributes are user-defined objects added with `add_attribute'. 1521 return self._attr_dict.get(name, [])
1525 Adds a random seed to each op in the net. 1526 If sequence_seed is set, the i-th op has rand_seed=`seed + i` 1527 If seed_on_op_def is set, the op rand_seed=hash(str(op)) 1528 sequence_seed and seed_on_op_def cannot be both set to True. 1530 assert not (sequence_seed
and seed_on_op_def), (
1531 'sequence_seed and seed_on_op_def cannot be both set to True.')
1532 for i, op
in enumerate(self.
Proto().op):
1534 curr_seed = seed + i
1535 elif seed_on_op_def:
1536 curr_seed = hash(str(op) + str(seed)) % np.iinfo(np.uint32).max
1539 op.device_option.random_seed = curr_seed
1542 return self._net.name
1547 def Const(self, array, blob_out=None, dtype=None):
1548 if isinstance(array, bool):
1549 return self.ConstantFill(
1552 dtype=DataType.BOOL,
1556 array = np.array(array)
1558 array = np.array(array, dtype=dtype)
1560 def do_set(operator):
1565 values=array.flatten().tolist())
1567 if array.dtype == np.int32:
1568 return do_set(self.GivenTensorIntFill)
1569 elif array.dtype == np.int64:
1570 return do_set(self.GivenTensorInt64Fill)
1571 elif array.dtype == np.str:
1572 return do_set(self.GivenTensorStringFill)
1573 elif array.dtype == np.bool:
1574 return do_set(self.GivenTensorBoolFill)
1576 return do_set(self.GivenTensorFill)
1580 Returns true if the given BlobReference is produced as output of 1581 an operator in this net, or if it is provided as an external input. 1590 Returns true iff the given BlobReference is used by any operator 1591 or this net, or if it is one of the external inputs of the net. 1593 blob_name = str(blob)
1594 for op
in self._net.op:
1595 for input
in op.input:
1596 if input == blob_name:
1602 Returns a set of blob names used in the net 1605 for op
in self._net.op:
1606 blob_names |= set(op.input)
1607 blob_names |= set(op.output)
1608 if self._net.external_input:
1609 blob_names |= set(self._net.external_input)
1610 if self._net.external_output:
1611 blob_names |= set(self._net.external_output)
1616 Given the name of a blob produced by this net, return a BlobReference 1617 to it. If the blob is not produced by any op in this net, 1620 blob_name = str(blob_name)
1622 raise KeyError(
'Net does not define blob %s' % blob_name)
1632 update_external_list=
False,
1637 name: name of the cloned net 1638 blob_remap: optional map with list of blob names to replace 1639 op_id_mask: optional list of operator indices to include in 1640 the cloned net. If not provided, all ops are included. 1642 orig_remap_funcs = {}
if remap_funcs
is None else remap_funcs
1646 remap_funcs = DEFAULT_REMAP_FUNCS.copy()
1647 remap_funcs.update(orig_remap_funcs)
1649 new_proto = caffe2_pb2.NetDef()
1650 new_proto.CopyFrom(proto)
1651 new_proto.name = name
1653 if blob_remap
is None:
1655 if op_id_mask
is None:
1656 op_id_mask = list(range(0, len(proto.op)))
1658 def get_remapped_str(blob):
1659 blob_str = str(blob)
1660 return str(blob_remap.get(blob_str, blob_str))
1662 def remap_list(proto_list):
1663 new_list = [get_remapped_str(b)
for b
in proto_list]
1665 proto_list.extend(new_list)
1668 new_op = caffe2_pb2.OperatorDef()
1670 remap_list(new_op.input)
1671 remap_list(new_op.output)
1672 if new_op.type
in remap_funcs:
1673 remap_funcs[new_op.type](
1675 (name +
'/')
if name
else '',
1681 new_proto.op.extend([remap_op(proto.op[op_id])
for op_id
in op_id_mask])
1682 remap_list(new_proto.external_input)
1683 remap_list(new_proto.external_output)
1684 new_net =
Net(new_proto)
1689 new_net._input_record = schema.from_blob_list(
1693 for blob
in self._input_record.field_blobs()
1697 new_net._output_record = schema.from_blob_list(
1701 for blob
in self._output_record.field_blobs()
1706 if update_external_list:
1708 existing_outputs = set()
1709 used_outputs = set()
1710 del new_net.Proto().external_input[:]
1711 del new_net.Proto().external_output[:]
1712 for op
in new_net.Proto().op:
1714 if ib
not in existing_outputs:
1715 new_net.Proto().external_input.extend([ib])
1717 used_outputs.add(ib)
1718 for ob
in op.output:
1719 existing_outputs.add(ob)
1721 for ob
in existing_outputs:
1722 if ob
not in used_outputs:
1723 new_net.Proto().external_output.extend([ob])
1728 Clone this net, including only ops that are necessary in order to 1729 compute `outputs` given `inputs`. Return references to the cloned 1730 outputs. Internal blobs (blobs that are produced and consumed inside 1731 the net but not used as outputs) will be remapped to avoid name 1735 name: the name of the cloned net 1736 inputs: map where the keys correspond to BlobReferences in the 1737 original net, and the values correspond to external inputs 1738 in the partially cloned net. If `inputs` is a list, don't 1740 outputs: outputs to be produced by the cloned net. 1743 Tuple (new_net, new_outputs) 1744 new_net: a new Net object. 1745 new_outputs: list of BlobReferences corresponding to the 1746 outputs produced by new_net. 1748 input_is_pair_list = isinstance(inputs, list)
and all(
1749 isinstance(i, tuple)
and len(i) == 2
for i
in inputs)
1751 inputs
if isinstance(inputs, (dict, OrderedDict))
else 1752 OrderedDict(inputs)
if input_is_pair_list
else 1753 OrderedDict(zip(inputs, inputs)))
1754 for output
in outputs:
1755 assert self.
BlobIsDefined(output),
"{} is not defined".format(output)
1756 input_names = {str(k): str(v)
for k, v
in viewitems(inputs)}
1757 output_names = [str(o)
for o
in outputs]
1759 blob_versions = {str(i): 0
for i
in inputs}
1760 ssa, blob_versions = get_ssa(proto, blob_versions)
1761 used_op_ids = get_op_ids_in_path(ssa, blob_versions, inputs, outputs)
1762 disallowed_op_ids = get_op_ids_in_path(ssa, blob_versions, [], inputs)
1763 assert len(set(used_op_ids) & set(disallowed_op_ids)) == 0, (
1764 'Cannot partially clone net: some of the ops required would ' +
1765 'generate the given input.')
1767 sub_ssa = [op
for i, op
in enumerate(ssa)
if i
in used_op_ids]
1768 undef_blobs = get_undefined_blobs(sub_ssa) - set(viewkeys(input_names))
1769 prefix = (name +
'/')
if name
else '' 1771 def remap(blob_name):
1772 if blob_name
in input_names:
1773 return input_names[blob_name]
1774 elif blob_name
in undef_blobs:
1777 return prefix + blob_name
1779 blob_mapping = {b: remap(b)
for b
in viewkeys(blob_versions)}
1780 new_net = self.
Clone(name, blob_mapping, used_op_ids, remap_funcs)
1782 blob_mapping[i]
for i
in viewkeys(input_names)] + list(undef_blobs)
1783 new_out = [blob_mapping[o]
for o
in output_names]
1784 del new_net.Proto().external_input[:]
1785 new_net.Proto().external_input.extend(new_in)
1786 new_net._external_input_map = set(list(new_in))
1787 del new_net.Proto().external_output[:]
1788 new_net.Proto().external_output.extend(new_out)
1789 return new_net, [new_net.GetBlobRef(o)
for o
in new_out]
1795 def insert_op_at_idx(self, op, op_idx):
1796 r""" inserting operator at index. Will update external blob list. 1799 temp_ops = self.
Proto().op[op_idx:]
1800 del self.
Proto().op[op_idx:]
1801 self.
Proto().op.extend([op])
1802 self.
Proto().op.extend(temp_ops)
1803 self.external_outputs.extend(op.output)
1804 self.external_inputs.extend(op.input)
1806 def reroute_tensor(self, tensor, new_producer, can_modify=None):
1807 r""" reroute tensor to new_producer. And feed new tensor to consumers 1808 and interseciton with can_modify if provided. 1810 tensor: str or blob_reference the tensor to reroute 1811 new_producer: an op takes in tensor gives new_tesnor 1812 can_modify: a list/set of operators that consumes tensor and can be 1816 reroute_cnt: how many consumer op has been changed 1818 Note: assume no inplace blob in net 1820 def _find_tensor_input_op(tensor):
1824 assert tensor
in new_producer.input, \
1825 "new producer {} is not taking in {}".format(
1826 new_producer.type, tensor)
1829 for index, op
in enumerate(self.
Proto().op):
1842 op_idx = max(_find_tensor_input_op(t)
for t
in new_producer.input)
1844 new_tensor = new_producer.output[0]
1847 new_list = [new_tensor
if b == tensor
else b
for b
in self.
external_outputs]
1848 del self.
Proto().external_output[:]
1849 self.
Proto().external_output.extend(new_list)
1854 for op
in self.
Proto().op:
1855 if op
in can_modify:
1856 remap_input(op, {tensor: new_tensor})
1857 reroute_cnt = reroute_cnt + 1
1860 def PopulateProtoWithFileName(self):
1861 net_tb = workspace.operator_tracebacks.get(self.
Name(),
None)
1862 if net_tb
is not None:
1863 for idx, op
in enumerate(self.
Proto().op):
1865 op.name =
':'.join(map(str, net_tb[idx][0]))
1868 """Return the blob that has not been defined or registered in the 1869 current net. It returns `ScopedBlobReference(prefix)`, if it's valid, 1870 otherwise `ScopedBlobReference(prefix) + '_auto_' + ?`. Different calls 1871 is guaranteed to return blob with different names. 1873 output_blob_base = ScopedName(prefix)
1874 return self.
NextBlob(output_blob_base)
1877 """Return the blob that has not been defined or registered in the 1878 current net. It returns `BlobReference(prefix)`, if it's valid, 1879 otherwise `BlobReference(prefix) + '_auto_' + ?`. Different calls 1880 is guaranteed to return blob with different names.""" 1882 output_blob = output_blob_base
1886 output_blob = output_blob_base +
'_auto_' + str(index)
1889 self._registered_blob_names.add(str(output_blob))
1893 """Returns the next name to be used, if you do not want to explicitly 1894 name your blob. [Deprecated, use NextBlob, NextScopedBlob instead]""" 1896 output_name_base = self._net.name +
'/' + prefix
1897 output_name = output_name_base
1898 if output_id
is not None:
1899 output_name +=
':' + str(output_id)
1901 while self.
BlobIsDefined(str(ScopedBlobReference(output_name))):
1902 output_name = output_name_base +
'_' + str(index)
1903 if output_id
is not None:
1904 output_name +=
':' + str(output_id)
1909 return str(output_name)
1911 def _ExtendOps(self, new_ops):
1912 self._net.op.extend(new_ops)
1914 self._op_outputs.update([text_type(o)
for o
in op.output])
1916 def _CheckLookupTables(self):
1918 Called from unit tests to validate the internal lookup tables 1919 match the protobuf contents. 1921 test_op_outputs = set()
1922 for op
in self._net.op:
1924 test_op_outputs.add(o)
1926 test_external_inp = set()
1927 for inp
in self._net.external_input:
1928 test_external_inp.add(inp)
1930 assert test_op_outputs.difference(self.
_op_outputs) == set()
1933 def _InvalidateLookupTables(self):
1936 def _RecreateLookupTables(self):
1938 for op
in self._net.op:
1940 self._op_outputs.add(o)
1943 for inp
in self._net.external_input:
1944 self._external_input_map.add(inp)
1949 """Add the gradient for operators in the net. 1952 ys: a list or a dictionary specifying what blobs we want to compute 1953 derivatives of. If the input is a list, we will automatically 1954 generate their gradients with all-one values; if the input is a 1955 dictionary, for any dictionary entries that are not None, we will 1956 take the corresponding blobs as their gradients; for all those 1957 that are None, we will auto-fill them with 1. 1958 skip: skips the first n operators. This is provided mainly because a 1959 lot of nets may use the first few operators for data generation 1960 like stuff which really do not need to have gradients. 1963 returns a map from the blob name in the input network to a blob 1964 containing gradient or a GradientSlice in case of sparse gradient 1966 Currently, this is hard-coded for float operators if there are branches 1967 (i.e. a blob is used as input to multiple operators). This is because 1968 the gradient accumulation (Sum) is float only right now. 1971 grad_ops, input_to_grad = GradientRegistry.GetBackwardPass(
1972 self._net.op[skip:], ys)
1976 if workspace.IsImmediate():
1978 workspace.RunOperatorImmediate(op)
1980 return input_to_grad
1982 def AddArgument(self, arg_name, arg_value):
1983 self._net.arg.extend([utils.MakeArgument(arg_name, arg_value)])
1985 def AddExternalInput(self, *inputs):
1986 assert len(inputs) > 0
1988 for input
in inputs:
1989 input_name = str(input)
1991 'Net already contains an input named %s' % input_name)
1992 for input
in inputs:
1993 input_name = str(input)
1994 self._net.external_input.extend([input_name])
1995 self._external_input_map.update([input_name])
1996 refs.append(_get_blob_ref(input_name))
1998 return refs[0]
if len(refs) == 1
else refs
2000 def AddExternalOutput(self, *outputs):
2001 for output
in outputs:
2002 assert isinstance(output, BlobReference)
2003 assert self.
BlobIsDefined(output),
"{} is not defined".format(output)
2004 for output
in outputs:
2005 self.
Proto().external_output.extend([str(output)])
2007 def AddScopedExternalInputs(self, *inputs):
2009 * [ScopedBlobReference(b)
for b
in inputs]
2011 if not isinstance(res, list):
2015 def AddScopedExternalOutputs(self, *outputs):
2017 * [ScopedBlobReference(b)
for b
in outputs]
2021 def AddObserver(self, observer_type):
2022 return C.add_observer_to_net(self._net.name, observer_type)
2024 def RemoveObserver(self, observer):
2025 C.remove_observer_from_net(self._net.name, observer)
2027 def NumObservers(self):
2028 return C.num_observers_on_net(self._net.name)
2031 def external_inputs(self):
2032 return [_get_blob_ref(x)
for x
in self._net.external_input]
2035 def external_outputs(self):
2036 return [_get_blob_ref(x)
for x
in self._net.external_output]
2038 def set_input_record(self, input_record):
2040 assert self.
_input_record is None or (input_record.has_blobs()
and 2041 set(input_record.field_blobs()) ==
2042 set(self._input_record.field_blobs())), (
2043 'Input schema cannot be reset')
2044 if not input_record.has_blobs():
2045 with NameScope(self.
Name()):
2050 for blob
in self._input_record.field_blobs():
2057 Tries to recover input record by taking a subset of external_inputs with 2058 a given prefix name and interpreting them as schema column names 2060 record = _recover_record_by_prefix(self._net.external_input, prefix)
2064 def set_output_record(self, record):
2066 set(record.field_blobs()) ==
2067 set(self._output_record.field_blobs())), (
2068 'Output schema cannot be reset')
2069 for blob
in record.field_blobs():
2070 assert self.
BlobIsDefined(blob),
"{} is not defined".format(blob)
2071 for blob
in record.field_blobs():
2078 Tries to recover out record by taking a subset of external_outputs with 2079 a given prefix name and interpreting them as schema column names 2081 record = _recover_record_by_prefix(self._net.external_output, prefix)
2085 def AppendOutputRecordField(self, field_name, record):
2088 'Tried to append to missing output record' 2090 for blob
in record.field_blobs():
2091 assert self.
BlobIsDefined(blob),
"{} is not defined".format(blob)
2092 for blob
in record.field_blobs():
2095 (field_name, record)
2098 def input_record(self):
2101 def output_record(self):
2104 def AddExternalInputs(self, *inputs):
2107 def AddExternalOutputs(self, *outputs):
2110 def DeduplicateGradientSlices(self, g, aggregator='sum'):
2111 assert isinstance(g, GradientSlice)
2112 unique, remapping = self.Unique([g.indices], 2, engine=
'SparseHash')
2113 if aggregator.lower() ==
'sum':
2114 new_g = self.UnsortedSegmentSum([g.values, remapping], 1)
2115 elif aggregator.lower() ==
'mean':
2116 new_g = self.UnsortedSegmentMean([g.values, remapping], 1)
2118 raise ValueError(
'{} is not supported'.format(aggregator))
2119 return GradientSlice(indices=unique, values=new_g)
2122 def _RunAllOnGPU(net, gpu_id=0, use_cudnn=False):
2123 device_option = caffe2_pb2.DeviceOption()
2124 device_option.device_type = workspace.GpuDeviceType
2125 device_option.device_id = gpu_id
2126 net.device_option.CopyFrom(device_option)
2132 if op.type !=
"RecurrentNetwork":
2135 if arg.name ==
"step_net":
2136 Net._RunAllOnGPU(arg.n, gpu_id, use_cudnn)
2139 """A convenient function to run everything on the GPU.""" 2145 """A convenient function to run everything using MKLDNN.""" 2146 device_option = caffe2_pb2.DeviceOption()
2147 device_option.device_type = caffe2_pb2.MKLDNN
2148 self._net.device_option.CopyFrom(device_option)
2151 """A convenient function to run everything using IDEEP.""" 2152 device_option = caffe2_pb2.DeviceOption()
2153 device_option.device_type = caffe2_pb2.IDEEP
2154 self._net.device_option.CopyFrom(device_option)
2156 def _CreateAndAddToSelf(self, op_type, inputs, outputs=None, **kwargs):
2157 """A helper function to create an operator and add it to self. 2159 inputs = _RectifyInputOutput(inputs)
2160 for input
in inputs:
2162 assert input.Net() != self
2167 outputs = self.
NextName(prefix=op_type)
2168 elif type(outputs)
is int:
2172 self.
NextName(prefix=op_type, output_id=i)
2173 for i
in range(outputs)]
2174 outputs = _RectifyInputOutput(outputs, net=self)
2175 op = CreateOperator(op_type, inputs, outputs, **kwargs)
2178 workspace.operator_tracebacks[self.
Name()][
2179 len(self._net.op) - 1] = _extract_stacktrace()
2181 if len(op.output) == 0:
2183 elif len(op.output) == 1:
2188 def __getattr__(self, op_type):
2189 if op_type.startswith(
'__'):
2190 raise AttributeError(
'Attribute {} not found.'.format(op_type))
2191 if not IsOperator(op_type)
and not IsOperatorWithEngine(op_type,
"CUDNN"):
2192 raise AttributeError(
2193 'Method ' + op_type +
' is not a registered operator.' +
2194 ' Did you mean: [' +
2195 ",".join(workspace.C.nearby_opnames(op_type)) +
']' 2198 op_type, *args, **kwargs)
2201 additional_methods = [
2203 for op
in _REGISTERED_OPERATORS
2204 if '_ENGINE_' not in op]
2205 return sorted(set(chain(
2207 viewkeys(self.__dict__),
2215 python_func_type=
None,
2216 pass_workspace=
False,
2217 grad_output_indices=
None,
2218 grad_input_indices=
None 2221 Registers and returns a python operator. 2223 `f` and `grad_f` can be one of the following: 2224 - a function with signature (inputs, outputs), where inputs and 2225 outputs are a list of CPUTensor objects. This function will be 2226 called from C++ everytime the operator is executed. 2227 - a tuple (func, args, kwargs), here `func` is a callable, args is 2228 an argument list, and kwargs is a dict list. The call: 2229 f = func(*args, kwargs) 2230 will be performed locally at node initialization time, on all of 2231 the nodes of the job, returning `f`, a callable that will be used 2232 as the python operator function to be called during Net execution. 2233 This is to be used when using python operator in a distributed 2234 context, and allows to create and keep local python state across 2235 calls to the operator. 2237 `python_func_type` is a type of an object that constructed as 2238 python_func_type(f) and provides an implementation to forward and 2239 backward functions. Its useful in such a case where users needs 2240 a statefull PythonOp (ex: use autograd for computing grad_f). 2242 If `pass_workspace` is True, the signature is changed to 2243 (inputs, outputs, workspace) where `workspace` is the workspace the op 2244 is going to run on. This is potentially dangerous (as the op can 2245 manipulate the workspace directly), use on your own risk. 2247 If a gradient function is specified (`grad_f`), by default its inputs 2248 will be: (1) all inputs to `f`, (2) followed by all outputs of `f`, (3) 2249 and then all gradient outputs of `f`. The outputs of `grad_f` will be 2250 (by default) all gradient inputs to `f`. If a subset of the gradient 2251 outputs or gradient inputs is desired instead, then the subsets can be 2252 specified by providing `grad_output_indices` and/or `grad_input_indices` 2253 which identify the indices of `f`'s inputs and outputs which have 2256 assert(IsOperator(
'Python'))
2258 def make_builder(t):
2259 if not isinstance(t, tuple):
2261 assert len(t) == 3,
'Expected builder tuple (func, args, kwargs)' 2262 func, args, kwargs = t
2263 normalized = (func, tuple(args), dict(kwargs))
2264 return pickle.dumps(normalized)
2266 f_builder = make_builder(f)
2267 grad_f_builder = make_builder(grad_f)
2269 assert (
not grad_f)
or ((
not f_builder) == (
not grad_f_builder)), (
2270 'A tuple has to be passed to both f and grad_f or neither.')
2274 core_kwargs[
'pickled_builder'] = f_builder
2275 core_kwargs[
'pickled_grad_builder'] = grad_f_builder
2276 core_kwargs[
'pass_workspace'] = pass_workspace
2278 core_kwargs[
'token'] = _RegisterPythonImpl(
2279 f, grad_f, python_func_type, pass_workspace=pass_workspace)
2281 grad_output_indices = grad_output_indices
or []
2282 grad_input_indices = grad_input_indices
or []
2285 grad_output_indices=grad_output_indices,
2286 grad_input_indices=grad_input_indices,
2288 **dict(chain(viewitems(kwargs), viewitems(core_kwargs)))
2291 def is_external_input(self, blob):
2295 def extend_ops(self, new_ops):
2299 def remap_input(op, blob_name_remapping):
2300 new_list = [blob_name_remapping.get(b, b)
for b
in op.input]
2302 op.input.extend(new_list)
2305 def copy_func_between_devices(src, dst):
2306 CPU = caffe2_pb2.CPU
2307 is_src_gpu = IsGPUDeviceType(src.device_type)
2308 is_dst_gpu = IsGPUDeviceType(dst.device_type)
2310 if src.device_type == CPU
and dst.device_type == CPU:
2313 if is_src_gpu
and is_dst_gpu:
2314 if src.device_id == dst.device_id:
2317 def fun(net, *args, **kw):
2318 with DeviceScope(dst):
2319 return net.Copy(*args, **kw)
2322 if is_src_gpu
and dst.device_type == CPU:
2323 def fun(net, *args, **kw):
2324 with DeviceScope(src):
2325 return net.CopyGPUToCPU(*args, **kw)
2328 if src.device_type == CPU
and is_dst_gpu:
2329 def fun(net, *args, **kw):
2330 with DeviceScope(dst):
2331 return net.CopyCPUToGPU(*args, **kw)
2334 raise ValueError(
'Non-supported devices: %s and %s' % (src, dst))
2337 def device_equal(src, dst):
2339 We are using this fucntion instead of == operator because optional-value 2340 comparison between empty device_options and {device_type:0, device_id:0} 2341 returns not equal in some cases. 2343 return src.device_type == dst.device_type
and src.device_id == dst.device_id
2346 def update_placeholder_op_output(op, blob_to_device):
2348 Placeholder ops (for e.g. Recv) always runs on CPU. So ensure their 2349 output blobs reside on CPU. 2352 for output
in op.output:
2353 if (output
in blob_to_device
and 2354 blob_to_device[output].device_type != caffe2_pb2.CPU):
2356 outputs.append(output)
2358 op.output.extend(outputs)
2362 def __init__(self, blob, device):
2366 def __eq__(self, other):
2367 return self.
blob == other.blob
and self.
device == other.device
2373 def InjectCrossDeviceCopies(net, blob_to_device=None, blob_remap=None,
2374 placeHolderOps=
None):
2376 Injecting Copy functions between device within a net. Users can provide 2377 a net with part of operators using different device_options. This method 2378 will automatically create a new net with Copy ops inserted in it. 2381 blob_to_device: If not None, it is a map of blobs and their device locations. 2382 blob_remap: If not None, it is a map from a pair (blob, device) to 2383 the name of the blob in the given device. Blobs found in this 2384 map are assumed to be cached and don't need to be copied. 2386 new_net: A new net with CopyCPUToGPU inserted with correct device option 2388 required_external_to_device: 2389 A mapping between unresolved external inputs and their 2390 required device options. 2392 1. every external inputs of this net is already in blob_to_device! 2393 2. if not, this function will use net device option 2394 3. InferOpBlobDevices might fail to get the correct inference for ops like 2395 EnsureCPUOutput that could take in input from multiple places. 2397 new_net = net.Clone(net._net.name +
'_cross_device', keep_schema=
True)
2398 del new_net._net.op[:]
2399 if blob_to_device
is None:
2402 if blob_remap
is None:
2405 net_option = net._net.device_option
or caffe2_pb2.DeviceOption()
2409 all_remaps = defaultdict(list)
2410 for entry, mapped_blob
in blob_remap.items():
2411 all_remaps[entry.blob].append(mapped_blob)
2412 mapped_external_inputs = []
2413 for input
in new_net._net.external_input:
2414 mapped_external_inputs.extend(all_remaps.get(input)
or [])
2415 new_net._net.external_input.extend(mapped_external_inputs)
2417 for op
in net._net.op:
2423 if placeHolderOps
is not None and op.type
in placeHolderOps:
2424 input_dev, output_dev = InferOpDeviceAsBlobDevices(op)
2426 input_dev, output_dev = InferOpBlobDevices(op)
2428 for dev, input
in zip(input_dev, op.input):
2429 assert net.BlobIsDefined(input), \
2430 "input {} should be defined in the net.".format(input)
2431 if input
not in blob_to_device:
2432 if net.is_external_input(input):
2433 blob_to_device[input] = net_option
2435 raise AttributeError(
2436 "No device information found for blob {}.".
2440 if not device_equal(blob_to_device[input], dev):
2443 blob_to_device[blob_remap[
RemapEntry(input, dev)]] == dev):
2444 temp_remap[input] = blob_remap[
RemapEntry(input, dev)]
2447 copy_func = copy_func_between_devices(
2448 blob_to_device[input], dev
2451 def _gen_new_name(blob, device_option):
2452 CPU = caffe2_pb2.CPU
2453 if device_option.device_type == CPU:
2455 elif IsGPUDeviceType(device_option.device_type):
2456 suffix =
'_gpu_' + str(device_option.device_id)
2459 "Unknown device type: {}".
2460 format(device_option.device_type)
2462 return blob + suffix
2464 new_name = _gen_new_name(input, dev)
2465 copy_func(new_net, input, new_name)
2466 blob_remap[
RemapEntry(input, dev)] = new_name
2467 temp_remap[input] = new_name
2468 blob_to_device[new_name] = dev
2470 if placeHolderOps
is not None and op.type
in placeHolderOps:
2471 update_placeholder_op_output(op, blob_to_device)
2476 for dev, output
in zip(output_dev, op.output):
2477 if output
in blob_to_device
and (
2478 output
not in op.input
and 2479 not device_equal(blob_to_device[output], dev)
2482 "In-place blob: {} is not supported between operators " 2483 "with different device option previous:{} now: {}. " 2484 "Failed op:\n {}".format(
2485 output, blob_to_device[output], dev, op
2488 new_op = caffe2_pb2.OperatorDef()
2491 new_list = [temp_remap.get(b, b)
for b
in new_op.input]
2493 new_op.input.extend(new_list)
2496 original_inputs = list(op.input)
2497 for i, out
in enumerate(new_op.output):
2499 input_idx = original_inputs.index(out)
2500 new_op.output[i] = new_op.input[input_idx]
2504 blob_to_device.update(
2505 {o: d
for d, o
in zip(output_dev, new_op.output)})
2506 new_net.extend_ops([new_op])
2508 return new_net, blob_to_device
2511 def InjectDeviceCopiesAmongNets(nets, blob_to_device_init=None):
2513 Takes in a list of nets. They usually represent your whole execution graph. 2514 This function will insert cross device copy functions to all nets, and resolve 2515 inter-net external inputs dependencies. This method will insert Copy funcitons if 2516 external inputs of a net is produced on different device than it is required. 2518 nets: a list of nets 2520 new_nets: a list of new nets with device difference solved. 2522 Some notes from wyiming: 2523 1. You MUST pass nets in execution order. e.g. [train_init, train] 2525 assert isinstance(nets, list), \
2526 "nets {} should be a list of nets.".format(str(nets))
2527 assert all(isinstance(net, Net)
for net
in nets), \
2528 "nets {} should be a list of nets.".format(str(nets))
2530 blob_to_device = blob_to_device_init
or {}
2535 new_net, blob_to_device = InjectCrossDeviceCopies(
2537 blob_to_device=blob_to_device,
2538 blob_remap=blob_remap,
2540 new_nets.append(new_net)
2542 return new_nets, blob_to_device
2545 def InjectDeviceCopiesAmongNetsWithoutB2D(nets, blob_to_device_init=None):
2546 new_nets, _ = InjectDeviceCopiesAmongNets(nets, blob_to_device_init)
2550 def get_net_name(netlike):
2551 if isinstance(netlike, Net):
2552 return netlike.Proto().name
2553 elif isinstance(netlike, caffe2_pb2.NetDef):
2559 def output_to_list(op_output):
2561 Ensures that the output of an operator is a list. 2562 Use when an operator has a variable number of outputs, but a list of 2563 outputs is desired even when number of outputs is 1. 2566 op_output: Either a BlobReferenece or an iterable of BlobReferences. 2569 A list of BlobReferences. 2571 assert type(op_output)
in (list, tuple, BlobReference)
2574 if isinstance(op_output, BlobReference)
else list(op_output))
2577 def _add_net_to_dict(net_dict, net):
2578 name = get_net_name(net)
2579 if name
in net_dict:
2580 assert net_dict[name]
is None or net == net_dict[name], (
2581 'Different nets with same name: ' + name)
2584 net_dict[name] = net
if isinstance(net, Net)
else None 2589 _step_names_used = set()
2592 def _get_next_step_name(basename):
2595 while name
in ExecutionStep._step_names_used:
2596 name = basename +
'_' + str(next_idx)
2598 ExecutionStep._step_names_used |= set([name])
2601 def __init__(self, name, nets=None, num_iter=None):
2602 self.
_step = caffe2_pb2.ExecutionStep()
2603 self._step.name = name
or ExecutionStep._get_next_step_name(
'step')
2607 if nets
is not None:
2608 if type(nets)
is Net:
2611 if _add_net_to_dict(self.
_net_dict, net):
2612 self._step.network.extend([get_net_name(net)])
2613 if num_iter
is not None:
2614 self._step.num_iter = num_iter
2616 def get_net(self, name):
2620 return self._step.name
2623 return self._step.name
2625 def _assert_can_mutate(self):
2627 'Cannot mutate a step that has already been added to a plan/step.')
2629 def _notify_is_used(self):
2636 return self._step.network
is not None and (
2637 len(self._step.network) > 0)
2639 def HasSubsteps(self):
2640 return self._step.substep
is not None and (
2641 len(self._step.substep) > 0)
2649 def SetIter(self, num_iter):
2651 self._step.num_iter = num_iter
2653 def SetCreateWorkspace(self, create_workspace):
2655 self._step.create_workspace = create_workspace
2657 def SetNumConcurrentInstances(self, num_concurrent_instances):
2659 self._step.num_concurrent_instances = num_concurrent_instances
2661 def SetOnlyOnce(self, only_once):
2663 self._step.only_once = only_once
2665 def SetShouldStopBlob(self, should_stop_blob):
2666 assert isinstance(should_stop_blob, BlobReference), (
2667 "expects BlobReference here, got {}".format(type(should_stop_blob)))
2669 self._step.should_stop_blob = str(should_stop_blob)
2673 Run this step every interval millisecods, as long as its 2674 siblings are still running. It is guaranteed that, after all 2675 siblings finish, this step will run at least one. 2677 This property is ignored for top-level ExecutionSteps. 2679 self._step.run_every_ms = interval
2682 """ DEPRECATED. Use RunEveryMillis instead. """ 2684 _add_net_to_dict(self.
_net_dict, report_net)
2685 self._step.report_net = get_net_name(report_net)
2686 self._step.report_interval = report_interval
2688 def AddSubstep(self, substep):
2690 assert not self.
HasNets(),
'Cannot have both network and substeps.' 2691 if isinstance(substep, ExecutionStep):
2692 substep._notify_is_used()
2693 if not substep.HasNets()
and not substep.HasSubsteps():
2695 for net
in substep.Nets():
2697 self._substeps.append(substep)
2698 proto = substep.Proto()
2701 self._step.substep.add().CopyFrom(proto)
2704 def SetConcurrentSubsteps(self, concurrent_substeps):
2706 assert not self.
HasNets(),
'Cannot have both network and substeps.' 2707 self._step.concurrent_substeps = concurrent_substeps
2709 def AddNet(self, net):
2711 assert not self.
HasSubsteps(),
'Cannot have both network and substeps.' 2712 assert isinstance(net, Net)
2714 self._step.network.extend([get_net_name(net)])
2719 Return the list of all attributes under the given `name`, present in 2720 all of the nets used in this execution step and its children. 2725 for attr
in net.get_attributes(name)
2731 Create ExecutionStep from ExecutionStep protobuf recursively 2733 assert isinstance(step_proto, caffe2_pb2.ExecutionStep)
2734 assert (len(step_proto.network) > 0
and len(step_proto.substep) == 0)
or \
2735 (len(step_proto.network) == 0
and len(step_proto.substep) > 0)
2738 if len(step_proto.substep) > 0:
2739 for substep_proto
in step_proto.substep:
2740 steps_or_nets.append(ExecutionStep.create_from_proto(
2741 substep_proto, net_obj_dict, net_proto_dict))
2743 for net_name
in step_proto.network:
2744 if net_name
not in net_obj_dict:
2745 assert net_name
in net_proto_dict
2746 net =
Net(net_proto_dict[net_name])
2747 net_obj_dict[net_name] = net
2748 net = net_obj_dict[net_name]
2749 assert isinstance(net, Net)
2750 steps_or_nets.append(net)
2752 num_iter = step_proto.num_iter
if step_proto.HasField(
'num_iter')
else None 2753 concurrent_substeps = step_proto.concurrent_substeps
if\
2754 step_proto.HasField(
'concurrent_substeps')
else None 2755 should_stop_blob =
BlobReference(step_proto.should_stop_blob)
if\
2756 step_proto.HasField(
'should_stop_blob')
else None 2757 only_once = step_proto.only_once
if\
2758 step_proto.HasField(
'only_once')
else None 2759 num_concurrent_instances = step_proto.num_concurrent_instances
if\
2760 step_proto.HasField(
'num_concurrent_instances')
else None 2761 create_workspace = step_proto.create_workspace
if\
2762 step_proto.HasField(
'create_workspace')
else None 2763 run_every_ms = step_proto.run_every_ms
if\
2764 step_proto.HasField(
'run_every_ms')
else None 2766 return execution_step(
2771 report_interval=
None,
2772 concurrent_substeps=concurrent_substeps,
2773 should_stop_blob=should_stop_blob,
2774 only_once=only_once,
2775 num_concurrent_instances=num_concurrent_instances,
2776 create_workspace=create_workspace,
2777 run_every_ms=run_every_ms)
2780 def add_nets_in_order(step, net_list):
2781 proto = step.Proto()
2782 for substep
in step.Substeps():
2783 add_nets_in_order(substep, net_list)
2784 for net
in proto.network:
2785 if net
not in net_list:
2786 net_list.append(net)
2790 if proto.report_net
and proto.report_net
not in net_list:
2791 net_list.append(proto.report_net)
2796 def __init__(self, name_or_step):
2797 self.
_plan = caffe2_pb2.PlanDef()
2800 if isinstance(name_or_step, ExecutionStep):
2801 self._plan.name = name_or_step.Name()
2803 elif isinstance(name_or_step, basestring):
2804 self._plan.name = name_or_step
2806 raise ValueError(
'name_or_step must be a string or ExecutionStep')
2809 return self._plan.name
2814 def AddNets(self, nets):
2816 if _add_net_to_dict(self.
_net_dict, net):
2817 assert isinstance(net, Net)
2818 self._plan.network.add().CopyFrom(net.Proto())
2823 def AddStep(self, step):
2824 assert isinstance(step, ExecutionStep)
2825 step._notify_is_used()
2826 if not step.HasNets()
and not step.HasSubsteps():
2828 self._plan.execution_step.add().CopyFrom(step.Proto())
2829 self._steps.append(step)
2832 add_nets_in_order(step, net_list)
2833 self.
AddNets([step.get_net(n)
for n
in net_list])
2840 Return the list of all attributes under the given `name`, present in 2841 all of the nets used in this plan. 2846 for attr
in net.get_attributes(name)
2850 def create_from_proto(cls, plan_proto):
2851 assert isinstance(plan_proto, caffe2_pb2.PlanDef)
2852 plan =
Plan(plan_proto.name)
2853 plan._plan.CopyFrom(plan_proto)
2854 del plan._plan.network[:]
2855 del plan._plan.execution_step[:]
2859 for net_proto
in plan_proto.network:
2860 assert net_proto.name
not in net_proto_dict
2861 net_proto_dict[net_proto.name] = net_proto
2863 for step_proto
in plan_proto.execution_step:
2864 step = ExecutionStep.create_from_proto(
2865 step_proto, net_obj_dict, net_proto_dict)
2871 def to_execution_step(step_or_nets, default_name=None):
2873 if isinstance(step_or_nets, ExecutionStep):
2877 if not default_name
and hasattr(step_or_nets,
'name'):
2878 default_name = step_or_nets.name
2879 if isinstance(step_or_nets, NetBuilder):
2880 stop_blob = step_or_nets._stop_blob
2881 step_or_nets = step_or_nets.get()
2882 return execution_step(
2883 default_name, step_or_nets, should_stop_blob=stop_blob)
2886 def execution_step(default_name,
2890 report_interval=
None,
2891 concurrent_substeps=
None,
2892 should_stop_blob=
None,
2894 num_concurrent_instances=
None,
2895 create_workspace=
False,
2898 Helper for creating an ExecutionStep. 2899 - steps_or_nets can be: 2904 - list<ExecutionStep> 2905 - should_stop_blob is either None or a scalar boolean blob. 2906 - This blob is checked AFTER every substeps/subnets. 2907 - If specified and true, then this step will return immediately. 2908 - Be sure to handle race conditions if setting from concurrent threads. 2909 - if no should_stop_blob or num_iter is provided, defaults to num_iter=1 2911 assert should_stop_blob
is None or num_iter
is None, (
2912 'Cannot set both should_stop_blob and num_iter.')
2913 if should_stop_blob
is None and num_iter
is None:
2917 if should_stop_blob
is not None:
2918 step.SetShouldStopBlob(should_stop_blob)
2919 if num_iter
is not None:
2920 step.SetIter(num_iter)
2921 if only_once
is not None:
2922 step.SetOnlyOnce(only_once)
2923 if concurrent_substeps
is not None:
2924 step.SetConcurrentSubsteps(concurrent_substeps)
2925 if report_net
is not None:
2926 assert report_interval
is not None 2927 step.SetReportNet(report_net, report_interval)
2928 if num_concurrent_instances
is not None:
2929 step.SetNumConcurrentInstances(num_concurrent_instances)
2930 if create_workspace:
2931 step.SetCreateWorkspace(
True)
2933 step.RunEveryMillis(run_every_ms)
2935 if isinstance(steps_or_nets, ExecutionStep):
2936 step.AddSubstep(steps_or_nets)
2937 elif isinstance(steps_or_nets, Net):
2938 step.AddNet(steps_or_nets)
2939 elif isinstance(steps_or_nets, list):
2940 if all(isinstance(x, Net)
for x
in steps_or_nets):
2941 for x
in steps_or_nets:
2944 for x
in steps_or_nets:
2945 step.AddSubstep(to_execution_step(x))
2948 'steps_or_nets must be a step, a net, or a list of nets or steps.')
2952 def scoped_execution_step(name, *args, **kwargs):
2953 """Same as execution_step() except that the step name is scoped.""" 2954 default_name = ScopedName(name)
if name
else name
2955 return execution_step(default_name, *args, **kwargs)
2958 def _extract_stacktrace():
2960 This function extracts stacktrace without file system access 2961 by purely using sys._getframe() and removes part that belongs to 2962 this file (core.py). We are not using inspect module because 2963 its just a wrapper on top of sys._getframe() whos 2964 logis is based on accessing source files on disk - exactly what 2965 we are trying to avoid here. Same stands for traceback module 2967 The reason for file system access avoidance is that 2968 if code is located on an NFS, file access might be slow 2970 Function returns a list of tuples (file_name, line_number, function) 2976 frame = sys._getframe(3)
2981 result.append((frame.f_code.co_filename, frame.f_lineno, frame.f_code.co_name))
2982 frame = frame.f_back
2986 SetPerOpEnginePref = C.set_per_op_engine_pref
2987 SetGlobalEnginePref = C.set_global_engine_pref
2988 SetEnginePref = C.set_engine_pref
2989 SetOpEnginePref = C.set_op_engine_pref
def add_attribute(self, name, obj)
def BuildGradientGenerators(self, fwd_op_idx, gradient_ops, g_output, g_input)
def _CreateAndAddToSelf(self, op_type, inputs, outputs=None, kwargs)
def _RecreateLookupTables(self)
def get_attributes(self, name)
def recover_input_record_by_prefix(self, prefix)
def AddExternalOutput(self, outputs)
def NextBlob(self, prefix='unnamed')
def get_all_attributes(self, name)
def RegisterGradient(cls, op_type)
def external_inputs(self)
dictionary gradient_registry_
def SanityCheck(self, operators)
def _assert_can_mutate(self)
def DoGradientAccumulation(self, fwd_op_idx)
def _VerifyGradientGenerators(self, generator)
def set_rand_seed(self, seed=100, sequence_seed=True, seed_on_op_def=False)
def _CreateAndAddToNet(self, op_type, inputs=None, args, kwargs)
def _AppendAutoGradGenerator(self, y, grad, autograd_op)
def __init__(self, name_or_proto)
def Clone(self, name, blob_remap=None, op_id_mask=None, remap_funcs=None, keep_schema=True, update_external_list=False)
def _MakeSparseSumOps(self, generators, out_base_name)
def AddGradientOperators(self, ys, skip=0)
def GetBackwardPass(self, ys)
def _MakeSumOps(self, input_name, input_version)
def RunAllOnGPU(self, gpu_id=0, use_cudnn=False)
def _RunAllOnGPU(net, gpu_id=0, use_cudnn=False)
def get_all_attributes(self, name)
def _DisambiguateGradOpOutput(self, grad_op, idx, cnt)
def CheckGradientOperatorInput(self, grad_op_input, g_output, fwd_op_idx, locally_generated_blobs)
def RunEveryMillis(self, interval)
def NextName(self, prefix=None, output_id=None)
def _ExtendOps(self, new_ops)
def __getattr__(self, op_type)
def set_output_record(self, record)
def _GetSumOpOutputName(self, generator, input_name)
def _SetSumOpsDeviceOption(self, sum_ops, generators)
def AppendSparseGenerators(self, sparse_generators)
def create_from_proto(cls, step_proto, net_obj_dict, net_proto_dict)
def __init__(self, name, net=None)
def BlobIsDefined(self, blob)
def _MakeDenseSumOps(self, generators, out_base_name)
def _GetGradientForOpCC(cls, op_def, g_output)
def AddExternalInput(self, inputs)
def NextScopedBlob(self, prefix='unnamed')
def insert_op_at_idx(self, op, op_idx)
def _GetInitGradients(self, ys)
def Python(self, f, grad_f=None, python_func_type=None, pass_workspace=False, grad_output_indices=None, grad_input_indices=None)
def _CheckSumOpsConflict(self, out_base_name, g)
def ClonePartial(self, name, inputs, outputs, remap_funcs=None)
def GetBlobRef(self, blob_name)
def _GenerateGradientsForForwardOp(self, forward_op_idx, input_to_grad)
def _InvalidateLookupTables(self)
def SetReportNet(self, report_net, report_interval)
def set_input_record(self, input_record)
def GetBackwardPass(cls, operators, ys, ys_generate_gradient=False)
def recover_output_record_by_prefix(self, prefix)
def external_outputs(self)