3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
13 get_categorical_limit,
19 logger = logging.getLogger(__name__)
23 def __init__(self, model, input_record, max_score=0, bucket_boundaries=None,
24 weight_optim=
None, name=
"bucket_weighted"):
25 super(BucketWeighted, self).__init__(model, name, input_record)
27 assert isinstance(input_record,
schema.List),
"Incorrect input type" 29 if bucket_boundaries
is not None:
30 self.
shape = len(bucket_boundaries) + 1
32 self.
shape = max_score
34 self.
shape = get_categorical_limit(input_record)
36 self.
bucket_w = self.create_param(param_name=
'bucket_w',
38 initializer=(
'ConstantFill', {
'value': 1.0}),
39 optimizer=weight_optim)
44 self.get_next_blob_reference(
"bucket_w_gather")))
47 self.tags.update({Tags.HANDLE_AS_SPARSE_LAYER})
49 def get_memory_usage(self):
52 def add_ops(self, net):
54 buckets = net.Bucketize(
55 self.input_record.values(),
60 buckets = self.input_record.values()
61 buckets_int = net.Cast(
64 to=core.DataType.INT32
68 self.output_schema.bucket_weights.field_blobs())