Caffe2 - C++ API
A deep learning, cross platform ML framework
embedding_lookup.cc
1 
17 #include "caffe2/perfkernels/embedding_lookup.h"
18 
19 #include "caffe2/core/types.h"
20 #include "caffe2/perfkernels/common.h"
21 #include "caffe2/perfkernels/typed_axpy.h"
22 #include "caffe2/utils/cpuid.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 // Base implementation does runtime dispatch for each segment of reduction
28 template <typename IndexType, typename InType, typename OutType>
29 static void EmbeddingLookupGenericSlow(
30  const TIndex block_size,
31  const TIndex output_size,
32  const TIndex index_size,
33  const TIndex data_size,
34  const InType* input,
35  const IndexType* indices,
36  const int* lengths,
37  const float* weights, // optional, can be null for sum reducer
38  const float* scale_bias, // optional scale & bias params for uint8 input
39  bool normalize_by_lengths,
40  OutType* out) {
41  TIndex current = 0;
42  for (int m = 0; m < output_size; ++m) {
43  memset(out, 0, sizeof(OutType) * block_size);
44  EigenVectorArrayMap<OutType> out_vector(out, block_size);
45  for (int i = 0; i < lengths[m]; ++i) {
46  CAFFE_ENFORCE_LT(current, index_size);
47  TIndex idx = indices[current];
48  CAFFE_ENFORCE(
49  0 <= idx && idx < data_size,
50  "Index ",
51  current,
52  " is out of bounds: ",
53  idx,
54  ", range 0 to ",
55  data_size);
56  CAFFE_ENFORCE_LT(idx, data_size);
57 #ifdef __GNUC__
58  if (current + 1 < index_size) {
59  __builtin_prefetch(input + block_size * indices[current + 1], 0, 1);
60  }
61 #endif // __GNUC__
62 
63  float w = 1.f, b = 0.f;
64  if (weights) {
65  w = weights[current];
66  }
67  if (scale_bias) {
68  b = w * scale_bias[2 * indices[current] + 1];
69  w = w * scale_bias[2 * indices[current]];
70  }
71 
72  TypedAxpy<InType, OutType>(
73  block_size, w, input + block_size * indices[current], out);
74 
75  if (scale_bias) {
76  out_vector = out_vector + b;
77  }
78 
79  ++current;
80  }
81  if (normalize_by_lengths && lengths[m]) {
82  // hack: context is not really used
83  math::Scale<OutType, CPUContext>(
84  block_size, 1.f / lengths[m], out, out, nullptr);
85  }
86  out += block_size;
87  }
88  CAFFE_ENFORCE_EQ(
89  current,
90  index_size,
91  "Your input seems to be incorrect: the sum of lengths values should be "
92  "the size of the indices tensor, but it appears not.");
93 }
94 
95 // Proxy back to generic implementation
96 #define EMBEDDING_SPECIALIZATION(IndexType, InType, OutType) \
97  void EmbeddingLookup_##IndexType##_##InType##_##OutType##__base( \
98  const TIndex block_size, \
99  const TIndex output_size, \
100  const TIndex index_size, \
101  const TIndex data_size, \
102  const InType* input, \
103  const IndexType* indices, \
104  const int* lengths, \
105  const float* weights, \
106  const float* scale_bias, \
107  bool normalize_by_lengths, \
108  OutType* out) { \
109  EmbeddingLookupGenericSlow<IndexType, InType, OutType>( \
110  block_size, \
111  output_size, \
112  index_size, \
113  data_size, \
114  input, \
115  indices, \
116  lengths, \
117  weights, \
118  scale_bias, \
119  normalize_by_lengths, \
120  out); \
121  } \
122  template <> \
123  void EmbeddingLookup( \
124  const TIndex block_size, \
125  const TIndex output_size, \
126  const TIndex index_size, \
127  const TIndex data_size, \
128  const InType* input, \
129  const IndexType* indices, \
130  const int* lengths, \
131  const float* weights, \
132  const float* scale_bias, \
133  bool normalize_by_lengths, \
134  OutType* out) { \
135  AVX2_FMA_DO( \
136  EmbeddingLookup_##IndexType##_##InType##_##OutType, \
137  block_size, \
138  output_size, \
139  index_size, \
140  data_size, \
141  input, \
142  indices, \
143  lengths, \
144  weights, \
145  scale_bias, \
146  normalize_by_lengths, \
147  out); \
148  BASE_DO( \
149  EmbeddingLookup_##IndexType##_##InType##_##OutType, \
150  block_size, \
151  output_size, \
152  index_size, \
153  data_size, \
154  input, \
155  indices, \
156  lengths, \
157  weights, \
158  scale_bias, \
159  normalize_by_lengths, \
160  out); \
161  }
162 
163 EMBEDDING_SPECIALIZATION(int32_t, float, float);
164 EMBEDDING_SPECIALIZATION(int64_t, float, float);
165 EMBEDDING_SPECIALIZATION(int32_t, float16, float);
166 EMBEDDING_SPECIALIZATION(int64_t, float16, float);
167 EMBEDDING_SPECIALIZATION(int32_t, uint8_t, float);
168 EMBEDDING_SPECIALIZATION(int64_t, uint8_t, float);
169 
170 #undef EMBEDDING_SPECIALIZATION
171 
172 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.