Caffe2 - C++ API
A deep learning, cross platform ML framework
lengths_reducer_fused_8bit_rowwise_ops.h
1 #ifndef CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_
2 #define CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/operators/fused_rowwise_8bit_conversion_ops.h"
8 #include "caffe2/operators/reducer_functors.h"
9 #include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h"
10 #include "caffe2/utils/math.h"
11 
12 namespace caffe2 {
13 
14 template <class Context, bool with_weights = 0, bool is_mean = 0>
15 class SparseLengthsFused8BitRowwiseOp : public Operator<Context> {
16  public:
17  static_assert(
18  !(with_weights && is_mean),
19  "Cannot have with_weights and is_mean a the same time");
20 
21  USE_OPERATOR_CONTEXT_FUNCTIONS;
22  USE_SIMPLE_CTOR_DTOR(SparseLengthsFused8BitRowwiseOp)
23 
24  bool RunOnDevice() override {
26  this, Input(INDICES));
27  }
28 
29  template <typename IndexType>
30  bool DoRunWithType() {
31  const auto& data = Input(DATA);
32  const auto& indices = Input(INDICES);
33  const auto& lengths = Input(LENGTHS);
34 
35  CAFFE_ENFORCE_EQ(indices.dim(), 1, "INDICES must be a vector");
36  CAFFE_ENFORCE_EQ(lengths.dim(), 1, "LENGTHS must be a vector");
37 
38  const float* weights = nullptr;
39  if (with_weights) {
40  const auto& weights_input = Input(WEIGHTS);
41  CAFFE_ENFORCE_EQ(weights_input.dim(), 1, "WEIGHTS must be a vector");
42  CAFFE_ENFORCE_EQ(
43  weights_input.numel(),
44  indices.numel(),
45  "WEIGHTS should have the same length as INDICES.");
46  weights = weights_input.template data<float>();
47  }
48 
49  CAFFE_ENFORCE_GT(data.size(1), 8, "DATA must have more than 8 columns");
50  // Subtract 8 from the #columns of data for the 4 bytes for scale and 4
51  // bytes for bias that we use in the fused representation (per row).
52  const std::vector<int64_t> shape = {lengths.size(0), data.size(1) - 8};
53  auto* output = Output(0, shape, at::dtype<float>());
54 
56  /*block_size=*/output->size(1),
57  /*output_size=*/output->size(0),
58  /*index_size=*/indices.numel(),
59  /*data_size=*/data.size(0),
60  /*input=*/data.template data<uint8_t>(),
61  /*indices=*/indices.template data<IndexType>(),
62  /*lengths=*/lengths.template data<int>(),
63  /*weights=*/weights,
64  /*normalize_by_lengths=*/is_mean,
65  /*out=*/output->template mutable_data<float>());
66 
67  return true;
68  }
69 
70  enum {
71  DATA = 0,
72  WEIGHTS = 1,
73  INDICES = 1 + with_weights,
74  LENGTHS = 2 + with_weights,
75  };
76 };
77 
78 } // namespace caffe2
79 
80 #endif // CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
void Fused8BitRowwiseEmbeddingLookup(const std::int64_t block_size, const std::int64_t output_size, const std::int64_t index_size, const std::int64_t data_size, const InType *input, const IndexType *indices, const int *lengths, const float *weights, bool normalize_by_lengths, OutType *out)
Embedding lookup with reduction.