Caffe2 - Python API
A deep learning, cross platform ML framework
seq2seq_model_helper.py
1 ## @package seq2seq_model_helper
2 # Module caffe2.python.models.seq2seq.seq2seq_model_helper
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import scope
9 from caffe2.python.model_helper import ModelHelper
10 
11 
13 
14  def __init__(self, init_params=True, **kwargs):
15  arg_scope = {
16  'use_cudnn': kwargs.pop('use_cudnn', True),
17  'cudnn_exhaustive_search': kwargs.pop('cudnn_exhaustive_search', False),
18  'order': 'NHWC',
19  }
20  if kwargs.get('ws_nbytes_limit', None):
21  arg_scope['ws_nbytes_limit'] = kwargs.pop('ws_nbytes_limit')
22 
23  super(Seq2SeqModelHelper, self).__init__(
24  init_params=init_params,
25  arg_scope=arg_scope,
26  **kwargs
27  )
28  self.non_trainable_params = []
29 
30  def AddParam(self, name, init=None, init_value=None, trainable=True):
31  """Adds a parameter to the model's net and it's initializer if needed
32 
33  Args:
34  init: a tuple (<initialization_op_name>, <initialization_op_kwargs>)
35  init_value: int, float or str. Can be used instead of `init` as a
36  simple constant initializer
37  trainable: bool, whether to compute gradient for this param or not
38  """
39  if init_value is not None:
40  assert init is None
41  assert type(init_value) in [int, float, str]
42  init = ('ConstantFill', dict(
43  shape=[1],
44  value=init_value,
45  ))
46 
47  if self.init_params:
48  param = self.param_init_net.__getattr__(init[0])(
49  [],
50  name,
51  **init[1]
52  )
53  else:
54  param = self.net.AddExternalInput(name)
55 
56  if trainable:
57  self.params.append(param)
58  else:
59  self.non_trainable_params.append(param)
60 
61  return param
62 
63  def GetNonTrainableParams(self, namescope=None):
64  '''
65  Returns the params in current namescope
66  '''
67  if namescope is None:
68  namescope = scope.CurrentNameScope()
69  else:
70  if not namescope.endswith(scope._NAMESCOPE_SEPARATOR):
71  namescope += scope._NAMESCOPE_SEPARATOR
72 
73  if namescope == '':
74  return self.non_trainable_params[:]
75  else:
76  return [
77  p for p in self.non_trainable_params
78  if p.GetNameScope() == namescope
79  ]
80 
81  def GetAllParams(self, namescope=None):
82  return (
83  self.GetParams(namescope) +
84  self.GetComputedParams(namescope) +
85  self.GetNonTrainableParams(namescope)
86  )
def AddParam(self, name, init=None, init_value=None, trainable=True)
def GetComputedParams(self, namescope=None)
def GetParams(self, namescope=None, top_scope=False)