Caffe2 - Python API
A deep learning, cross platform ML framework
constant_weight.py
1 # @package constant_weight
2 # Module caffe2.fb.python.layers.constant_weight
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import schema
9 from caffe2.python.layers.layers import ModelLayer
10 import numpy as np
11 
12 
14  def __init__(
15  self,
16  model,
17  input_record,
18  weights=None,
19  name='constant_weight',
20  **kwargs
21  ):
22  super(ConstantWeight,
23  self).__init__(model, name, input_record, **kwargs)
25  np.float32, self.get_next_blob_reference('constant_weight')
26  )
27  self.data = self.input_record.field_blobs()
28  self.num = len(self.data)
29  weights = (
30  weights if weights is not None else
31  [1. / self.num for _ in range(self.num)]
32  )
33  assert len(weights) == self.num
34  self.weights = [
35  self.model.add_global_constant(
36  '%s_weight_%d' % (self.name, i), float(weights[i])
37  ) for i in range(self.num)
38  ]
39 
40  def add_ops(self, net):
41  net.WeightedSum(
42  [b for x_w_pair in zip(self.data, self.weights) for b in x_w_pair],
43  self.output_schema()
44  )
def get_next_blob_reference(self, name)
Definition: layers.py:349