Caffe2 - Python API
A deep learning, cross platform ML framework
seq2seq_model_helper.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package seq2seq_model_helper
17 # Module caffe2.python.models.seq2seq.seq2seq_model_helper
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import scope
24 from caffe2.python.model_helper import ModelHelper
25 
26 
28 
29  def __init__(self, init_params=True, **kwargs):
30  arg_scope = {
31  'use_cudnn': kwargs.pop('use_cudnn', True),
32  'cudnn_exhaustive_search': kwargs.pop('cudnn_exhaustive_search', False),
33  'order': 'NHWC',
34  }
35  if kwargs.get('ws_nbytes_limit', None):
36  arg_scope['ws_nbytes_limit'] = kwargs.pop('ws_nbytes_limit')
37 
38  super(Seq2SeqModelHelper, self).__init__(
39  init_params=init_params,
40  arg_scope=arg_scope,
41  **kwargs
42  )
43  self.non_trainable_params = []
44 
45  def AddParam(self, name, init=None, init_value=None, trainable=True):
46  """Adds a parameter to the model's net and it's initializer if needed
47 
48  Args:
49  init: a tuple (<initialization_op_name>, <initialization_op_kwargs>)
50  init_value: int, float or str. Can be used instead of `init` as a
51  simple constant initializer
52  trainable: bool, whether to compute gradient for this param or not
53  """
54  if init_value is not None:
55  assert init is None
56  assert type(init_value) in [int, float, str]
57  init = ('ConstantFill', dict(
58  shape=[1],
59  value=init_value,
60  ))
61 
62  if self.init_params:
63  param = self.param_init_net.__getattr__(init[0])(
64  [],
65  name,
66  **init[1]
67  )
68  else:
69  param = self.net.AddExternalInput(name)
70 
71  if trainable:
72  self.params.append(param)
73  else:
74  self.non_trainable_params.append(param)
75 
76  return param
77 
78  def GetNonTrainableParams(self, namescope=None):
79  '''
80  Returns the params in current namescope
81  '''
82  if namescope is None:
83  namescope = scope.CurrentNameScope()
84  else:
85  if not namescope.endswith(scope._NAMESCOPE_SEPARATOR):
86  namescope += scope._NAMESCOPE_SEPARATOR
87 
88  if namescope == '':
89  return self.non_trainable_params[:]
90  else:
91  return [
92  p for p in self.non_trainable_params
93  if p.GetNameScope() == namescope
94  ]
95 
96  def GetAllParams(self, namescope=None):
97  return (
98  self.GetParams(namescope) +
99  self.GetComputedParams(namescope) +
100  self.GetNonTrainableParams(namescope)
101  )
def AddParam(self, name, init=None, init_value=None, trainable=True)
def GetComputedParams(self, namescope=None)
def GetParams(self, namescope=None, top_scope=False)