Caffe2 - C++ API
A deep learning, cross platform ML framework
elementwise_mul_gradient_op.cc
1 #include "caffe2/operators/elementwise_mul_op.h"
2 
3 #include <algorithm>
4 #include <functional>
5 #include <string>
6 #include <vector>
7 
8 namespace caffe2 {
9 
10 namespace {
11 
12 template <typename TGrad, typename TIn>
13 void ComputeMulGradient(
14  const int ndim,
15  const int* A_dims,
16  const int* B_dims,
17  const int* C_dims,
18  const TGrad* dC,
19  const TIn* A,
20  const TIn* B,
21  TGrad* dA,
22  TGrad* dB,
23  CPUContext* context) {
24  const int A_size =
25  std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
26  const int B_size =
27  std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
28  const int C_size =
29  std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
30  math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
31  math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
32  std::vector<int> index(ndim, 0);
33  for (int C_index = 0; C_index < C_size; ++C_index) {
34  const int A_index =
35  math::utils::GetIndexFromDims(ndim, A_dims, index.data());
36  const int B_index =
37  math::utils::GetIndexFromDims(ndim, B_dims, index.data());
38  dA[A_index] += dC[C_index] * B[B_index];
39  dB[B_index] += dC[C_index] * A[A_index];
40  math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
41  }
42 }
43 
44 } // namespace
45 
46 template <>
47 template <typename TGrad, typename TIn, typename TOut>
48 bool MulFunctor<CPUContext>::Backward(
49  const std::vector<int>& A_dims,
50  const std::vector<int>& B_dims,
51  const TGrad* dC,
52  const TIn* A,
53  const TIn* B,
54  const TOut* /* C */,
55  TGrad* dA,
56  TGrad* dB,
57  CPUContext* context) const {
58  if (A_dims == B_dims) {
59  const int size = std::accumulate(
60  A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
61  math::Mul(size, dC, B, dA, context);
62  math::Mul(size, dC, A, dB, context);
63  return true;
64  }
65  const int ndim = std::max(A_dims.size(), B_dims.size());
66  std::vector<int> A_broadcast_dims(ndim);
67  std::vector<int> B_broadcast_dims(ndim);
68  std::vector<int> C_broadcast_dims(ndim);
69  math::utils::ComputeBroadcastBinaryOpDims(
70  A_dims.size(),
71  A_dims.data(),
72  B_dims.size(),
73  B_dims.data(),
74  A_broadcast_dims.data(),
75  B_broadcast_dims.data(),
76  C_broadcast_dims.data());
77  ComputeMulGradient<TGrad, TIn>(
78  ndim,
79  A_broadcast_dims.data(),
80  B_broadcast_dims.data(),
81  C_broadcast_dims.data(),
82  dC,
83  A,
84  B,
85  dA,
86  dB,
87  context);
88  return true;
89 }
90 
91 REGISTER_CPU_OPERATOR(
92  MulGradient,
93  BinaryElementwiseGradientOp<
94  NumericTypes,
95  CPUContext,
96  MulFunctor<CPUContext>>);
97 
98 namespace {
99 
100 class GetMulGradient final : public GradientMakerBase {
101  using GradientMakerBase::GradientMakerBase;
102 
103  std::vector<OperatorDef> GetGradientDefs() override {
104  return SingleGradientDef(
105  "MulGradient",
106  "",
107  std::vector<std::string>{GO(0), I(0), I(1)},
108  std::vector<std::string>{GI(0), GI(1)});
109  }
110 };
111 
112 } // namespace
113 
114 REGISTER_GRADIENT(Mul, GetMulGradient);
115 
116 } // namespace caffe2
Definition: static.cpp:52
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:58