Caffe2 - C++ API
A deep learning, cross platform ML framework
fused_8bit_rowwise_embedding_lookup.cc
1 
17 #include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h"
18 
19 #include "caffe2/core/types.h"
20 #include "caffe2/perfkernels/common.h"
21 #include "caffe2/perfkernels/typed_axpy.h"
22 #include "caffe2/utils/cpuid.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 // Base implementation does runtime dispatch for each segment of reduction
28 template <typename IndexType, typename InType, typename OutType>
29 static void Fused8BitRowwiseEmbeddingLookupGenericSlow(
30  const TIndex block_size,
31  const TIndex output_size,
32  const TIndex index_size,
33  const TIndex data_size,
34  const InType* input,
35  const IndexType* indices,
36  const int* lengths,
37  const float* weights, // optional, can be null for sum reducer
38  bool normalize_by_lengths,
39  OutType* out) {
40  // block_size is the number of elements and fused_block_size is the size of
41  // an entire row, including scale and bias.
42  const auto scale_bias_offset = 8 / sizeof(InType);
43  const TIndex fused_block_size = block_size + scale_bias_offset;
44  TIndex current = 0;
45  for (int m = 0; m < output_size; ++m) {
46  memset(out, 0, sizeof(OutType) * block_size);
47  EigenVectorArrayMap<OutType> out_vector(out, block_size);
48  for (int i = 0; i < lengths[m]; ++i) {
49  CAFFE_ENFORCE_LT(current, index_size);
50  TIndex idx = indices[current];
51  CAFFE_ENFORCE(
52  0 <= idx && idx < data_size,
53  "Index ",
54  current,
55  " is out of bounds: ",
56  idx,
57  ", range 0 to ",
58  data_size);
59  CAFFE_ENFORCE_LT(idx, data_size);
60 #ifdef __GNUC__
61  if (current + 1 < index_size) {
62  __builtin_prefetch(
63  input + fused_block_size * indices[current + 1], 0, 1);
64  }
65 #endif // __GNUC__
66 
67  const float* scale_bias = reinterpret_cast<const float*>(
68  input + fused_block_size * indices[current] + block_size);
69 
70  const float weight = weights ? weights[current] : 1.0f;
71  const float scale = weight * scale_bias[0];
72  const float bias = weight * scale_bias[1];
73 
74  TypedAxpy<InType, OutType>(
75  block_size, scale, input + fused_block_size * indices[current], out);
76 
77  out_vector += bias;
78 
79  ++current;
80  }
81  if (normalize_by_lengths && lengths[m]) {
82  // hack: context is not really used
83  math::Scale<OutType, CPUContext>(
84  block_size, 1.f / lengths[m], out, out, nullptr);
85  }
86  out += block_size;
87  }
88  CAFFE_ENFORCE_EQ(
89  current,
90  index_size,
91  "Your input seems to be incorrect: the sum of lengths values should be "
92  "the size of the indices tensor, but it appears not.");
93 }
94 
95 // Proxy back to generic implementation
96 #define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION( \
97  IndexType, InType, OutType) \
98  void \
99  Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##__base( \
100  const TIndex block_size, \
101  const TIndex output_size, \
102  const TIndex index_size, \
103  const TIndex data_size, \
104  const InType* input, \
105  const IndexType* indices, \
106  const int* lengths, \
107  const float* weights, \
108  bool normalize_by_lengths, \
109  OutType* out) { \
110  Fused8BitRowwiseEmbeddingLookupGenericSlow<IndexType, InType, OutType>( \
111  block_size, \
112  output_size, \
113  index_size, \
114  data_size, \
115  input, \
116  indices, \
117  lengths, \
118  weights, \
119  normalize_by_lengths, \
120  out); \
121  } \
122  template <> \
123  void Fused8BitRowwiseEmbeddingLookup( \
124  const TIndex block_size, \
125  const TIndex output_size, \
126  const TIndex index_size, \
127  const TIndex data_size, \
128  const InType* input, \
129  const IndexType* indices, \
130  const int* lengths, \
131  const float* weights, \
132  bool normalize_by_lengths, \
133  OutType* out) { \
134  const int32_t one = 1; \
135  CAFFE_ENFORCE_EQ( \
136  reinterpret_cast<const uint8_t*>(&one)[0], \
137  1, \
138  "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \
139  AVX2_FMA_DO( \
140  Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType, \
141  block_size, \
142  output_size, \
143  index_size, \
144  data_size, \
145  input, \
146  indices, \
147  lengths, \
148  weights, \
149  normalize_by_lengths, \
150  out); \
151  BASE_DO( \
152  Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType, \
153  block_size, \
154  output_size, \
155  index_size, \
156  data_size, \
157  input, \
158  indices, \
159  lengths, \
160  weights, \
161  normalize_by_lengths, \
162  out); \
163  }
164 
165 FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t, uint8_t, float);
166 FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int64_t, uint8_t, float);
167 
168 #undef FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION
169 
170 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.