3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
23 def __init__(self, model, input_record, seed=0, modulo=None,
24 use_hashing=
True, name=
'sparse_feature_hash', **kwargs):
25 super(SparseFeatureHash, self).__init__(model, name, input_record, **kwargs)
29 if schema.equal_schemas(input_record, IdList):
32 categorical_limit=self.
modulo,
33 feature_specs=input_record.items.metadata.feature_specs,
34 expected_value=input_record.items.metadata.expected_value
36 with core.NameScope(name):
38 self.output_schema.items.set_metadata(metadata)
40 elif schema.equal_schemas(input_record, IdScoreList):
43 categorical_limit=self.
modulo,
44 feature_specs=input_record.keys.metadata.feature_specs,
45 expected_value=input_record.keys.metadata.expected_value
47 with core.NameScope(name):
49 self.output_schema.keys.set_metadata(metadata)
52 assert False,
"Input type must be one of (IdList, IdScoreList)" 54 assert self.
modulo >= 1,
'Unexpected modulo: {}'.format(self.
modulo)
55 if input_record.lengths.metadata:
56 self.output_schema.lengths.set_metadata(input_record.lengths.metadata)
61 self.tags.update([Tags.CPU_ONLY])
63 def extract_hash_size(self, metadata):
64 if metadata.feature_specs
and metadata.feature_specs.desired_hash_size:
65 return metadata.feature_specs.desired_hash_size
66 elif metadata.categorical_limit
is not None:
67 return metadata.categorical_limit
69 assert False,
"desired_hash_size or categorical_limit must be set" 71 def add_ops(self, net):
73 self.input_record.lengths(),
74 self.output_schema.lengths()
77 input_blob = self.input_record.items()
78 output_blob = self.output_schema.items()
80 input_blob = self.input_record.keys()
81 output_blob = self.output_schema.keys()
83 self.input_record.values(),
84 self.output_schema.values()
87 raise NotImplementedError()
91 input_blob, output_blob, seed=self.
seed, modulo=self.
modulo 95 input_blob, output_blob, divisor=self.
modulo, sign_follow_divisor=
True
def extract_hash_size(self, metadata)