Caffe2 - Python API
A deep learning, cross platform ML framework
position_weighted.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package position_weighted
17 # Module caffe2.python.layers.position_weighted
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 import logging
24 import numpy as np
25 
26 from caffe2.python import schema
27 from caffe2.python.layers.layers import (
28  get_categorical_limit,
29  ModelLayer,
30 )
31 
32 from caffe2.python.layers.tags import Tags
33 
34 logger = logging.getLogger(__name__)
35 
36 
37 class PositionWeighted(ModelLayer):
38  def __init__(self, model, input_record, weight_optim=None,
39  name="position_weights"):
40  super(PositionWeighted, self).__init__(model, name, input_record)
41 
42  assert isinstance(input_record, schema.List), "Incorrect input type"
43  length_metadata = input_record.lengths.metadata
44  max_length = (length_metadata.categorical_limit if length_metadata is
45  not None else None)
46  if max_length is not None:
47  self.shape = max_length
48  else:
49  self.shape = get_categorical_limit(input_record)
50  logger.warning(
51  '{}: categorical_limit of lengths is not available, using '
52  'categorical_limit of the keys: {}'.format(
53  str(input_record.lengths()), self.shape))
54 
55  self.pos_w = self.create_param(param_name='pos_w',
56  shape=[self.shape, ],
57  initializer=('ConstantFill', {'value': 1.0}),
58  optimizer=weight_optim)
59 
61  ('position_weights',
62  schema.Scalar((np.float32, self.shape),
63  self.get_next_blob_reference("pos_w_gather")))
64  )
65 
66  self.tags.update({Tags.HANDLE_AS_SPARSE_LAYER})
67  self.tags.update({Tags.GRADIENT_FROM_PS})
68 
69  def get_memory_usage(self):
70  return self.shape
71 
72  def add_ops(self, net):
73  inc_seq = net.LengthsRangeFill(
74  [self.input_record.lengths()],
75  self.input_record.lengths() + '_pos_w_seq'
76  )
77 
78  net.Gather(
79  [self.pos_w, inc_seq],
80  self.output_schema.position_weights.field_blobs())