1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
14 This class abstracts out parameter creation. One can come up with a new 15 Initializer in order to implement more complex parameter initializaion logic 18 def __init__(self, operator_name=None, **kwargs):
22 def update(self, operator_name, kwargs):
24 raise Exception(
"Operator name overwrites are not allowed")
28 def create_param(self, param_name, init_net, shape):
40 This class is used in cases when the parameter should not be initialized by 41 the initializer, but rather provided in the workspace when param_init_net is 44 Current version is not doing any real sanity checks to the parameter. 47 def create_param(self, param_name, init_net, shape):
48 if isinstance(param_name, BlobReference):
50 elif isinstance(param_name, six.string_types):
51 param = ScopedBlobReference(param_name, init_net)
53 raise "Unsupported type for param_name" 64 Used in cases when the parameter should be used at half (16-bit) precision 65 for compute purposes (i.e. on the forward and backward pass) but 66 needs to be stored and optimized at single (32-bit) precision so tiny 67 gradients with small learning rates don't underflow FP16 precision. 68 A 32-bit copy of the 16-bit blob is stored in the ParameterInfo. 69 This is helpful for mixed-precision training, see 70 https://arxiv.org/abs/1710.03740 for details. 72 def update(self, operator_name, kwargs):
74 raise Exception(
"Operator name overwrites are not allowed")
78 def create_param(self, param_name, init_net, shape):
81 [], param_name +
"_fp32", shape=shape,
84 param = init_net.FloatToHalf(
85 param_fp32, param_name)
91 blob_copy={DataType.FLOAT: param_fp32}
97 Like PseudoFP16Initializer above, except the primary blob is taken to 98 be the 32-bit precision parameter, and the 16-bit version of the blob 99 is stored in blob_copy instead. 101 def update(self, operator_name, kwargs):
103 raise Exception(
"Operator name overwrites are not allowed")
107 def create_param(self, param_name, init_net, shape):
110 [], param_name, shape=shape,
113 param_fp16 = init_net.FloatToHalf(
114 param_fp32, param_name +
"_fp16")
120 blob_copy={DataType.FLOAT16: param_fp16}
123 def update_initializer(initializer_class,
124 operator_name_and_kwargs,
125 default_operator_name_and_kwargs):
127 A helper function to convert from operator_name_and_kwargs to new 128 object of type initializer_class. This function serves two purposes: 130 1. Support for custom initialization operators being passed in 131 2. Allow user to specify a custom Initializer without overwriting 132 default operators used for initialization 134 If initializer_class is None, creates a default initializer using 135 the Initializer class and operator_name_and_kwargs provided 137 If operator_name_and_kwargs is None, uses default_operator_name_and_kwargs 139 returns an instantiated Initializer object 141 def get_initializer_args():
143 operator_name_and_kwargs
or 144 default_operator_name_and_kwargs
147 if initializer_class
is not None:
148 init = initializer_class(get_initializer_args()[0],
149 **get_initializer_args()[1])
152 get_initializer_args()[0],
153 **get_initializer_args()[1]