Caffe2 - Python API
A deep learning, cross platform ML framework
sparse_lookup.py
1 ## @package sparse_lookup
2 # Module caffe2.python.layers.sparse_lookup
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python.helpers.arg_scope import get_current_scope
9 from caffe2.python import schema
10 from caffe2.python.layers.layers import (
11  get_categorical_limit,
12  get_key,
13  IdList,
14  IdScoreList,
15  LayerPsParam,
16  ModelLayer,
17 )
18 import collections
19 import functools
20 import math
21 import numpy as np
22 import operator
23 
24 
25 def get_sparse_lookup_predictor_version(version):
26  assert version in {'fp32', 'fp16', 'uint8rowwise', 'fused_uint8rowwise'},\
27  "Unexpected version of sparse_lookup layer {0}".format(version)
28  return version
29 
30 
31 def get_sparse_lookup_trainer_version(version):
32  assert version in {'fp32', 'fp16'},\
33  "Unexpected version of sparse_lookup layer {0}".format(version)
34  return version
35 
36 
37 def _is_id_list(input_record):
38  return schema.equal_schemas(input_record, IdList)
39 
40 
41 def _is_id_score_list(input_record):
42  return schema.equal_schemas(input_record,
43  IdScoreList,
44  check_field_types=False)
45 
46 
47 class SparseLookup(ModelLayer):
48  _id_list_supported_reducers = [
49  'LogMeanExp', 'LogSumExp', 'Max', 'Mean', 'Sum',
50  'WeightedSum', 'WeightedMean', 'Sqrt', 'None']
51 
52  _id_score_list_supported_reducers = [
53  'PositionWeighted', 'RecencyWeighted', 'Mean', 'Sum', 'WeightedSum',
54  'WeightedMean', 'None'
55  ]
56 
57  def __init__(self, model, input_record, inner_shape, reducer,
58  weight_init=None, weight_optim=None,
59  name='sparse_lookup', regularizer=None, **kwargs):
60 
61  super(SparseLookup, self).__init__(model, name, input_record, **kwargs)
62 
63  # TODO Add some asserts about input type
64  if isinstance(inner_shape, int):
65  inner_shape = [inner_shape]
66  assert isinstance(inner_shape, list) or isinstance(inner_shape, tuple),\
67  "Unexpected type for inner_shape, expected list or tuple, got {0}".\
68  format(type(inner_shape))
69 
70  if reducer == "PositionWeighted":
71  assert _is_id_score_list(self.input_record), (
72  "PositionWeighted only support IdScoreList, but got {} " +
73  "please use PositionWeighted layer to convert IdList " +
74  "to IdScoreList").format(repr(self.input_record))
75  self.external_weights = input_record.values()
76 
77  elif reducer == "RecencyWeighted":
78  assert _is_id_score_list(self.input_record), (
79  "RecencyWeighted only supports IdScoreList.")
80  self.external_weights = input_record.values()
81  self.reducer = reducer
82 
83  input_dim = get_categorical_limit(input_record)
84  assert input_dim > 0, (
85  "{} should have categorical limit > 0, but got {}".format(
86  get_key(input_record)(), input_dim))
87 
88  self.input_dim = input_dim
89  self.shape = [input_dim] + inner_shape
90 
91  default_init_op = self._get_default_init_op()
92 
93  self.weight_init = weight_init or default_init_op
94 
95  if _is_id_list(self.input_record):
96  sparse_key = self.input_record.items()
97  elif _is_id_score_list(self.input_record):
98  sparse_key = self.input_record.keys()
99  else:
100  raise NotImplementedError()
101 
102  if self.input_record.lengths.metadata:
103  avg_length = self.input_record.lengths.metadata.expected_value
104  else:
105  avg_length = None
106 
107  self.w = self.create_param(
108  param_name='w',
109  shape=self.shape,
110  initializer=self.weight_init,
111  optimizer=weight_optim,
112  ps_param=LayerPsParam(
113  sparse_key=sparse_key,
114  average_length=avg_length),
115  regularizer=regularizer
116  )
117 
118  self.scale_bias_init = ('ConstantFill', {'value': 0.0})
119 
120  self.scale_bias = self.create_param(
121  param_name='scale_bias',
122  shape=[],
123  initializer=self.scale_bias_init,
124  optimizer=model.NoOptim,
125  )
126 
128  (np.float32, inner_shape),
129  self.get_next_blob_reference('output'),
130  )
131 
132  def get_memory_usage(self):
133  return functools.reduce(operator.mul, self.shape) * 4
134 
135  def get_fp16_compatible_parameters(self):
136  return [self.w]
137 
138  def support_8bit(self):
139  # Rowwise quantization makes sense only if shape it's 2D matrix with
140  # second dimension >= 8
141  if len(self.shape) != 2 or self.shape[1] < 8:
142  return False
143  return True
144 
145  def get_8bits_compatible_parameters(self, fused=True):
146  if not self.support_8bit():
147  return []
148  if fused:
149  RowwiseQuantized8BitsWeight = collections.namedtuple(
150  'RowwiseQuantized8BitsWeight', 'w'
151  )
152  return [RowwiseQuantized8BitsWeight(self.w)]
153  else:
154  RowwiseQuantized8BitsWeight = collections.namedtuple(
155  'RowwiseQuantized8BitsWeight', 'w, scale_bias'
156  )
157  return [RowwiseQuantized8BitsWeight(self.w, self.scale_bias)]
158 
159  def _get_default_init_op(self):
160  scale = math.sqrt(1.0 / self.input_dim)
161 
162  cur_scope = get_current_scope()
163  trainer_version = get_sparse_lookup_trainer_version(
164  **cur_scope.get(get_sparse_lookup_trainer_version.__name__,
165  {'version': 'fp32'}))
166 
167  if trainer_version == 'fp32':
168  default_weight_init = ('UniformFill', {'min': -scale, 'max': scale})
169  elif trainer_version == 'fp16':
170  default_weight_init = ("Float16UniformFill", {'min': -scale, 'max': scale})
171  else:
172  raise NotImplementedError(
173  "Train version {} is not currently supported".format(trainer_version)
174  )
175 
176  self.trainer_version = trainer_version
177 
178  return default_weight_init
179 
180  def _gather_wrapper(self, net, version, in_indices, out):
181  # Gather can work on all kinds of input data types, and output
182  # data with the same type. Convert the output of Gather to float,
183  # because the follow-up Ops expect fp32.
184  if version == 'fp32':
185  return net.Gather([self.w, in_indices], out)
186  elif version == 'fp16':
187  gathered_w = net.Gather([self.w, in_indices], 'gathered_w')
188 
189  return net.HalfToFloat(gathered_w, out)
190  elif version == 'uint8rowwise':
191  gathered_w = net.Gather([self.w, in_indices], 'gathered_w')
192  gathered_scale_bias = net.Gather(
193  [self.scale_bias, in_indices],
194  'gathered_scale_bias'
195  )
196 
197  return net.Rowwise8BitQuantizedToFloat(
198  [gathered_w, gathered_scale_bias], out)
199  elif version == 'fused_uint8rowwise':
200  gathered_w = net.Gather([self.w, in_indices], 'gathered_w')
201  return net.Fused8BitRowwiseQuantizedToFloat(gathered_w, out)
202  else:
203  raise "Unsupported version of operators in SparseLookup " +\
204  "layer: {0}".format(version)
205 
206  def _sparse_lengths_weighted_reducer(
207  self, in_indices, weights, reducer,
208  net, version, grad_on_weights=0):
209  op_input = [
210  self.w,
211  weights,
212  in_indices,
213  self.input_record.lengths()
214  ]
215  layer_name = 'SparseLengths' + reducer
216 
217  if version in ['fp32', 'fp16']:
218  # SparseLengths* Ops will accept either fp16 or fp32 embedding
219  # matrix and output fp32 pooled embedding
220  # A special case here is that we need FP16 engine for
221  # SparseLengthsWeightedSum when FP16 embeedings are used for
222  # correct backward updates
223  if reducer == "WeightedSum" and version == "fp16":
224  net.SparseLengthsWeightedSum(
225  op_input,
226  self.output_schema.field_blobs(),
227  grad_on_weights=grad_on_weights,
228  engine='FP16',
229  )
230  else:
231  net.__getattr__(layer_name)(
232  op_input,
233  self.output_schema.field_blobs(),
234  grad_on_weights=grad_on_weights,
235  )
236  elif version == 'uint8rowwise':
237  op_input.insert(len(op_input), self.scale_bias)
238  net.__getattr__(layer_name + '8BitsRowwise')(
239  op_input, self.output_schema.field_blobs())
240  elif version == 'fused_uint8rowwise':
241  net.__getattr__(layer_name + 'Fused8BitRowwise')(
242  op_input, self.output_schema.field_blobs())
243  else:
244  raise "Unsupported version of operator in SparseLookUp " +\
245  "layer: {0}".format(version)
246 
247  # deal with sparse features of id_list type
248  def _add_ops_id_list(self, net, version):
249  assert self.reducer in self._id_list_supported_reducers, (
250  "Unsupported reducer: {} for ID_LIST".format(self.reducer)
251  )
252  if self.reducer in ['Sum', 'Mean', 'WeightedSum', 'WeightedMean']:
253  op_input = [self.w,
254  self.input_record.items(),
255  self.input_record.lengths()]
256 
257  # For id list features, the behaviors of 'Sum' and
258  # 'WeightedSum' are identical, since we can regard the weight on each
259  # id as 1. Similarly, for 'Mean' and 'WeightedMean'.
260  if self.reducer == 'WeightedSum':
261  self.reducer = 'Sum'
262  elif self.reducer == 'WeightedMean':
263  self.reducer = 'Mean'
264 
265  layer_name = 'SparseLengths' + self.reducer
266  if version in ['fp32', 'fp16']:
267  # SparseLengths* Ops will accept either fp16 or fp32 embedding
268  # matrix and output fp32 pooled embedding
269  net.__getattr__(layer_name)(
270  op_input,
271  self.output_schema.field_blobs(),
272  )
273  elif version == 'uint8rowwise':
274  op_input.insert(len(op_input), self.scale_bias)
275  net.__getattr__(layer_name + '8BitsRowwise')(
276  op_input, self.output_schema.field_blobs())
277  elif version == 'fused_uint8rowwise':
278  net.__getattr__(layer_name + 'Fused8BitRowwise')(
279  op_input, self.output_schema.field_blobs())
280  else:
281  raise "Unsupported version of operator in SparseLookUp " +\
282  "layer: {0}".format(version)
283 
284  elif self.reducer == 'Sqrt':
285  sqrt_weight = net.LengthsToWeights(
286  [self.input_record.lengths()],
287  [net.NextScopedBlob('lengths_sqrt')],
288  power=0.5,
289  )
291  self.input_record.items(),
292  sqrt_weight,
293  'WeightedSum', net, version)
294 
295  elif self.reducer == 'None':
296  # Gather operator will gather the embedding for each id of
297  # each IdList.
298  self._gather_wrapper(net, version, self.input_record.items(),
299  self.output_schema.field_blobs())
300 
301  else:
302  table_rows = self._gather_wrapper(
303  net, version, self.input_record.items(), 'table_rows')
304 
305  segment_ids = net.LengthsToSegmentIds(
306  self.input_record.lengths(),
307  net.NextScopedBlob(self.input_record.lengths() + '_sid'))
308  net.__getattr__('SortedSegmentRange' + self.reducer)(
309  [table_rows, segment_ids],
310  self.output_schema.field_blobs(),
311  )
312 
313  # deal with sparse features of id_score_list type
314  def _add_ops_id_score_list(self, net, version):
315  assert self.reducer in self._id_score_list_supported_reducers, (
316  "Unsupported reducer: {} for ID_SCORE_LIST".format(self.reducer)
317  )
318  if self.reducer in ['WeightedSum', 'WeightedMean']:
320  self.input_record.keys(),
321  self.input_record.values(),
322  self.reducer, net, version)
323 
324  elif self.reducer in ['Sum', 'Mean']:
325  op_input = [self.w,
326  self.input_record.keys(),
327  self.input_record.lengths()]
328 
329  layer_name = 'SparseLengths' + self.reducer
330 
331  if version in ['fp32', 'fp16']:
332  net.__getattr__(layer_name)(
333  op_input,
334  self.output_schema.field_blobs(),
335  )
336  elif version == 'uint8rowwise':
337  net.__getattr__(layer_name + '8BitsRowwise')(
338  op_input, self.output_schema.field_blobs())
339  elif version == 'fused_uint8rowwise':
340  net.__getattr__(layer_name + 'Fused8BitRowwise')(
341  op_input, self.output_schema.field_blobs())
342  else:
343  raise "Unsupported version of operator in SparseLookUp " +\
344  "layer: {0}".format(version)
345 
346  elif self.reducer in ['PositionWeighted', 'RecencyWeighted']:
348  self.input_record.keys(),
349  self.external_weights,
350  'WeightedSum', net, version, grad_on_weights=1)
351 
352  elif self.reducer == 'None':
353  # Gather operator will gather the embedding for each id of
354  # each IdList.
355  self._gather_wrapper(net, version, self.input_record.keys(),
356  self.output_schema.field_blobs())
357  else:
358  raise "Only Sum, Mean, None are supported for IdScoreList input." +\
359  "Trying to create with {}".format(self.reducer)
360 
361  def _add_ops(self, net, version='fp32'):
362  if _is_id_list(self.input_record):
363  self._add_ops_id_list(net, version=version)
364  elif _is_id_score_list(self.input_record):
365  self._add_ops_id_score_list(net, version=version)
366  else:
367  raise "Unsupported input type {0}".format(self.input_record)
368 
369  def add_train_ops(self, net):
370  self._add_ops(net, self.trainer_version)
371 
372  def add_ops(self, net):
373  cur_scope = get_current_scope()
374  version = get_sparse_lookup_predictor_version(
375  **cur_scope.get(get_sparse_lookup_predictor_version.__name__,
376  {'version': 'fp32'}))
377 
378  # TODO(amalevich): Layer should not be responsible for decision about
379  # quantization.
380  if not self.support_8bit() and version in {'uint8rowwise',
381  'fused_uint8rowwise'}:
382  version = 'fp32'
383 
384  self._add_ops(net, version)
def _sparse_lengths_weighted_reducer(self, in_indices, weights, reducer, net, version, grad_on_weights=0)
def _add_ops(self, net, version='fp32')
def _gather_wrapper(self, net, version, in_indices, out)