Caffe2 - C++ API
A deep learning, cross platform ML framework
fused_8bit_rowwise_embedding_lookup.cc
1 #include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h"
2 
3 #include "caffe2/core/types.h"
4 #include "caffe2/perfkernels/common.h"
5 #include "caffe2/perfkernels/typed_axpy.h"
6 #include "caffe2/utils/cpuid.h"
7 #include "caffe2/utils/eigen_utils.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
16 template <
17  typename IndexType,
18  typename InType,
19  typename OutType,
20  bool IS_WEIGHT_POSITIONAL = false>
21 static bool Fused8BitRowwiseEmbeddingLookupGenericSlow(
22  const int64_t block_size,
23  const int64_t output_size,
24  const int64_t index_size,
25  const int64_t data_size,
26  const InType* input,
27  const IndexType* indices,
28  const int* lengths,
29  const float* weights, // optional, can be null for sum reducer
30  bool normalize_by_lengths,
31  OutType* out) {
32  // block_size is the number of elements and fused_block_size is the size of
33  // an entire row, including scale and bias.
34  const auto scale_bias_offset = 8 / sizeof(InType);
35  const int64_t fused_block_size = block_size + scale_bias_offset;
36  int64_t current = 0;
37  for (int m = 0; m < output_size; ++m) {
38  memset(out, 0, sizeof(OutType) * block_size);
39  EigenVectorArrayMap<OutType> out_vector(out, block_size);
40  if (current + lengths[m] > index_size) {
41  return false;
42  }
43  for (int i = 0; i < lengths[m]; ++i) {
44  int64_t idx = indices[current];
45  if (idx < 0 || idx >= data_size) {
46  return false;
47  }
48 #ifdef __GNUC__
49  if (current + 1 < index_size) {
50  __builtin_prefetch(
51  input + fused_block_size * indices[current + 1], 0, 1);
52  }
53 #endif // __GNUC__
54 
55  const float* scale_bias = reinterpret_cast<const float*>(
56  input + fused_block_size * indices[current] + block_size);
57 
58  float weight = 1.0f;
59  if (weights) {
60  weight = weights[IS_WEIGHT_POSITIONAL ? i : current];
61  }
62  const float scale = weight * scale_bias[0];
63  const float bias = weight * scale_bias[1];
64 
65  TypedAxpy<InType, OutType>(
66  block_size, scale, input + fused_block_size * indices[current], out);
67 
68  out_vector += bias;
69 
70  ++current;
71  }
72  if (normalize_by_lengths && lengths[m]) {
73  // hack: context is not really used
74  math::Scale<float, OutType, CPUContext>(
75  block_size, 1.f / lengths[m], out, out, nullptr);
76  }
77  out += block_size;
78  }
79  return current == index_size;
80 }
81 
82 // Proxy back to generic implementation
83 #define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(IndexType, OutType) \
84  bool \
85  Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__base( \
86  const int64_t block_size, \
87  const int64_t output_size, \
88  const int64_t index_size, \
89  const int64_t data_size, \
90  const uint8_t* input, \
91  const IndexType* indices, \
92  const int* lengths, \
93  const float* weights, \
94  bool normalize_by_lengths, \
95  OutType* out) { \
96  return Fused8BitRowwiseEmbeddingLookupGenericSlow< \
97  IndexType, \
98  uint8_t, \
99  OutType, \
100  false>( \
101  block_size, \
102  output_size, \
103  index_size, \
104  data_size, \
105  input, \
106  indices, \
107  lengths, \
108  weights, \
109  normalize_by_lengths, \
110  out); \
111  } \
112  decltype( \
113  Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__base) \
114  Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__avx2_fma; \
115  bool Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType( \
116  const int64_t block_size, \
117  const int64_t output_size, \
118  const int64_t index_size, \
119  const int64_t data_size, \
120  const uint8_t* input, \
121  const IndexType* indices, \
122  const int* lengths, \
123  const float* weights, \
124  bool normalize_by_lengths, \
125  OutType* out) { \
126  const int32_t one = 1; \
127  CAFFE_ENFORCE_EQ( \
128  reinterpret_cast<const uint8_t*>(&one)[0], \
129  1, \
130  "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \
131  AVX2_FMA_DO( \
132  Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false, \
133  block_size, \
134  output_size, \
135  index_size, \
136  data_size, \
137  input, \
138  indices, \
139  lengths, \
140  weights, \
141  normalize_by_lengths, \
142  out); \
143  BASE_DO( \
144  Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false, \
145  block_size, \
146  output_size, \
147  index_size, \
148  data_size, \
149  input, \
150  indices, \
151  lengths, \
152  weights, \
153  normalize_by_lengths, \
154  out); \
155  } \
156  template <> \
157  void Fused8BitRowwiseEmbeddingLookup<IndexType, uint8_t, OutType, false>( \
158  const int64_t block_size, \
159  const int64_t output_size, \
160  const int64_t index_size, \
161  const int64_t data_size, \
162  const uint8_t* input, \
163  const IndexType* indices, \
164  const int* lengths, \
165  const float* weights, \
166  bool normalize_by_lengths, \
167  OutType* out) { \
168  bool success = \
169  Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType( \
170  block_size, \
171  output_size, \
172  index_size, \
173  data_size, \
174  input, \
175  indices, \
176  lengths, \
177  weights, \
178  normalize_by_lengths, \
179  out); \
180  if (success) { \
181  return; \
182  } \
183  int64_t current = 0; \
184  for (int m = 0; m < output_size; ++m) { \
185  for (int i = 0; i < lengths[m]; ++i) { \
186  CAFFE_ENFORCE_LT(current, index_size); \
187  IndexType idx = indices[current]; \
188  CAFFE_ENFORCE( \
189  0 <= idx && idx < data_size, \
190  "Index ", \
191  current, \
192  " is out of bounds: ", \
193  idx, \
194  ", range 0 to ", \
195  data_size); \
196  ++current; \
197  } \
198  } \
199  CAFFE_ENFORCE_EQ( \
200  current, \
201  index_size, \
202  "Your input seems to be incorrect: the sum of lengths values should be " \
203  "the size of the indices tensor, but it appears not."); \
204  }
205 
206 FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t, float);
207 FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int64_t, float);
208 
209 #undef FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION
210 
211 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13