Caffe2 - Python API
A deep learning, cross platform ML framework
sparse_lookup.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package sparse_lookup
17 # Module caffe2.python.layers.sparse_lookup
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python.helpers.arg_scope import get_current_scope
24 from caffe2.python import schema
25 from caffe2.python.layers.layers import (
26  get_categorical_limit,
27  get_key,
28  IdList,
29  IdScoreList,
30  LayerPsParam,
31  ModelLayer,
32 )
33 import collections
34 import functools
35 import math
36 import numpy as np
37 import operator
38 
39 
40 def get_sparse_lookup_predictor_version(version):
41  assert version in {'fp32', 'fp16', 'uint8rowwise', 'fused_uint8rowwise'},\
42  "Unexpected version of sparse_lookup layer {0}".format(version)
43  return version
44 
45 
46 def _is_id_list(input_record):
47  return schema.equal_schemas(input_record, IdList)
48 
49 
50 def _is_id_score_list(input_record):
51  return schema.equal_schemas(input_record,
52  IdScoreList,
53  check_field_types=False)
54 
55 
56 class SparseLookup(ModelLayer):
57  _id_list_supported_reducers = [
58  'LogMeanExp', 'LogSumExp', 'Max', 'Mean', 'Sum',
59  'WeightedSum', 'WeightedMean', 'Sqrt', 'None']
60 
61  _id_score_list_supported_reducers = [
62  'PositionWeighted', 'Mean', 'Sum', 'WeightedSum', 'WeightedMean', 'None']
63 
64  def __init__(self, model, input_record, inner_shape, reducer,
65  weight_init=None, weight_optim=None,
66  name='sparse_lookup', regularizer=None, **kwargs):
67 
68  super(SparseLookup, self).__init__(model, name, input_record, **kwargs)
69 
70  # TODO Add some asserts about input type
71  if isinstance(inner_shape, int):
72  inner_shape = [inner_shape]
73  assert isinstance(inner_shape, list) or isinstance(inner_shape, tuple),\
74  "Unexpected type for inner_shape, expected list or tuple, got {0}".\
75  format(type(inner_shape))
76 
77  if reducer == "PositionWeighted":
78  assert _is_id_score_list(self.input_record), (
79  "PositionWeighted only support IdScoreList, but got {} " +
80  "please use PositionWeighted layer to convert IdList " +
81  "to IdScoreList").format(repr(self.input_record))
82  self.external_weights = input_record.values()
83  self.reducer = reducer
84 
85  input_dim = get_categorical_limit(input_record)
86  assert input_dim > 0, (
87  "{} should have categorical limit > 0, but got {}".format(
88  get_key(input_record)(), input_dim))
89 
90  scale = math.sqrt(1.0 / input_dim)
91  self.shape = [input_dim] + inner_shape
92  self.weight_init = weight_init if weight_init else (
93  'UniformFill', {'min': -scale, 'max': scale})
94 
95  if _is_id_list(self.input_record):
96  sparse_key = self.input_record.items()
97  elif _is_id_score_list(self.input_record):
98  sparse_key = self.input_record.keys()
99  else:
100  raise NotImplementedError()
101 
102  if self.input_record.lengths.metadata:
103  avg_length = self.input_record.lengths.metadata.expected_value
104  else:
105  avg_length = None
106 
107  self.w = self.create_param(
108  param_name='w',
109  shape=self.shape,
110  initializer=self.weight_init,
111  optimizer=weight_optim,
112  ps_param=LayerPsParam(
113  sparse_key=sparse_key,
114  average_length=avg_length),
115  regularizer=regularizer
116  )
117 
118  self.scale_bias_init = ('ConstantFill', {'value': 0.0})
119 
120  self.scale_bias = self.create_param(
121  param_name='scale_bias',
122  shape=[],
123  initializer=self.scale_bias_init,
124  optimizer=model.NoOptim,
125  )
126 
128  (np.float32, inner_shape),
129  self.get_next_blob_reference('output'),
130  )
131 
132  def get_memory_usage(self):
133  return functools.reduce(operator.mul, self.shape) * 4
134 
135  def get_fp16_compatible_parameters(self):
136  return [self.w]
137 
138  def get_8bits_compatible_parameters(self, fused=True):
139  # Rowwise quantization makes sense only if shape it's 2D matrix with
140  # second dimension >= 8
141  if len(self.shape) != 2 or self.shape[1] < 8:
142  return []
143  if fused:
144  RowwiseQuantized8BitsWeight = collections.namedtuple(
145  'RowwiseQuantized8BitsWeight', 'w'
146  )
147  return [RowwiseQuantized8BitsWeight(self.w)]
148  else:
149  RowwiseQuantized8BitsWeight = collections.namedtuple(
150  'RowwiseQuantized8BitsWeight', 'w, scale_bias'
151  )
152  return [RowwiseQuantized8BitsWeight(self.w, self.scale_bias)]
153 
154  def _gather_wrapper(self, net, version, in_indices, out):
155  # Gather can work on all kinds of input data types, and output
156  # data with the same type. Convert the output of Gather to float,
157  # because the follow-up Ops expect fp32.
158  if version == 'fp32':
159  return net.Gather([self.w, in_indices], out)
160  elif version == 'fp16':
161  gathered_w = net.Gather([self.w, in_indices], 'gathered_w')
162 
163  return net.HalfToFloat(gathered_w, out)
164  elif version == 'uint8rowwise':
165  gathered_w = net.Gather([self.w, in_indices], 'gathered_w')
166  gathered_scale_bias = net.Gather(
167  [self.scale_bias, in_indices],
168  'gathered_scale_bias'
169  )
170 
171  return net.Rowwise8BitQuantizedToFloat(
172  [gathered_w, gathered_scale_bias], out)
173  elif version == 'fused_uint8rowwise':
174  gathered_w = net.Gather([self.w, in_indices], 'gathered_w')
175  return net.Fused8BitRowwiseQuantizedToFloat(gathered_w, out)
176  else:
177  raise "Unsupported version of operators in SparseLookup " +\
178  "layer: {0}".format(version)
179 
180  def _sparse_lengths_weighted_reducer(
181  self, in_indices, weights, reducer,
182  net, version, grad_on_weights=0):
183  op_input = [
184  self.w,
185  weights,
186  in_indices,
187  self.input_record.lengths()
188  ]
189  layer_name = 'SparseLengths' + reducer
190 
191  if version in ['fp32', 'fp16']:
192  # SparseLengths* Ops with engine='fp16' will accept either
193  # fp16 or fp32 embedding matrix and output fp32 pooled embedding
194  net.__getattr__(layer_name)(
195  op_input,
196  self.output_schema.field_blobs(),
197  grad_on_weights=grad_on_weights,
198  engine='fp16',
199  )
200  elif version == 'uint8rowwise':
201  op_input.insert(len(op_input), self.scale_bias)
202  net.__getattr__(layer_name + '8BitsRowwise')(
203  op_input, self.output_schema.field_blobs())
204  elif version == 'fused_uint8rowwise':
205  net.__getattr__(layer_name + 'Fused8BitRowwise')(
206  op_input, self.output_schema.field_blobs())
207  else:
208  raise "Unsupported version of operator in SparseLookUp " +\
209  "layer: {0}".format(version)
210 
211  # deal with sparse features of id_list type
212  def _add_ops_id_list(self, net, version):
213  assert self.reducer in self._id_list_supported_reducers, (
214  "Unsupported reducer: {} for ID_LIST".format(self.reducer)
215  )
216  if self.reducer in ['Sum', 'Mean', 'WeightedSum', 'WeightedMean']:
217  op_input = [self.w,
218  self.input_record.items(),
219  self.input_record.lengths()]
220 
221  # For id list features, the behaviors of 'Sum' and
222  # 'WeightedSum' are identical, since we can regard the weight on each
223  # id as 1. Similarly, for 'Mean' and 'WeightedMean'.
224  if self.reducer == 'WeightedSum':
225  self.reducer = 'Sum'
226  elif self.reducer == 'WeightedMean':
227  self.reducer = 'Mean'
228 
229  layer_name = 'SparseLengths' + self.reducer
230  if version in ['fp32', 'fp16']:
231  # SparseLengths* Ops with engine='fp16' will accept either
232  # fp16 or fp32 embedding matrix and output fp32 pooled embedding
233  net.__getattr__(layer_name)(
234  op_input,
235  self.output_schema.field_blobs(),
236  engine='fp16',
237  )
238  elif version == 'uint8rowwise':
239  op_input.insert(len(op_input), self.scale_bias)
240  net.__getattr__(layer_name + '8BitsRowwise')(
241  op_input, self.output_schema.field_blobs())
242  elif version == 'fused_uint8rowwise':
243  net.__getattr__(layer_name + 'Fused8BitRowwise')(
244  op_input, self.output_schema.field_blobs())
245  else:
246  raise "Unsupported version of operator in SparseLookUp " +\
247  "layer: {0}".format(version)
248 
249  elif self.reducer == 'Sqrt':
250  sqrt_weight = net.LengthsToWeights(
251  [self.input_record.lengths()],
252  [net.NextScopedBlob('lengths_sqrt')],
253  power=0.5,
254  )
256  self.input_record.items(),
257  sqrt_weight,
258  'WeightedSum', net, version)
259 
260  elif self.reducer == 'None':
261  # Gather operator will gather the embedding for each id of
262  # each IdList.
263  self._gather_wrapper(net, version, self.input_record.items(),
264  self.output_schema.field_blobs())
265 
266  else:
267  table_rows = self._gather_wrapper(
268  net, version, self.input_record.items(), 'table_rows')
269 
270  segment_ids = net.LengthsToSegmentIds(
271  self.input_record.lengths(),
272  self.input_record.lengths() + '_sid')
273  net.__getattr__('SortedSegmentRange' + self.reducer)(
274  [table_rows, segment_ids],
275  self.output_schema.field_blobs(),
276  engine='fp16',
277  )
278 
279  # deal with sparse features of id_score_list type
280  def _add_ops_id_score_list(self, net, version):
281  assert self.reducer in self._id_score_list_supported_reducers, (
282  "Unsupported reducer: {} for ID_SCORE_LIST".format(self.reducer)
283  )
284  if self.reducer in ['WeightedSum', 'WeightedMean']:
286  self.input_record.keys(),
287  self.input_record.values(),
288  self.reducer, net, version)
289 
290  elif self.reducer in ['Sum', 'Mean']:
291  op_input = [self.w,
292  self.input_record.keys(),
293  self.input_record.lengths()]
294 
295  layer_name = 'SparseLengths' + self.reducer
296 
297  if version in ['fp32', 'fp16']:
298  net.__getattr__(layer_name)(
299  op_input,
300  self.output_schema.field_blobs(),
301  engine='fp16',
302  )
303  elif version == 'uint8rowwise':
304  net.__getattr__(layer_name + '8BitsRowwise')(
305  op_input, self.output_schema.field_blobs())
306  elif version == 'fused_uint8rowwise':
307  net.__getattr__(layer_name + 'Fused8BitRowwise')(
308  op_input, self.output_schema.field_blobs())
309  else:
310  raise "Unsupported version of operator in SparseLookUp " +\
311  "layer: {0}".format(version)
312 
313  elif self.reducer == 'PositionWeighted':
315  self.input_record.keys(),
316  self.external_weights,
317  'WeightedSum', net, version, grad_on_weights=1)
318 
319  elif self.reducer == 'None':
320  # Gather operator will gather the embedding for each id of
321  # each IdList.
322  self._gather_wrapper(net, version, self.input_record.keys(),
323  self.output_schema.field_blobs())
324  else:
325  raise "Only Sum, Mean, None are supported for IdScoreList input." +\
326  "Trying to create with {}".format(self.reducer)
327 
328  def add_ops(self, net):
329  cur_scope = get_current_scope()
330  version = get_sparse_lookup_predictor_version(
331  **cur_scope.get(get_sparse_lookup_predictor_version.__name__,
332  {'version': 'fp32'}))
333 
334  if _is_id_list(self.input_record):
335  self._add_ops_id_list(net, version=version)
336  elif _is_id_score_list(self.input_record):
337  self._add_ops_id_score_list(net, version=version)
338  else:
339  raise "Unsupported input type {0}".format(self.input_record)
def _sparse_lengths_weighted_reducer(self, in_indices, weights, reducer, net, version, grad_on_weights=0)
def _gather_wrapper(self, net, version, in_indices, out)