Caffe2 - C++ API
A deep learning, cross platform ML framework
embedding_lookup.cc
1 #include "caffe2/perfkernels/embedding_lookup.h"
2 
3 #include "caffe2/core/types.h"
4 #include "caffe2/perfkernels/common.h"
5 #include "caffe2/perfkernels/typed_axpy.h"
6 #include "caffe2/utils/eigen_utils.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
15 template <
16  typename IndexType,
17  typename InType,
18  typename OutType,
19  bool IS_WEIGHT_POSITIONAL = false>
20 static bool EmbeddingLookupGenericSlow(
21  const int64_t block_size,
22  const int64_t output_size,
23  const int64_t index_size,
24  const int64_t data_size,
25  const InType* input,
26  const IndexType* indices,
27  const int* lengths,
28  const float* weights, // optional, can be null for sum reducer
29  const float* scale_bias, // optional scale & bias params for uint8 input
30  bool normalize_by_lengths,
31  OutType* out) {
32  int64_t current = 0;
33  for (int m = 0; m < output_size; ++m) {
34  memset(out, 0, sizeof(OutType) * block_size);
35  EigenVectorArrayMap<OutType> out_vector(out, block_size);
36  if (current + lengths[m] > index_size) {
37  return false;
38  }
39  for (int i = 0; i < lengths[m]; ++i) {
40  int64_t idx = indices[current];
41  if (idx < 0 || idx >= data_size) {
42  return false;
43  }
44 #ifdef __GNUC__
45  if (current + 1 < index_size) {
46  __builtin_prefetch(input + block_size * indices[current + 1], 0, 1);
47  }
48 #endif // __GNUC__
49 
50  float w = 1.f, b = 0.f;
51  if (weights) {
52  w = weights[IS_WEIGHT_POSITIONAL ? i : current];
53  }
54  if (scale_bias) {
55  b = w * scale_bias[2 * indices[current] + 1];
56  w = w * scale_bias[2 * indices[current]];
57  }
58 
59  TypedAxpy<InType, OutType>(
60  block_size, w, input + block_size * indices[current], out);
61 
62  if (scale_bias) {
63  out_vector = out_vector + b;
64  }
65 
66  ++current;
67  }
68  if (normalize_by_lengths && lengths[m]) {
69  // hack: context is not really used
70  math::Scale<float, OutType, CPUContext>(
71  block_size, 1.f / lengths[m], out, out, nullptr);
72  }
73  out += block_size;
74  }
75  return current == index_size;
76 }
77 
78 // Proxy back to generic implementation
79 #define EMBEDDING_SPECIALIZATION( \
80  IndexType, InTypeName, InType, OutType, IS_WEIGHT_POSITIONAL) \
81  bool \
82  EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base( \
83  const int64_t block_size, \
84  const int64_t output_size, \
85  const int64_t index_size, \
86  const int64_t data_size, \
87  const InType* input, \
88  const IndexType* indices, \
89  const int* lengths, \
90  const float* weights, \
91  const float* scale_bias, \
92  bool normalize_by_lengths, \
93  OutType* out) { \
94  return EmbeddingLookupGenericSlow< \
95  IndexType, \
96  InType, \
97  OutType, \
98  IS_WEIGHT_POSITIONAL>( \
99  block_size, \
100  output_size, \
101  index_size, \
102  data_size, \
103  input, \
104  indices, \
105  lengths, \
106  weights, \
107  scale_bias, \
108  normalize_by_lengths, \
109  out); \
110  } \
111  decltype( \
112  EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \
113  EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \
114  bool \
115  EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \
116  const int64_t block_size, \
117  const int64_t output_size, \
118  const int64_t index_size, \
119  const int64_t data_size, \
120  const InType* input, \
121  const IndexType* indices, \
122  const int* lengths, \
123  const float* weights, \
124  const float* scale_bias, \
125  bool normalize_by_lengths, \
126  OutType* out) { \
127  if (std::is_same<InType, uint8_t>::value) { \
128  CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr"); \
129  } else { \
130  CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); \
131  } \
132  AVX2_FMA_DO( \
133  EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \
134  block_size, \
135  output_size, \
136  index_size, \
137  data_size, \
138  input, \
139  indices, \
140  lengths, \
141  weights, \
142  scale_bias, \
143  normalize_by_lengths, \
144  out); \
145  BASE_DO( \
146  EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \
147  block_size, \
148  output_size, \
149  index_size, \
150  data_size, \
151  input, \
152  indices, \
153  lengths, \
154  weights, \
155  scale_bias, \
156  normalize_by_lengths, \
157  out); \
158  } \
159  template <> \
160  void EmbeddingLookup<IndexType, InType, OutType, IS_WEIGHT_POSITIONAL>( \
161  const int64_t block_size, \
162  const int64_t output_size, \
163  const int64_t index_size, \
164  const int64_t data_size, \
165  const InType* input, \
166  const IndexType* indices, \
167  const int* lengths, \
168  const float* weights, \
169  const float* scale_bias, \
170  bool normalize_by_lengths, \
171  OutType* out) { \
172  bool success = \
173  EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \
174  block_size, \
175  output_size, \
176  index_size, \
177  data_size, \
178  input, \
179  indices, \
180  lengths, \
181  weights, \
182  scale_bias, \
183  normalize_by_lengths, \
184  out); \
185  if (success) { \
186  return; \
187  } \
188  int64_t current = 0; \
189  for (int m = 0; m < output_size; ++m) { \
190  for (int i = 0; i < lengths[m]; ++i) { \
191  CAFFE_ENFORCE_LT(current, index_size); \
192  IndexType idx = indices[current]; \
193  CAFFE_ENFORCE( \
194  0 <= idx && idx < data_size, \
195  "Index ", \
196  current, \
197  " is out of bounds: ", \
198  idx, \
199  ", range 0 to ", \
200  data_size); \
201  ++current; \
202  } \
203  } \
204  CAFFE_ENFORCE_EQ( \
205  current, \
206  index_size, \
207  "Your input seems to be incorrect: the sum of lengths values should be " \
208  "the size of the indices tensor, but it appears not."); \
209  }
210 
211 EMBEDDING_SPECIALIZATION(int32_t, float, float, float, false);
212 EMBEDDING_SPECIALIZATION(int64_t, float, float, float, false);
213 EMBEDDING_SPECIALIZATION(int32_t, half, at::Half, float, false);
214 EMBEDDING_SPECIALIZATION(int64_t, half, at::Half, float, false);
215 EMBEDDING_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false);
216 EMBEDDING_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false);
217 
218 EMBEDDING_SPECIALIZATION(int32_t, float, float, float, true);
219 EMBEDDING_SPECIALIZATION(int64_t, float, float, float, true);
220 EMBEDDING_SPECIALIZATION(int32_t, half, at::Half, float, true);
221 EMBEDDING_SPECIALIZATION(int64_t, half, at::Half, float, true);
222 EMBEDDING_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true);
223 EMBEDDING_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true);
224 
225 #undef EMBEDDING_SPECIALIZATION
226 
227 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13