3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
14 def __init__(self, init_params=True, **kwargs):
16 'use_cudnn': kwargs.pop(
'use_cudnn',
True),
17 'cudnn_exhaustive_search': kwargs.pop(
'cudnn_exhaustive_search',
False),
20 if kwargs.get(
'ws_nbytes_limit',
None):
21 arg_scope[
'ws_nbytes_limit'] = kwargs.pop(
'ws_nbytes_limit')
23 super(Seq2SeqModelHelper, self).__init__(
24 init_params=init_params,
30 def AddParam(self, name, init=None, init_value=None, trainable=True):
31 """Adds a parameter to the model's net and it's initializer if needed 34 init: a tuple (<initialization_op_name>, <initialization_op_kwargs>) 35 init_value: int, float or str. Can be used instead of `init` as a 36 simple constant initializer 37 trainable: bool, whether to compute gradient for this param or not 39 if init_value
is not None:
41 assert type(init_value)
in [int, float, str]
42 init = (
'ConstantFill', dict(
48 param = self.param_init_net.__getattr__(init[0])(
54 param = self.net.AddExternalInput(name)
57 self.params.append(param)
59 self.non_trainable_params.append(param)
65 Returns the params in current namescope 68 namescope = scope.CurrentNameScope()
70 if not namescope.endswith(scope._NAMESCOPE_SEPARATOR):
71 namescope += scope._NAMESCOPE_SEPARATOR
78 if p.GetNameScope() == namescope
81 def GetAllParams(self, namescope=None):
def AddParam(self, name, init=None, init_value=None, trainable=True)
def GetNonTrainableParams(self, namescope=None)
def GetComputedParams(self, namescope=None)
def GetParams(self, namescope=None, top_scope=False)