Caffe2 - Python API
A deep learning, cross platform ML framework
initializers.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19 from __future__ import unicode_literals
20 
21 from caffe2.python.core import DataType, BlobReference, ScopedBlobReference
22 from caffe2.python.modeling.parameter_info import ParameterInfo
23 
24 import six
25 
26 
27 class Initializer(object):
28  '''
29  This class abstracts out parameter creation. One cancome up with a new
30  Initializer in order to implement more complex parameter initializaion logic
31  '''
32 
33  def __init__(self, operator_name=None, **kwargs):
34  self.operator_name = operator_name
35  self.operator_kwargs = kwargs
36 
37  def update(self, operator_name, kwargs):
38  if self.operator_name is not None:
39  raise Exception("Operator name overwrites are not allowed")
40  self.operator_name = operator_name
41  self.operator_kwargs = kwargs
42 
43  def create_param(self, param_name, init_net, shape):
44  param = init_net.__getattr__(self.operator_name)(
45  [], param_name, shape=shape, **self.operator_kwargs)
46  return ParameterInfo(
47  param_id=None,
48  param=param,
49  shape=shape,
50  )
51 
52 
53 class ExternalInitializer(object):
54  '''
55  This class is used in cases when the parameter should not be initialized by
56  the initializer, but rather provided in the workspace when param_init_net is
57  executed.
58 
59  Current version is not doing any real sanity checks to the parameter.
60  '''
61 
62  def create_param(self, param_name, init_net, shape):
63  if isinstance(param_name, BlobReference):
64  param = BlobReference(str(param_name), init_net)
65  elif isinstance(param_name, six.string_types):
66  param = ScopedBlobReference(param_name, init_net)
67  else:
68  raise "Unsupported type for param_name"
69  # TODO(amalevich): Add operator that will check param in the workspace
70  return ParameterInfo(
71  param_id=None,
72  param=param,
73  shape=shape,
74  )
75 
76 
78 
79  def update(self, operator_name, kwargs):
80  if self.operator_name is not None:
81  raise Exception("Operator name overwrites are not allowed")
82  self.operator_name = operator_name
83  self.operator_kwargs = kwargs
84 
85  def create_param(self, param_name, init_net, shape):
86  # create master fp32 copy
87  param_fp32 = init_net.__getattr__(self.operator_name)(
88  [], param_name + "_fp32", shape=shape,
89  **self.operator_kwargs)
90  # cast to fp16 copy
91  param = init_net.FloatToHalf(
92  param_fp32, param_name)
93 
94  return ParameterInfo(
95  param_id=None,
96  param=param,
97  shape=shape,
98  blob_copy={DataType.FLOAT: param_fp32}
99  )
100 
101 
103 
104  def update(self, operator_name, kwargs):
105  if self.operator_name is not None:
106  raise Exception("Operator name overwrites are not allowed")
107  self.operator_name = operator_name
108  self.operator_kwargs = kwargs
109 
110  def create_param(self, param_name, init_net, shape):
111  # create master fp32 copy
112  param_fp32 = init_net.__getattr__(self.operator_name)(
113  [], param_name, shape=shape,
114  **self.operator_kwargs)
115  # cast to fp16 copy
116  param_fp16 = init_net.FloatToHalf(
117  param_fp32, param_name + "_fp16")
118 
119  return ParameterInfo(
120  param_id=None,
121  param=param_fp32,
122  shape=shape,
123  blob_copy={DataType.FLOAT16: param_fp16}
124  )
125 
126 def update_initializer(initializer_class,
127  operator_name_and_kwargs,
128  default_operator_name_and_kwargs):
129  '''
130  A helper function to convert from operator_name_and_kwargs to new
131  object of type initializer_class. This function serves two purposes:
132 
133  1. Support for custom initialization operators being passed in
134  2. Allow user to specify a custom Initializer without overwriting
135  default operators used for initialization
136 
137  If initializer_class is None, creates a default initializer using
138  the Initializer class and operator_name_and_kwargs provided
139 
140  If operator_name_and_kwargs is None, uses default_operator_name_and_kwargs
141 
142  returns an instantiated Initializer object
143  '''
144  def get_initializer_args():
145  return (
146  operator_name_and_kwargs or
147  default_operator_name_and_kwargs
148  )
149 
150  if initializer_class is not None:
151  init = initializer_class(get_initializer_args()[0],
152  **get_initializer_args()[1])
153  else:
154  init = Initializer(
155  get_initializer_args()[0],
156  **get_initializer_args()[1]
157  )
158  return init