Caffe2 - C++ API
A deep learning, cross platform ML framework
lengths_reducer_fused_8bit_rowwise_ops.h
1 
17 #ifndef CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_
18 #define CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/operators/fused_rowwise_8bit_conversion_ops.h"
24 #include "caffe2/operators/reducer_functors.h"
25 #include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h"
26 #include "caffe2/utils/math.h"
27 
28 namespace caffe2 {
29 
30 template <class Context, bool with_weights = 0, bool is_mean = 0>
31 class SparseLengthsFused8BitRowwiseOp : public Operator<Context> {
32  public:
33  static_assert(
34  !(with_weights && is_mean),
35  "Cannot have with_weights and is_mean a the same time");
36 
37  USE_OPERATOR_CONTEXT_FUNCTIONS;
38  USE_SIMPLE_CTOR_DTOR(SparseLengthsFused8BitRowwiseOp)
39 
40  bool RunOnDevice() override {
42  this, Input(INDICES));
43  }
44 
45  template <typename IndexType>
46  bool DoRunWithType() {
47  const auto& data = Input(DATA);
48  const auto& indices = Input(INDICES);
49  const auto& lengths = Input(LENGTHS);
50  auto* output = Output(0);
51 
52  CAFFE_ENFORCE_EQ(indices.ndim(), 1, "INDICES must be a vector");
53  CAFFE_ENFORCE_EQ(lengths.ndim(), 1, "LENGTHS must be a vector");
54 
55  const float* weights = nullptr;
56  if (with_weights) {
57  const auto& weights_input = Input(WEIGHTS);
58  CAFFE_ENFORCE_EQ(weights_input.ndim(), 1, "WEIGHTS must be a vector");
59  CAFFE_ENFORCE_EQ(
60  weights_input.size(),
61  indices.size(),
62  "WEIGHTS should have the same length as INDICES.");
63  weights = weights_input.template data<float>();
64  }
65 
66  CAFFE_ENFORCE_GT(data.dim(1), 8, "DATA must have more than 8 columns");
67  // Subtract 8 from the #columns of data for the 4 bytes for scale and 4
68  // bytes for bias that we use in the fused representation (per row).
69  const std::vector<TIndex> shape = {lengths.dim(0), data.dim(1) - 8};
70  output->Resize(shape);
71 
73  /*block_size=*/output->dim(1),
74  /*output_size=*/output->dim(0),
75  /*index_size=*/indices.size(),
76  /*data_size=*/data.dim(0),
77  /*input=*/data.template data<uint8_t>(),
78  /*indices=*/indices.template data<IndexType>(),
79  /*lengths=*/lengths.template data<int>(),
80  /*weights=*/weights,
81  /*normalize_by_lengths=*/is_mean,
82  /*out=*/output->template mutable_data<float>());
83 
84  return true;
85  }
86 
87  private:
88  enum {
89  DATA = 0,
90  WEIGHTS = 1,
91  INDICES = 1 + with_weights,
92  LENGTHS = 2 + with_weights,
93  };
94 };
95 
96 } // namespace caffe2
97 
98 #endif // CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_
Copyright (c) 2016-present, Facebook, Inc.
void Fused8BitRowwiseEmbeddingLookup(const TIndex block_size, const TIndex output_size, const TIndex index_size, const TIndex data_size, const InType *input, const IndexType *indices, const int *lengths, const float *weights, bool normalize_by_lengths, OutType *out)
Embedding lookup with reduction.