3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
13 get_categorical_limit,
19 logger = logging.getLogger(__name__)
23 def __init__(self, model, input_record, weight_optim=None,
24 name=
"position_weights"):
25 super(PositionWeighted, self).__init__(model, name, input_record)
27 assert isinstance(input_record,
schema.List),
"Incorrect input type" 28 length_metadata = input_record.lengths.metadata
29 max_length = (length_metadata.categorical_limit
if length_metadata
is 31 if max_length
is not None:
32 self.
shape = max_length
34 self.
shape = get_categorical_limit(input_record)
36 '{}: categorical_limit of lengths is not available, using ' 37 'categorical_limit of the keys: {}'.format(
38 str(input_record.lengths()), self.
shape))
40 self.
pos_w = self.create_param(param_name=
'pos_w',
42 initializer=(
'ConstantFill', {
'value': 1.0}),
43 optimizer=weight_optim)
48 self.get_next_blob_reference(
"pos_w_gather")))
51 self.tags.update({Tags.HANDLE_AS_SPARSE_LAYER})
53 def get_memory_usage(self):
56 def add_ops(self, net):
57 inc_seq = net.LengthsRangeFill(
58 [self.input_record.lengths()],
59 self.input_record.lengths() +
'_pos_w_seq' 63 [self.
pos_w, inc_seq],
64 self.output_schema.position_weights.field_blobs())