1 from __future__ 
import absolute_import
     2 from __future__ 
import division
     3 from __future__ 
import print_function
     4 from __future__ 
import unicode_literals
    14     This class abstracts out parameter creation. One can come up with a new    15     Initializer in order to implement more complex parameter initializaion logic    18     def __init__(self, operator_name=None, **kwargs):
    22     def update(self, operator_name, kwargs):
    24             raise Exception(
"Operator name overwrites are not allowed")
    28     def create_param(self, param_name, init_net, shape):
    40     This class is used in cases when the parameter should not be initialized by    41     the initializer, but rather provided in the workspace when param_init_net is    44     Current version is not doing any real sanity checks to the parameter.    47     def create_param(self, param_name, init_net, shape):
    48         if isinstance(param_name, BlobReference):
    50         elif isinstance(param_name, six.string_types):
    51             param = ScopedBlobReference(param_name, init_net)
    53             raise "Unsupported type for param_name"    64     Used in cases when the parameter should be used at half (16-bit) precision    65     for compute purposes (i.e. on the forward and backward pass) but    66     needs to be stored and optimized at single (32-bit) precision so tiny    67     gradients with small learning rates don't underflow FP16 precision.    68     A 32-bit copy of the 16-bit blob is stored in the ParameterInfo.    69     This is helpful for mixed-precision training, see    70     https://arxiv.org/abs/1710.03740 for details.    72     def update(self, operator_name, kwargs):
    74             raise Exception(
"Operator name overwrites are not allowed")
    78     def create_param(self, param_name, init_net, shape):
    81             [], param_name + 
"_fp32", shape=shape,
    84         param = init_net.FloatToHalf(
    85             param_fp32, param_name)
    91             blob_copy={DataType.FLOAT: param_fp32}
    97     Like PseudoFP16Initializer above, except the primary blob is taken to    98     be the 32-bit precision parameter, and the 16-bit version of the blob    99     is stored in blob_copy instead.   101     def update(self, operator_name, kwargs):
   103             raise Exception(
"Operator name overwrites are not allowed")
   107     def create_param(self, param_name, init_net, shape):
   110             [], param_name, shape=shape,
   113         param_fp16 = init_net.FloatToHalf(
   114             param_fp32, param_name + 
"_fp16")
   120             blob_copy={DataType.FLOAT16: param_fp16}
   123 def update_initializer(initializer_class,
   124                        operator_name_and_kwargs,
   125                        default_operator_name_and_kwargs):
   127     A helper function to convert from operator_name_and_kwargs to new   128     object of type initializer_class. This function serves two purposes:   130     1. Support for custom initialization operators being passed in   131     2. Allow user to specify a custom Initializer without overwriting   132        default operators used for initialization   134     If initializer_class is None, creates a default initializer using   135     the Initializer class and operator_name_and_kwargs provided   137     If operator_name_and_kwargs is None, uses default_operator_name_and_kwargs   139     returns an instantiated Initializer object   141     def get_initializer_args():
   143             operator_name_and_kwargs 
or   144             default_operator_name_and_kwargs
   147     if initializer_class 
is not None:
   148         init = initializer_class(get_initializer_args()[0],
   149                                  **get_initializer_args()[1])
   152             get_initializer_args()[0],
   153             **get_initializer_args()[1]