Caffe2 - Python API
A deep learning, cross platform ML framework
position_weighted.py
1 ## @package position_weighted
2 # Module caffe2.python.layers.position_weighted
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import logging
9 import numpy as np
10 
11 from caffe2.python import schema
12 from caffe2.python.layers.layers import (
13  get_categorical_limit,
14  ModelLayer,
15 )
16 
17 from caffe2.python.layers.tags import Tags
18 
19 logger = logging.getLogger(__name__)
20 
21 
22 class PositionWeighted(ModelLayer):
23  def __init__(self, model, input_record, weight_optim=None,
24  name="position_weights"):
25  super(PositionWeighted, self).__init__(model, name, input_record)
26 
27  assert isinstance(input_record, schema.List), "Incorrect input type"
28  length_metadata = input_record.lengths.metadata
29  max_length = (length_metadata.categorical_limit if length_metadata is
30  not None else None)
31  if max_length is not None:
32  self.shape = max_length
33  else:
34  self.shape = get_categorical_limit(input_record)
35  logger.warning(
36  '{}: categorical_limit of lengths is not available, using '
37  'categorical_limit of the keys: {}'.format(
38  str(input_record.lengths()), self.shape))
39 
40  self.pos_w = self.create_param(param_name='pos_w',
41  shape=[self.shape, ],
42  initializer=('ConstantFill', {'value': 1.0}),
43  optimizer=weight_optim)
44 
46  ('position_weights',
47  schema.Scalar((np.float32, self.shape),
48  self.get_next_blob_reference("pos_w_gather")))
49  )
50 
51  self.tags.update({Tags.HANDLE_AS_SPARSE_LAYER})
52 
53  def get_memory_usage(self):
54  return self.shape
55 
56  def add_ops(self, net):
57  inc_seq = net.LengthsRangeFill(
58  [self.input_record.lengths()],
59  self.input_record.lengths() + '_pos_w_seq'
60  )
61 
62  net.Gather(
63  [self.pos_w, inc_seq],
64  self.output_schema.position_weights.field_blobs())