Caffe2 - Python API
A deep learning, cross platform ML framework
model_helper.py
1 ## @package model_helper
2 # Module caffe2.python.model_helper
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, scope, workspace, helpers
9 from caffe2.python.modeling import parameter_info
11  parameter_sharing_context,
12 )
14  OptimizerContext,
15  DEFAULT_OPTIM,
16 )
17 from caffe2.python.regularizer_context import RegularizerContext
18 
19 from future.utils import viewitems, viewkeys
20 from itertools import chain
21 
22 import logging
23 import six
24 
25 
26 # _known_working_ops are operators that do not need special care.
27 _known_working_ops = [
28  "Accuracy",
29  "Adam",
30  "Add",
31  "Adagrad",
32  "SparseAdagrad",
33  "Adadelta",
34  "SparseAdadelta",
35  "AveragedLoss",
36  "Cast",
37  "Checkpoint",
38  "ConstantFill",
39  "Copy",
40  "CopyGPUToCPU",
41  "CopyCPUToGPU",
42  "DequeueBlobs",
43  "EnsureCPUOutput",
44  "ExpandDims",
45  "Flatten",
46  "FlattenToVec",
47  "LabelCrossEntropy",
48  "LearningRate",
49  "MakeTwoClass",
50  "MatMul",
51  "NCCLAllreduce",
52  "NHWC2NCHW",
53  "PackSegments",
54  "Print",
55  "PRelu",
56  "ReduceFrontSum",
57  "Scale",
58  "ScatterWeightedSum",
59  "Sigmoid",
60  "SortedSegmentSum",
61  "Snapshot", # Note: snapshot is deprecated, use Checkpoint
62  "Softmax",
63  "SoftmaxWithLoss",
64  "SquaredL2Distance",
65  "Squeeze",
66  "StopGradient",
67  "Summarize",
68  "Tanh",
69  "Transpose",
70  "UnpackSegments",
71  "WeightedSum",
72  "YellowFin"
73 ]
74 
75 
76 class ModelHelper(object):
77  """A helper model so we can manange models more easily. It contains net def
78  and parameter storages. You can add an Operator yourself, e.g.
79 
80  model = model_helper.ModelHelper(name="train_net")
81  # init your weight and bias as w and b
82  w = model.param_init_net.XavierFill(...)
83  b = model.param_init_net.ConstantFill(...)
84  fc1 = model.FC([input, w, b], output, **kwargs)
85 
86  or you can use helper functions in brew module without manually
87  defining parameter initializations and operators.
88 
89  model = model_helper.ModelHelper(name="train_net")
90  fc1 = brew.fc(model, input, output, dim_in, dim_out, **kwargs)
91 
92  """
93 
94  def __init__(self, name=None, init_params=True, allow_not_known_ops=True,
95  skip_sparse_optim=False, param_model=None, arg_scope=None):
96  self.name = name or "model"
97  self.net = core.Net(self.name)
98 
99  if param_model is not None:
100  self.param_init_net = param_model.param_init_net
101  self.param_to_grad = param_model.param_to_grad
102  self.params = param_model.params
103  self._parameters_info = param_model._parameters_info
104  self._computed_params = param_model._computed_params
105  else:
106  self.param_init_net = core.Net(self.name + '_init')
107  self.param_to_grad = {}
108  self.params = []
109  self._parameters_info = {}
110  self._computed_params = []
111 
112  self._param_info_deprecated = []
113  self._devices = []
114  self.gradient_ops_added = False
115  self.init_params = init_params
116  self.allow_not_known_ops = allow_not_known_ops
117  self.skip_sparse_optim = skip_sparse_optim
118  self.weights = []
119  self.biases = []
120  self._arg_scope = {
121  'order': "NCHW",
122  'use_cudnn': True,
123  'cudnn_exhaustive_search': False,
124  }
125  if arg_scope is not None:
126  # Please notice value as None is not acceptable. We are not checking it
127  # here because we already have check in MakeArgument.
128  self._arg_scope.update(arg_scope)
129 
130  @property
131  def arg_scope(self):
132  return self._arg_scope
133 
134  def get_name(self):
135  return self.name
136 
137  def _infer_param_shape(self, param):
138  for op in self.param_init_net.Proto().op:
139  if str(param) in op.output:
140  for arg in op.arg:
141  if arg.name == "shape":
142  return list(arg.ints)
143  return None
144 
145  def _update_param_info_deprecated(self):
146  assert len(self._param_info_deprecated) <= len(self.params)
147  for param in self.params[len(self._param_info_deprecated):]:
148  if not isinstance(param, core.BlobReference):
149  raise ValueError(
150  "Param %s must be a BlobReference!" % str(param))
151  self._param_info_deprecated.append(parameter_info.ParameterInfo(
152  param_id=len(self._param_info_deprecated),
153  param=param,
154  shape=self._infer_param_shape(param)))
155  for info in self._param_info_deprecated:
156  info.grad = self.param_to_grad.get(info.name)
157 
158  def _normalize_tags(self, tags):
159  tags = tags or []
160  return set(tags) if isinstance(tags, list) else set([tags])
161 
162  def create_param(self, param_name, shape, initializer, tags=None):
163  """
164  Creates parameter with a given name and initializer.
165 
166  If param_name is instance of BlobRefernce - then this blob will be used
167  to store parameter (no any logic will affect it's location).
168 
169  If param_name is instance of a string type, then the final blob will
170  be created in the CurrentNameScope with the respect of all parameter
171  sharing logic, i.e. 'resolved_name_scope/param_name'.
172 
173  Parameter sharing logic is going to override CurrentNameScope accoring
174  to the rules that are specified through ParameterSharing contexts,
175  all ParameterSharing contexts are applied recursively until there are no
176  extra overrides present, where on each step the best match will be
177  applied first.
178 
179  The following examples should clarify the way ParameterSharing logic
180  works:
181 
182  As an example if this function is called with parameter 'w':
183  a. Call from some scope 'global_scope' with no Parameter sharing:
184  'global_scope/w'
185  b. Call from scope 'scope_b', with override {'scope_b': 'scope_a'}:
186  'scope_a/w'
187  c. Call from scope 'scope_a', with override {'scope_a': ''}:
188  'scope_a/w'
189  d. Call from scope 'scope_b/shared', with overrides
190  {'scope_b/shared': 'scope_b', 'scope_b': 'scope_a'}:
191  'scope_a/w'
192  d. Call from scope 'scope_b/unshared', with overrides
193  {'scope_b/shared': 'scope_b', 'scope_b': 'scope_a'}:
194  'scope_a/unshared/w'
195  """
196  # ParameterSharing works only for case when param_name is instance of
197  # a string type. If param_name is a BlobReference - no attempt for
198  # ParameterSharing will be applied.
199  if isinstance(param_name, core.BlobReference):
200  param_name = str(param_name)
201  elif isinstance(param_name, six.string_types):
202  # Parameter name will be equal to current Namescope that got
203  # resolved with the respect of parameter sharing of the scopes.
204  param_name = parameter_sharing_context.get_parameter_name(
205  param_name)
206  else:
207  raise TypeError("Unsupported type for param_name")
208 
209  if param_name in self._parameters_info:
210  assert self._parameters_info[param_name].shape == shape
211  return self._parameters_info[param_name].blob
212 
213  param_info = initializer.create_param(
214  param_name=core.BlobReference(param_name),
215  init_net=self.param_init_net,
216  shape=shape,
217  )
218  optim_context = OptimizerContext.current()
219  for tag in self._normalize_tags(tags):
220  if optim_context.has_optimizer(tag):
221  # param_info will check optimizer has not been set
222  param_info.optimizer = optim_context.get_optimizer(tag)
223  if not param_info.optimizer and optim_context.has_optimizer(DEFAULT_OPTIM):
224  param_info.optimizer = optim_context.get_optimizer(DEFAULT_OPTIM)
225 
226  reg_context = RegularizerContext.current()
227  param_info.regularizer = reg_context
228 
229  self._parameters_info[param_name] = param_info
230  # Add param to legacy structs as well, so all other functions for
231  # parameters are still working.
232  self.AddParameter(param_info.blob, tags)
233  return param_info.blob
234 
235  def get_param_info(self, param):
236  assert isinstance(param, core.BlobReference), \
237  "Param {} is not a BlobReference".format(param)
238  return self._parameters_info.get(param, None)
239 
240  # This method is deprecated, use create_param method which
241  # also does parameter initialization when needed
242  def add_param_DEPRECATED(self, param, key=None, shape=None, length=None):
243  logging.warning("add_param method is DEPRECATED")
245  self.AddParameter(param)
246  if key is not None and self.net.input_record() is not None:
247  idx = self.net.input_record().field_blobs().index(key)
248  key = self.net.input_record().field_names()[idx]
249  shape = shape if shape is not None else self._infer_param_shape(param)
250  if not isinstance(param, core.BlobReference):
251  raise ValueError("Param %s must be a BlobReference!" % str(param))
252  self._param_info_deprecated.append(parameter_info.ParameterInfo(
253  param_id=len(self._param_info_deprecated),
254  param=param,
255  shape=shape,
256  key=key,
257  length=length,
258  ))
259  return self._param_info_deprecated[-1]
260 
261  def AddParameter(self, param, tags=None):
262  assert isinstance(param, core.BlobReference)
263  tags = self._normalize_tags(tags)
264  if parameter_info.ParameterTags.COMPUTED_PARAM in tags:
265  self._computed_params.append(param)
266  else:
267  self.params.append(param)
268 
269  if parameter_info.ParameterTags.WEIGHT in tags:
270  self.weights.append(param)
271  if parameter_info.ParameterTags.BIAS in tags:
272  self.biases.append(param)
273 
274  @staticmethod
275  def _NormalizeNamescope(namescope):
276  if namescope is None:
277  return scope.CurrentNameScope()
278  elif namescope == '' or namescope.endswith(scope._NAMESCOPE_SEPARATOR):
279  return namescope
280  else:
281  return namescope + scope._NAMESCOPE_SEPARATOR
282 
283  def GetParams(self, namescope=None, top_scope=False):
284  '''
285  Returns the params in current namescope
286  '''
287  namescope = ModelHelper._NormalizeNamescope(namescope)
288 
289  if namescope == '':
290  return self.params[:]
291  else:
292  return [p for p in self.params if
293  p.GetNameScope().startswith(namescope)]
294 
295  def Proto(self):
296  return self.net.Proto()
297 
298  def InitProto(self):
299  return self.param_init_net.Proto()
300 
301  def RunAllOnGPU(self, *args, **kwargs):
302  self.param_init_net.RunAllOnGPU(*args, **kwargs)
303  self.net.RunAllOnGPU(*args, **kwargs)
304 
305  def CreateDB(self, blob_out, db, db_type, **kwargs):
306  dbreader = self.param_init_net.CreateDB(
307  [], blob_out, db=db, db_type=db_type, **kwargs)
308  return dbreader
309 
310  def AddGradientOperators(self, *args, **kwargs):
311  if self.gradient_ops_added:
312  raise RuntimeError("You cannot run AddGradientOperators twice.")
313  self.Validate()
314 
315  self.gradient_ops_added = True
316  self.grad_map = self.net.AddGradientOperators(*args, **kwargs)
317  self.param_to_grad = self.get_param_to_grad(self.params)
318 
319  # Populate ParameterInfo for all parameters if missing
320  # and add gradient blob information. So optimizers can use it
321  for param, grad in self.param_to_grad.items():
322  param_info = self.get_param_info(param)
323  if param_info:
324  param_info.grad = grad
325  else:
326  self._parameters_info[param] = parameter_info.ParameterInfo(
327  param_id=None,
328  param=param,
329  grad=grad,
330  )
331 
332  return self.grad_map
333 
334  def get_param_to_grad(self, params):
335  '''
336  Given a list of parameters returns a dict from a parameter
337  to a corresponding gradient
338  '''
339 
340  param_to_grad = {}
341  if not self.gradient_ops_added:
342  raise RuntimeError("You need to run AddGradientOperators first.")
343  # We need to use empty namescope when creating the gradients
344  # to prevent duplicating the namescope prefix for gradient blobs.
345  for p in params:
346  if str(p) in self.grad_map:
347  param_to_grad[p] = self.grad_map[str(p)]
348  return param_to_grad
349 
350  def GetOptimizationParamInfo(self, params=None):
351  '''
352  Returns a map for param => grad.
353  If params is not specified, all parameters will be considered.
354  '''
355  if not self.gradient_ops_added:
356  raise RuntimeError("Need to call AddGradientOperators first")
357 
358  param_to_grad = self.param_to_grad
359  if params:
360  param_to_grad = self.get_param_to_grad(params)
361 
362  return [
363  self.get_param_info(param) for param, grad in viewitems(param_to_grad)
364  if (
365  not self.skip_sparse_optim or
366  not isinstance(grad, core.GradientSlice)
367  )
368  ]
369 
370  def _Validate(self):
371  '''
372  Check for duplicate params
373  '''
374  params_list = [str(p) for p in self.params]
375  params_set = set(params_list)
376 
377  dupes = []
378  if len(params_set) != len(params_list):
379  params_list = sorted(params_list)
380  for j, p in enumerate(params_list):
381  if j > 0 and params_list[j - 1] == p:
382  if p not in dupes:
383  dupes.append(p)
384 
385  return dupes
386 
387  def Validate(self):
388  dupes = self._Validate()
389  assert dupes == [], "Duplicate params: {}".format(dupes)
390 
391  def GetComputedParams(self, namescope=None):
392  '''
393  Returns the computed params in current namescope. 'Computed params'
394  are such parameters that are not optimized via gradient descent but are
395  directly computed from data, such as the running mean and variance
396  of Spatial Batch Normalization.
397  '''
398  namescope = ModelHelper._NormalizeNamescope(namescope)
399 
400  if namescope == '':
401  return self._computed_params[:]
402  else:
403  return [p for p in self._computed_params
404  if p.GetNameScope().startswith(namescope)]
405 
406  def GetAllParams(self, namescope=None):
407  return self.GetParams(namescope) + self.GetComputedParams(namescope)
408 
410  self, unused_blob_in, blob_out, batch_size, db, db_type, **kwargs
411  ):
412  """TensorProtosDBInput."""
413  assert len(unused_blob_in) == 0, \
414  """You cannot pass reader to model_helper.TensorProtosDBInput.
415  Use model.net.TensorProtosDBInput instead to create the op."""
416 
417  return helpers.db_input.db_input(
418  self, blob_out, batch_size, db, db_type, **kwargs)
419 
420  def GetDevices(self):
421  assert len(self._devices) > 0, \
422  "Use data_parallel_model to run model on multiple GPUs."
423  return self._devices
424 
425  def __getattr__(self, op_type):
426  """Catch-all for all other operators, mostly those without params."""
427  if op_type.startswith('__'):
428  raise AttributeError(op_type)
429 
430  if not core.IsOperator(op_type):
431  raise AttributeError(
432  'Method ' + op_type + ' is not a registered operator.' +
433  ' Did you mean: [' +
434  ','.join(workspace.C.nearby_opnames(op_type)) + ']'
435  )
436  if op_type not in _known_working_ops:
437  if not self.allow_not_known_ops:
438  raise AttributeError(
439  "Operator {} is not known to be safe".format(op_type))
440 
441  logging.warning("You are creating an op that the ModelHelper "
442  "does not recognize: {}.".format(op_type))
443  return self.net.__getattr__(op_type)
444 
445  def __dir__(self):
446  return sorted(set(chain(
447  dir(type(self)),
448  viewkeys(self.__dict__),
449  _known_working_ops
450  )))
451 
452  def GetCompleteNet(self):
453  r""" Return param_init_net + net Net.
454  Returns:
455  'core.Net' containing param_init_net and net
456  """
457  new_net = self.param_init_net.Clone(
458  self.name + "_complete_net", keep_schema=True)
459  # add init net info to debug info
460  for op in new_net.Proto().op:
461  op.debug_info = op.debug_info + "/param_init_net"
462  new_net.AppendNet(self.net)
463  # keep the execution optimization
464  if self.net.Proto().HasField("type"):
465  new_net.Proto().type = self.net.Proto().type
466  return new_net
467 
468  def ConstructInitTrainNetfromNet(self, net):
469  r""" construct init net and train net from complete_net
470  Inputs:
471  net: 'core.Net' containing param_init_net and train net
472  """
473  param_op_mask = []
474  train_op_mask = []
475  for idx, op in enumerate(net.Proto().op):
476  if op.debug_info.endswith("/param_init_net"):
477  param_op_mask.append(idx)
478  else:
479  train_op_mask.append(idx)
480 
481  self.param_init_net = net.Clone(
482  net.Name() + "/generated_param_init_net",
483  keep_schema=True,
484  op_id_mask=param_op_mask,
485  update_external_list=True,
486  )
487  self.net = net.Clone(
488  net.Name() + "/generated_net",
489  keep_schema=True,
490  op_id_mask=train_op_mask,
491  update_external_list=True,
492  )
493 
494 
495 def ExtractPredictorNet(
496  net_proto,
497  input_blobs,
498  output_blobs,
499  device=None,
500  renames=None,
501  disabled_inputs=None,
502 ):
503  '''
504  Takes a model net for training and returns a net which can be
505  used for prediction. For example, all gradient operators and
506  input operators are removed.
507  @param net_proto protobuf of the net you want to process (net.Proto())
508  @param input_blobs list/set of blob names that are the inputs of predictor
509  @param output_blobs list/set of blob names that are outputs of predictor
510  @param device optional device option that is assigned
511  @param renames dictionary of blob name to a new name (optional)
512  @param disabled_inputs optional set of blobs that are 'switched off'. This
513  will cause branches with those blobs as inputs to be removed
514  '''
515  predict_net = core.Net(net_proto.name + "_predict")
516  predict_proto = predict_net.Proto()
517 
518  orig_external_inputs = set(net_proto.external_input)
519  orig_external_outputs = set(net_proto.external_output)
520  input_blobs = {str(b) for b in input_blobs}
521  known_blobs = set(orig_external_inputs).union(input_blobs)
522  output_blobs = {str(b) for b in output_blobs}
523  external_inputs = set(input_blobs)
524  external_outputs = set(output_blobs)
525 
526  if renames is None:
527  renames = {}
528 
529  if disabled_inputs is not None:
530  known_blobs = known_blobs - set(disabled_inputs)
531 
532  ops = list(net_proto.op)
533 
534  # Find the range of ops that we should include
535  try:
536  first_op_with_input = min(
537  [
538  j for j in range(len(ops))
539  if input_blobs.intersection(ops[j].input) and ops[j].type !=
540  'StopGradient'
541  ]
542  )
543  except ValueError:
544  raise Exception("No ops with input={}".format(input_blobs))
545  try:
546  last_op_with_output = max(
547  [
548  j for j in range(len(ops))
549  if output_blobs.intersection(ops[j].output)
550  ]
551  )
552  except ValueError:
553  raise Exception("No ops with output={}".format(output_blobs))
554 
555  def validate_op(op):
556  # Check that the op does not have is_test = 0 set. This is a common
557  # pitfall with SpatialBN op, at lest.
558  for arg in op.arg:
559  if arg.name == "is_test" and arg.i == 0:
560  raise Exception(
561  "An operator had is_test=0, did you try to extract a " +
562  "predictor from a train model (instead of test model)?" +
563  " Op was: {}".format(str(op))
564  )
565 
566  def rename_list(proto_list):
567  # proto lists don't support assignments
568  new_list = proto_list[:]
569  for j, b in enumerate(new_list):
570  if b in renames:
571  new_list[j] = renames[b]
572 
573  del proto_list[:]
574  proto_list.extend(new_list)
575 
576  # Iterate through the ops and only include those whose inputs
577  # we can satisfy.
578  for op in ops[first_op_with_input:(last_op_with_output + 1)]:
579  if known_blobs.issuperset(op.input):
580 
581  # Special handling for recurrent nets
582  # TODO: when standard argument type for "nets" is introduced,
583  # this can be more general
584  if op.type == 'RecurrentNetwork':
585  for arg in op.arg:
586  if arg.name == 'backward_step_net':
587  arg.ClearField(str('n'))
588  elif arg.name == 'step_net':
589  for step_op in arg.n.op:
590  rename_list(step_op.input)
591  rename_list(step_op.output)
592  if device is not None:
593  step_op.device_option.device_type = device.device_type
594  step_op.device_option.device_id = device.device_id
595 
596  rename_list(arg.n.external_input)
597  rename_list(arg.n.external_output)
598 
599  # Add additional external inputs
600  external_inputs.update(
601  set(arg.n.external_input).intersection(
602  orig_external_inputs
603  )
604  )
605 
606  if device is not None:
607  op.device_option.device_type = device.device_type
608  op.device_option.device_id = device.device_id
609  validate_op(op)
610  predict_proto.op.extend([op])
611  known_blobs.update(op.output)
612  external_inputs.update(
613  set(op.input).intersection(orig_external_inputs)
614  )
615  external_outputs.update(
616  set(op.output).intersection(orig_external_outputs)
617  )
618 
619  else:
620  logging.debug(
621  "Op {} had unknown inputs: {}".format(
622  op.type, set(op.input).difference(known_blobs)
623  )
624  )
625 
626  # Predictor net's external inputs and outputs include only those
627  # that are part of this net.
628  predict_proto.external_input.extend(external_inputs)
629  predict_proto.external_output.extend(external_outputs)
630 
631  rename_list(predict_proto.external_input)
632  rename_list(predict_proto.external_output)
633 
634  renamed_input_blobs = []
635  for b in input_blobs:
636  if b in renames:
637  renamed_input_blobs.append(renames[b])
638  else:
639  renamed_input_blobs.append(b)
640 
641  for op in predict_proto.op:
642  rename_list(op.input)
643  rename_list(op.output)
644 
645  return predict_net, list(
646  set(predict_proto.external_input) - set(renamed_input_blobs)
647  )
def AddParameter(self, param, tags=None)
def create_param(self, param_name, shape, initializer, tags=None)
def GetOptimizationParamInfo(self, params=None)
def TensorProtosDBInput(self, unused_blob_in, blob_out, batch_size, db, db_type, kwargs)
def GetComputedParams(self, namescope=None)
def GetParams(self, namescope=None, top_scope=False)