3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
11 get_categorical_limit,
25 def get_sparse_lookup_predictor_version(version):
26 assert version
in {
'fp32',
'fp16',
'uint8rowwise',
'fused_uint8rowwise'},\
27 "Unexpected version of sparse_lookup layer {0}".format(version)
31 def get_sparse_lookup_trainer_version(version):
32 assert version
in {
'fp32',
'fp16'},\
33 "Unexpected version of sparse_lookup layer {0}".format(version)
37 def _is_id_list(input_record):
38 return schema.equal_schemas(input_record, IdList)
41 def _is_id_score_list(input_record):
42 return schema.equal_schemas(input_record,
44 check_field_types=
False)
48 _id_list_supported_reducers = [
49 'LogMeanExp',
'LogSumExp',
'Max',
'Mean',
'Sum',
50 'WeightedSum',
'WeightedMean',
'Sqrt',
'None']
52 _id_score_list_supported_reducers = [
53 'PositionWeighted',
'RecencyWeighted',
'Mean',
'Sum',
'WeightedSum',
54 'WeightedMean',
'None' 57 def __init__(self, model, input_record, inner_shape, reducer,
58 weight_init=
None, weight_optim=
None,
59 name=
'sparse_lookup', regularizer=
None, **kwargs):
61 super(SparseLookup, self).__init__(model, name, input_record, **kwargs)
64 if isinstance(inner_shape, int):
65 inner_shape = [inner_shape]
66 assert isinstance(inner_shape, list)
or isinstance(inner_shape, tuple),\
67 "Unexpected type for inner_shape, expected list or tuple, got {0}".\
68 format(type(inner_shape))
70 if reducer ==
"PositionWeighted":
71 assert _is_id_score_list(self.input_record), (
72 "PositionWeighted only support IdScoreList, but got {} " +
73 "please use PositionWeighted layer to convert IdList " +
74 "to IdScoreList").format(repr(self.input_record))
77 elif reducer ==
"RecencyWeighted":
78 assert _is_id_score_list(self.input_record), (
79 "RecencyWeighted only supports IdScoreList.")
83 input_dim = get_categorical_limit(input_record)
84 assert input_dim > 0, (
85 "{} should have categorical limit > 0, but got {}".format(
86 get_key(input_record)(), input_dim))
89 self.
shape = [input_dim] + inner_shape
95 if _is_id_list(self.input_record):
96 sparse_key = self.input_record.items()
97 elif _is_id_score_list(self.input_record):
98 sparse_key = self.input_record.keys()
100 raise NotImplementedError()
102 if self.input_record.lengths.metadata:
103 avg_length = self.input_record.lengths.metadata.expected_value
107 self.
w = self.create_param(
111 optimizer=weight_optim,
112 ps_param=LayerPsParam(
113 sparse_key=sparse_key,
114 average_length=avg_length),
115 regularizer=regularizer
121 param_name=
'scale_bias',
124 optimizer=model.NoOptim,
128 (np.float32, inner_shape),
129 self.get_next_blob_reference(
'output'),
132 def get_memory_usage(self):
133 return functools.reduce(operator.mul, self.
shape) * 4
135 def get_fp16_compatible_parameters(self):
138 def support_8bit(self):
145 def get_8bits_compatible_parameters(self, fused=True):
149 RowwiseQuantized8BitsWeight = collections.namedtuple(
150 'RowwiseQuantized8BitsWeight',
'w' 152 return [RowwiseQuantized8BitsWeight(self.
w)]
154 RowwiseQuantized8BitsWeight = collections.namedtuple(
155 'RowwiseQuantized8BitsWeight',
'w, scale_bias' 157 return [RowwiseQuantized8BitsWeight(self.
w, self.
scale_bias)]
159 def _get_default_init_op(self):
162 cur_scope = get_current_scope()
163 trainer_version = get_sparse_lookup_trainer_version(
164 **cur_scope.get(get_sparse_lookup_trainer_version.__name__,
165 {
'version':
'fp32'}))
167 if trainer_version ==
'fp32':
168 default_weight_init = (
'UniformFill', {
'min': -scale,
'max': scale})
169 elif trainer_version ==
'fp16':
170 default_weight_init = (
"Float16UniformFill", {
'min': -scale,
'max': scale})
172 raise NotImplementedError(
173 "Train version {} is not currently supported".format(trainer_version)
178 return default_weight_init
180 def _gather_wrapper(self, net, version, in_indices, out):
184 if version ==
'fp32':
185 return net.Gather([self.
w, in_indices], out)
186 elif version ==
'fp16':
187 gathered_w = net.Gather([self.
w, in_indices],
'gathered_w')
189 return net.HalfToFloat(gathered_w, out)
190 elif version ==
'uint8rowwise':
191 gathered_w = net.Gather([self.
w, in_indices],
'gathered_w')
192 gathered_scale_bias = net.Gather(
194 'gathered_scale_bias' 197 return net.Rowwise8BitQuantizedToFloat(
198 [gathered_w, gathered_scale_bias], out)
199 elif version ==
'fused_uint8rowwise':
200 gathered_w = net.Gather([self.
w, in_indices],
'gathered_w')
201 return net.Fused8BitRowwiseQuantizedToFloat(gathered_w, out)
203 raise "Unsupported version of operators in SparseLookup " +\
204 "layer: {0}".format(version)
206 def _sparse_lengths_weighted_reducer(
207 self, in_indices, weights, reducer,
208 net, version, grad_on_weights=0):
213 self.input_record.lengths()
215 layer_name =
'SparseLengths' + reducer
217 if version
in [
'fp32',
'fp16']:
223 if reducer ==
"WeightedSum" and version ==
"fp16":
224 net.SparseLengthsWeightedSum(
226 self.output_schema.field_blobs(),
227 grad_on_weights=grad_on_weights,
231 net.__getattr__(layer_name)(
233 self.output_schema.field_blobs(),
234 grad_on_weights=grad_on_weights,
236 elif version ==
'uint8rowwise':
237 op_input.insert(len(op_input), self.
scale_bias)
238 net.__getattr__(layer_name +
'8BitsRowwise')(
239 op_input, self.output_schema.field_blobs())
240 elif version ==
'fused_uint8rowwise':
241 net.__getattr__(layer_name +
'Fused8BitRowwise')(
242 op_input, self.output_schema.field_blobs())
244 raise "Unsupported version of operator in SparseLookUp " +\
245 "layer: {0}".format(version)
248 def _add_ops_id_list(self, net, version):
250 "Unsupported reducer: {} for ID_LIST".format(self.
reducer)
252 if self.
reducer in [
'Sum',
'Mean',
'WeightedSum',
'WeightedMean']:
254 self.input_record.items(),
255 self.input_record.lengths()]
260 if self.
reducer ==
'WeightedSum':
262 elif self.
reducer ==
'WeightedMean':
265 layer_name =
'SparseLengths' + self.
reducer 266 if version
in [
'fp32',
'fp16']:
269 net.__getattr__(layer_name)(
271 self.output_schema.field_blobs(),
273 elif version ==
'uint8rowwise':
274 op_input.insert(len(op_input), self.
scale_bias)
275 net.__getattr__(layer_name +
'8BitsRowwise')(
276 op_input, self.output_schema.field_blobs())
277 elif version ==
'fused_uint8rowwise':
278 net.__getattr__(layer_name +
'Fused8BitRowwise')(
279 op_input, self.output_schema.field_blobs())
281 raise "Unsupported version of operator in SparseLookUp " +\
282 "layer: {0}".format(version)
285 sqrt_weight = net.LengthsToWeights(
286 [self.input_record.lengths()],
287 [net.NextScopedBlob(
'lengths_sqrt')],
291 self.input_record.items(),
293 'WeightedSum', net, version)
299 self.output_schema.field_blobs())
303 net, version, self.input_record.items(),
'table_rows')
305 segment_ids = net.LengthsToSegmentIds(
306 self.input_record.lengths(),
307 net.NextScopedBlob(self.input_record.lengths() +
'_sid'))
308 net.__getattr__(
'SortedSegmentRange' + self.
reducer)(
309 [table_rows, segment_ids],
310 self.output_schema.field_blobs(),
314 def _add_ops_id_score_list(self, net, version):
316 "Unsupported reducer: {} for ID_SCORE_LIST".format(self.
reducer)
318 if self.
reducer in [
'WeightedSum',
'WeightedMean']:
320 self.input_record.keys(),
321 self.input_record.values(),
324 elif self.
reducer in [
'Sum',
'Mean']:
326 self.input_record.keys(),
327 self.input_record.lengths()]
329 layer_name =
'SparseLengths' + self.
reducer 331 if version
in [
'fp32',
'fp16']:
332 net.__getattr__(layer_name)(
334 self.output_schema.field_blobs(),
336 elif version ==
'uint8rowwise':
337 net.__getattr__(layer_name +
'8BitsRowwise')(
338 op_input, self.output_schema.field_blobs())
339 elif version ==
'fused_uint8rowwise':
340 net.__getattr__(layer_name +
'Fused8BitRowwise')(
341 op_input, self.output_schema.field_blobs())
343 raise "Unsupported version of operator in SparseLookUp " +\
344 "layer: {0}".format(version)
346 elif self.
reducer in [
'PositionWeighted',
'RecencyWeighted']:
348 self.input_record.keys(),
350 'WeightedSum', net, version, grad_on_weights=1)
356 self.output_schema.field_blobs())
358 raise "Only Sum, Mean, None are supported for IdScoreList input." +\
359 "Trying to create with {}".format(self.
reducer)
361 def _add_ops(self, net, version='fp32'):
362 if _is_id_list(self.input_record):
364 elif _is_id_score_list(self.input_record):
367 raise "Unsupported input type {0}".format(self.input_record)
369 def add_train_ops(self, net):
372 def add_ops(self, net):
373 cur_scope = get_current_scope()
374 version = get_sparse_lookup_predictor_version(
375 **cur_scope.get(get_sparse_lookup_predictor_version.__name__,
376 {
'version':
'fp32'}))
380 if not self.
support_8bit()
and version
in {
'uint8rowwise',
381 'fused_uint8rowwise'}:
def _sparse_lengths_weighted_reducer(self, in_indices, weights, reducer, net, version, grad_on_weights=0)
list _id_list_supported_reducers
def _add_ops(self, net, version='fp32')
list _id_score_list_supported_reducers
def _add_ops_id_list(self, net, version)
def _get_default_init_op(self)
def _add_ops_id_score_list(self, net, version)
def _gather_wrapper(self, net, version, in_indices, out)