Caffe2 - C++ API
A deep learning, cross platform ML framework
lengths_reducer_ops.h
1 #pragma once
2 #include "caffe2/core/context.h"
3 #include "caffe2/core/operator.h"
4 #include "caffe2/perfkernels/embedding_lookup.h"
5 
6 namespace caffe2 {
7 
8 // A templated class that implements SparseLengths[Sum,WeightedSum,Mean].
9 template <
10  typename T, // output type
11  class InputTypes, // supported input types, such as TensorTypes<float>
12  bool USE_WEIGHT = 0, // Whether it is SparseLengthsWeightedSum
13  bool USE_MEAN = 0, // Whether this is SparseLengthsMean
14  bool USE_POSITIONAL_WEIGHT = 0
15  // USE_WEIGHT = 1 and USE_POSITIONAL_WEIGHT = 1
16  // -> SparseLengthsPositionalWeightedSum
17  >
18 class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
19  public:
20  USE_OPERATOR_FUNCTIONS(CPUContext);
21  template <class... Args>
22  explicit CPUSparseLengthsReductionOp(Args&&... args)
23  : Operator<CPUContext>(std::forward<Args>(args)...) {
24  static_assert(
25  !(USE_WEIGHT & USE_MEAN), "Cannot both specify weight and mean.");
26  }
27 
28  ~CPUSparseLengthsReductionOp() {}
29 
30  // Currently, we support float and at::Half inputs for input data type, and
31  // int32_t and int64_t for the index type.
32 
33  bool RunOnDevice() override {
34  return DispatchHelper<InputTypes>::call(this, Input(DATA));
35  }
36 
37  template <typename InputType>
38  bool DoRunWithType() {
39  return DispatchHelper<TensorTypes2<int32_t, int64_t>, InputType>::call(
40  this, Input(INDICES));
41  }
42 
43  template <typename InputType, typename IndexType>
44  bool DoRunWithType2() {
45  auto& dataInput = Input(DATA);
46  auto& indicesInput = Input(INDICES);
47  auto& lengthsInput = Input(LENGTHS);
48 
49  CAFFE_ENFORCE_EQ(1, indicesInput.dim(), "INDICES must be a vector");
50  CAFFE_ENFORCE_EQ(1, lengthsInput.dim(), "LENGTHS must be a vector");
51  const int64_t N = dataInput.size(0);
52  const int D = dataInput.size_from_dim(1);
53  const int64_t M = lengthsInput.size(0);
54  const int64_t indices_size = indicesInput.numel();
55 
56  auto shape = dataInput.sizes().vec();
57  shape[0] = M;
58  auto* output = Output(0, shape, at::dtype<T>());
59  T* out_data = output->template mutable_data<T>();
60 
61  const InputType* in_data = dataInput.template data<InputType>();
62  const IndexType* indices = indicesInput.template data<IndexType>();
63  const int* lengths = lengthsInput.template data<int>();
64  const T* in_weight = nullptr;
65 
66  if (USE_WEIGHT) {
67  // static if
68  auto& weightInput = Input(WEIGHT);
69  CAFFE_ENFORCE_EQ(1, weightInput.dim(), "WEIGHT must be a vector");
70  if (!USE_POSITIONAL_WEIGHT) {
71  CAFFE_ENFORCE_EQ(
72  weightInput.numel(),
73  indices_size,
74  "Weight should have the same length as indices.");
75  }
76  in_weight = weightInput.template data<T>();
77  }
78 
79  // delegate work to perfkernel that branches based on architecture
80  EmbeddingLookup<IndexType, InputType, T, USE_POSITIONAL_WEIGHT>(
81  D,
82  M,
83  indices_size,
84  N,
85  in_data,
86  indices,
87  lengths,
88  in_weight,
89  nullptr, // scale_bias field is only used in SparseLengths8BitsRowwiseOp
90  USE_MEAN,
91  out_data);
92  return true;
93  }
94 
95  enum {
96  DATA = 0, // Data input.
97  WEIGHT = 1, // Weight input used in SparseLengthsWeightedSum
98  INDICES = 1 + USE_WEIGHT, // 1 in SparseLengths[Sum,Mean] and
99  // 2 in SparseLengthsWeightedSum
100  LENGTHS = 2 + USE_WEIGHT, // 2 in SparseLengths[Sum, Mean],
101  // 3 in SparseLengthsWeightedSum
102  };
103 };
104 
105 } // namespace caffe2
Definition: any.cpp:108
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
const Tensor & Input(int idx, DeviceType type=CPUContext::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:70