Caffe2 - C++ API
A deep learning, cross platform ML framework
lengths_reducer_ops.h
1 
17 #pragma once
18 #include "caffe2/core/context.h"
19 #include "caffe2/core/operator.h"
20 #include "caffe2/perfkernels/embedding_lookup.h"
21 
22 namespace caffe2 {
23 
24 // A templated class that implements SparseLengths[Sum,WeightedSum,Mean].
25 template <
26  typename T, // output type
27  class InputTypes, // supported input types, such as TensorTypes<float>
28  bool USE_WEIGHT = 0, // Whether it is SparseLengthsWeightedSum
29  bool USE_MEAN = 0 // Whether this is SparseLengthsMean
30  >
31 class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
32  public:
33  USE_OPERATOR_FUNCTIONS(CPUContext);
34  CPUSparseLengthsReductionOp(const OperatorDef& operator_def, Workspace* ws)
35  : Operator<CPUContext>(operator_def, ws) {
36  static_assert(
37  !(USE_WEIGHT & USE_MEAN), "Cannot both specify weight and mean.");
38  }
39 
41 
42  // Currently, we support float and float16 inputs for input data type, and
43  // int32_t and int64_t for the index type.
44 
45  bool RunOnDevice() override {
46  return DispatchHelper<InputTypes>::call(this, Input(DATA));
47  }
48 
49  template <typename InputType>
50  bool DoRunWithType() {
51  return DispatchHelper<TensorTypes2<int32_t, int64_t>, InputType>::call(
52  this, Input(INDICES));
53  }
54 
55  template <typename InputType, typename IndexType>
56  bool DoRunWithType2() {
57  auto& dataInput = Input(DATA);
58  auto& indicesInput = Input(INDICES);
59  auto& lengthsInput = Input(LENGTHS);
60 
61  CAFFE_ENFORCE_EQ(1, indicesInput.ndim(), "INDICES must be a vector");
62  CAFFE_ENFORCE_EQ(1, lengthsInput.ndim(), "LENGTHS must be a vector");
63  const TIndex N = dataInput.dim(0);
64  const int D = dataInput.size_from_dim(1);
65  const TIndex M = lengthsInput.dim(0);
66  const TIndex indices_size = indicesInput.size();
67 
68  auto* output = Output(0);
69  auto shape = dataInput.dims();
70  shape[0] = M;
71  output->Resize(shape);
72  T* out_data = output->template mutable_data<T>();
73 
74  const InputType* in_data = dataInput.template data<InputType>();
75  const IndexType* indices = indicesInput.template data<IndexType>();
76  const int* lengths = lengthsInput.template data<int>();
77  const T* in_weight = nullptr;
78 
79  if (USE_WEIGHT) { // static if
80  auto& weightInput = Input(WEIGHT);
81  CAFFE_ENFORCE_EQ(1, weightInput.ndim(), "WEIGHT must be a vector");
82  CAFFE_ENFORCE_EQ(
83  weightInput.size(),
84  indices_size,
85  "Weight should have the same length as indices.");
86  in_weight = weightInput.template data<T>();
87  }
88 
89  // delegate work to perfkernel that branches based on architecture
91  D,
92  M,
93  indices_size,
94  N,
95  in_data,
96  indices,
97  lengths,
98  in_weight,
99  nullptr, // scale_bias field is only used in SparseLengths8BitsRowwiseOp
100  USE_MEAN,
101  out_data);
102  return true;
103  }
104 
105  private:
106  enum {
107  DATA = 0, // Data input.
108  WEIGHT = 1, // Weight input used in SparseLengthsWeightedSum
109  INDICES = 1 + USE_WEIGHT, // 1 in SparseLengths[Sum,Mean] and
110  // 2 in SparseLengthsWeightedSum
111  LENGTHS = 2 + USE_WEIGHT, // 2 in SparseLengths[Sum, Mean],
112  // 3 in SparseLengthsWeightedSum
113  };
114 };
115 
116 } // namespace caffe2
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:82
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
void EmbeddingLookup(const TIndex block_size, const TIndex output_size, const TIndex index_size, const TIndex data_size, const InType *input, const IndexType *indices, const int *lengths, const float *weights, const float *scale_bias, bool normalize_by_lengths, OutType *out)
Embedding lookup with reduction.