Caffe2 - C++ API
A deep learning, cross platform ML framework
elementwise_div_gradient_op.cc
1 #include "caffe2/operators/elementwise_div_op.h"
2 #include "caffe2/utils/eigen_utils.h"
3 
4 #include <algorithm>
5 #include <functional>
6 #include <string>
7 #include <vector>
8 
9 namespace caffe2 {
10 
11 namespace {
12 
13 template <typename TGrad, typename TIn, typename TOut>
14 void ComputeDivGradient(
15  const int ndim,
16  const int* A_dims,
17  const int* B_dims,
18  const int* C_dims,
19  const TGrad* dC,
20  const TIn* B,
21  const TOut* C,
22  TGrad* dA,
23  TGrad* dB,
24  CPUContext* context) {
25  const int A_size =
26  std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
27  const int B_size =
28  std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
29  const int C_size =
30  std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
31  if (dA != nullptr) {
32  math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
33  }
34  math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
35  std::vector<int> index(ndim, 0);
36  for (int C_index = 0; C_index < C_size; ++C_index) {
37  const int B_index =
38  math::utils::GetIndexFromDims(ndim, B_dims, index.data());
39  dB[B_index] += -dC[C_index] * C[C_index] / B[B_index];
40  if (dA != nullptr) {
41  const int A_index =
42  math::utils::GetIndexFromDims(ndim, A_dims, index.data());
43  dA[A_index] += dC[C_index] / B[B_index];
44  }
45  math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
46  }
47 }
48 
49 } // namespace
50 
51 template <>
52 template <typename TGrad, typename TIn, typename TOut>
53 bool DivFunctor<CPUContext>::Backward(
54  const std::vector<int>& A_dims,
55  const std::vector<int>& B_dims,
56  const TGrad* dC,
57  const TIn* /* A */,
58  const TIn* B,
59  const TOut* C,
60  TGrad* dA,
61  TGrad* dB,
62  CPUContext* context) const {
63  if (A_dims == B_dims) {
64  const int size = std::accumulate(
65  A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
66  EigenVectorMap<TGrad>(dB, size) =
67  -ConstEigenVectorArrayMap<TGrad>(dC, size) *
68  ConstEigenVectorArrayMap<TOut>(C, size) /
69  ConstEigenVectorArrayMap<TIn>(B, size);
70  math::Div(size, dC, B, dA, context);
71  return true;
72  }
73  const int ndim = std::max(A_dims.size(), B_dims.size());
74  std::vector<int> A_broadcast_dims(ndim);
75  std::vector<int> B_broadcast_dims(ndim);
76  std::vector<int> C_broadcast_dims(ndim);
77  math::utils::ComputeBroadcastBinaryOpDims(
78  A_dims.size(),
79  A_dims.data(),
80  B_dims.size(),
81  B_dims.data(),
82  A_broadcast_dims.data(),
83  B_broadcast_dims.data(),
84  C_broadcast_dims.data());
85  if (dA == dC) {
86  ComputeDivGradient<TGrad, TIn, TOut>(
87  ndim,
88  A_broadcast_dims.data(),
89  B_broadcast_dims.data(),
90  C_broadcast_dims.data(),
91  dC,
92  B,
93  C,
94  nullptr,
95  dB,
96  context);
97  math::Div(
98  A_dims.size(),
99  A_dims.data(),
100  B_dims.size(),
101  B_dims.data(),
102  dC,
103  B,
104  dA,
105  context);
106  } else {
107  ComputeDivGradient<TGrad, TIn, TOut>(
108  ndim,
109  A_broadcast_dims.data(),
110  B_broadcast_dims.data(),
111  C_broadcast_dims.data(),
112  dC,
113  B,
114  C,
115  dA,
116  dB,
117  context);
118  }
119  return true;
120 }
121 
122 template <>
124  NumericTypes,
125  CPUContext,
129  final : public Operator<CPUContext> {
130  public:
131  USE_OPERATOR_FUNCTIONS(CPUContext);
132 
133  template <class... Args>
134  explicit BinaryElementwiseWithArgsGradientOp(Args&&... args)
135  : Operator<CPUContext>(std::forward<Args>(args)...),
136  OP_SINGLE_ARG(bool, "broadcast", legacy_broadcast_, false),
137  OP_SINGLE_ARG(int, "axis", axis_, -1),
138  OP_SINGLE_ARG(string, "axis_str", axis_str_, ""),
139  OP_SINGLE_ARG(string, "order", order_, "NCHW"),
140  functor_(*this) {
141  if (legacy_broadcast_) {
142  if (axis_ != -1) {
143  // Get axis from an explicit axis argument.
144  CAFFE_ENFORCE_EQ(
145  axis_str_.size(),
146  0,
147  "Args axis and axis_str cannot be used simultaneously.");
148  } else if (axis_str_.size()) {
149  // Get the axis index semantically.
150  CAFFE_ENFORCE_EQ(
151  axis_str_.size(), 1, "Unsupported axis string", axis_str_);
152  const size_t semantic_axis_ = order_.find(axis_str_);
153  CAFFE_ENFORCE_NE(
154  semantic_axis_,
155  string::npos,
156  "Unrecognizable axis string ",
157  axis_str_,
158  " from order string ",
159  order_);
160  axis_ = semantic_axis_;
161  } else {
162  CAFFE_ENFORCE(
163  axis_ == -1 && axis_str_.empty(),
164  "Do not specify axis or axis_str if broadcast is not enabled.");
165  }
166  }
167  }
168 
169  bool RunOnDevice() override {
170  return DispatchHelper<NumericTypes>::call(this, Input(1));
171  }
172 
173  template <typename T>
174  bool DoRunWithType() {
175  const T* dC_data = nullptr;
176  const T* A_data = nullptr;
177  const T* B_data = nullptr;
178  const T* C_data = nullptr;
179  std::vector<int> A_dims;
180  std::vector<int> B_dims;
181  at::IntArrayRef dA_sizes;
182  at::IntArrayRef dB_sizes;
183  if (InputSize() == 3) {
184  const auto& B = Input(0);
185  const auto& C = Input(1);
186  const auto& dC = Input(2);
187  if (legacy_broadcast_) {
188  if (B.numel() == 1) {
189  A_dims = {static_cast<int>(C.numel())};
190  B_dims = {1};
191  } else {
192  size_t pre, n, post;
193  std::tie(pre, n, post) =
194  elementwise_ops_utils::ComputeLegacyBroadcastSizes(C, B, axis_);
195  A_dims = {static_cast<int>(pre),
196  static_cast<int>(n),
197  static_cast<int>(post)};
198  B_dims = {static_cast<int>(n), 1};
199  }
200  } else {
201  std::copy(
202  C.sizes().cbegin(), C.sizes().cend(), std::back_inserter(A_dims));
203  std::copy(
204  B.sizes().cbegin(), B.sizes().cend(), std::back_inserter(B_dims));
205  }
206  B_data = B.template data<T>();
207  C_data = C.template data<T>();
208  dC_data = dC.template data<T>();
209  dA_sizes = C.sizes();
210  dB_sizes = B.sizes();
211  } else {
212  const auto& dC = Input(0);
213  const auto& A = Input(1);
214  const auto& B = Input(2);
215  const auto& C = Input(3);
216  if (legacy_broadcast_) {
217  if (B.numel() == 1) {
218  A_dims = {static_cast<int>(A.numel())};
219  B_dims = {1};
220  } else {
221  size_t pre, n, post;
222  std::tie(pre, n, post) =
223  elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_);
224  A_dims = {static_cast<int>(pre),
225  static_cast<int>(n),
226  static_cast<int>(post)};
227  B_dims = {static_cast<int>(n), 1};
228  }
229  } else {
230  std::copy(
231  A.sizes().cbegin(), A.sizes().cend(), std::back_inserter(A_dims));
232  std::copy(
233  B.sizes().cbegin(), B.sizes().cend(), std::back_inserter(B_dims));
234  }
235  dC_data = dC.template data<T>();
236  A_data = A.template data<T>();
237  B_data = B.template data<T>();
238  C_data = C.template data<T>();
239  dA_sizes = A.sizes();
240  dB_sizes = B.sizes();
241  }
242  auto* dA = Output(0, dA_sizes, at::dtype<T>());
243  auto* dB = Output(1, dB_sizes, at::dtype<T>());
244  auto* dA_data = dA->template mutable_data<T>();
245  auto* dB_data = dB->template mutable_data<T>();
246  return functor_.Backward(
247  A_dims,
248  B_dims,
249  dC_data,
250  A_data,
251  B_data,
252  C_data,
253  dA_data,
254  dB_data,
255  &context_);
256  }
257 
258  private:
259  const bool legacy_broadcast_;
260  int axis_;
261  const std::string axis_str_;
262  const std::string order_;
263 
265 };
266 
267 REGISTER_CPU_OPERATOR(
268  DivGradient,
270  NumericTypes,
271  CPUContext,
273 
274 namespace {
275 
276 class GetDivGradient final : public GradientMakerBase {
277  using GradientMakerBase::GradientMakerBase;
278 
279  std::vector<OperatorDef> GetGradientDefs() override {
280  return SingleGradientDef(
281  "DivGradient",
282  "",
283  std::vector<std::string>{GO(0), I(0), I(1), O(0)},
284  std::vector<std::string>{GI(0), GI(1)});
285  }
286 };
287 
288 } // namespace
289 
290 REGISTER_GRADIENT(Div, GetDivGradient);
291 
292 } // namespace caffe2
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
Definition: static.cpp:52
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64
Definition: static.cpp:58