Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_gather_ops.h
1 #ifndef CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_
2 #define CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 // Reuse helper logic from GatherOp since BatchGather is the same with axis=1.
8 #include "caffe2/operators/gather_op.h"
9 
10 namespace caffe2 {
11 
12 template <class Context>
13 class BatchGatherOp final : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  USE_SIMPLE_CTOR_DTOR(BatchGatherOp)
17 
18  bool RunOnDevice() override {
20  this, this->template Input<Tensor>(INDICES, CPU));
21  }
22 
23  template <typename TInd>
24  bool DoRunWithType() {
25  // BatchGather is a special-case of Gather with Axis = 1.
26  return gather_helper::gather_impl<TInd, Context>(
27  this, DATA, INDICES, 0, 1, false);
28  }
29  INPUT_TAGS(DATA, INDICES);
30 };
31 
32 template <class Context>
33 class BatchGatherGradientOp final : public Operator<Context> {
34  public:
35  USE_OPERATOR_CONTEXT_FUNCTIONS;
36 
37  // Constructor to recieve axis in case it was passed for GatherOp gradient,
38  // use default of 1 for batch gather otherwise.
39  template <class... Args>
40  explicit BatchGatherGradientOp(Args&&... args)
41  : Operator<Context>(std::forward<Args>(args)...),
42  OP_SINGLE_ARG(int, "axis", axis_, 1) {}
43  virtual ~BatchGatherGradientOp() noexcept {}
44 
45  bool RunOnDevice() override {
47  this, this->template Input<Tensor>(INDICES, CPU));
48  }
49 
50  template <typename TInd>
51  bool DoRunWithType() {
52  return DispatchHelper<
54  TInd>::call(this, Input(DATA));
55  }
56 
57  template <typename TInd, typename TData>
58  bool DoRunWithType2() {
59  auto& data = Input(DATA);
60  auto& indices = Input(INDICES);
61  auto& grad = Input(GRAD);
62 
63  // ONNX allows negative axis to index from the back, valid range: [-r, r].
64  int axis = axis_;
65  if (axis < 0) {
66  axis = data.dim() + axis;
67  }
68 
69  CAFFE_ENFORCE_GE(data.dim(), 2, "DATA should be at least 2-D");
70  // Outer dimensions of input data and gradient should be the same
71  // because they are preserved for gathers with axis > 0.
72  for (int acheck = 0; acheck < axis; acheck++) {
73  CAFFE_ENFORCE_EQ(
74  data.size(acheck),
75  grad.size(acheck),
76  "batch gather outer dimensions should match");
77  }
78 
79  auto* output = Output(0, data.sizes(), at::dtype<TData>());
80  TData* out_data = output->template mutable_data<TData>();
81  if (data.numel() <= 0) {
82  return true;
83  }
84  memset(out_data, 0, output->nbytes());
85 
86  const TData* grad_data = grad.template data<TData>();
87  const TInd* idxs = indices.template data<TInd>();
88 
89  auto outer_dims_product = data.size_to_dim(axis);
90  auto batch_size = data.size_from_dim(axis);
91  auto block_size = data.size_from_dim(axis + 1);
92  auto N = indices.numel();
93  auto gathered_grad_batch_size = N * block_size;
94 
95  // Check indexing bounds.
96  auto src_indexing_axis_dim = data.dim(axis);
97  gather_helper::check_indexarray_range<TInd>(
98  idxs,
99  N,
100  src_indexing_axis_dim,
101  false);
102 
103  for (auto batch = 0; batch < outer_dims_product; ++batch) {
104  auto grad_batch_base = grad_data + batch * gathered_grad_batch_size;
105  auto out_batch_base = out_data + batch * batch_size;
106 
107  for (auto i = 0; i < N; ++i) {
108  auto idx = idxs[i];
109  if (idx < 0) {
110  idx = idx + src_indexing_axis_dim;
111  }
112  if (block_size == 1) {
113  out_batch_base[idx] += grad_batch_base[i];
114  } else {
115  math::Add(
116  block_size,
117  out_batch_base + idx * block_size,
118  grad_batch_base + i * block_size,
119  out_batch_base + idx * block_size,
120  &context_);
121  }
122  }
123  }
124  return true;
125  }
126 
127  template <typename TInd>
128  bool DoRunWithOtherType2() {
129  CAFFE_THROW(
130  "BatchGatherGradient is not implemented on tensor of type ",
131  Input(DATA).meta().name(),
132  "consider adding it as a type in the DispatchHelper list or "
133  "implementing a generic version (which won't work for "
134  "duplicated indices though)");
135  }
136 
137  INPUT_TAGS(DATA, INDICES, GRAD);
138 protected:
139  int axis_;
140 };
141 
142 } // namespace caffe2
143 
144 #endif // CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13