1 #include "caffe2/operators/elementwise_mul_op.h" 12 template <
typename TGrad,
typename TIn>
13 void ComputeMulGradient(
23 CPUContext* context) {
25 std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
27 std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
29 std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
30 math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
31 math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
32 std::vector<int> index(ndim, 0);
33 for (
int C_index = 0; C_index < C_size; ++C_index) {
35 math::utils::GetIndexFromDims(ndim, A_dims, index.data());
37 math::utils::GetIndexFromDims(ndim, B_dims, index.data());
38 dA[A_index] += dC[C_index] * B[B_index];
39 dB[B_index] += dC[C_index] * A[A_index];
40 math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
47 template <
typename TGrad,
typename TIn,
typename TOut>
48 bool MulFunctor<CPUContext>::Backward(
49 const std::vector<int>& A_dims,
50 const std::vector<int>& B_dims,
57 CPUContext* context)
const {
58 if (A_dims == B_dims) {
59 const int size = std::accumulate(
60 A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
61 math::Mul(size, dC, B, dA, context);
62 math::Mul(size, dC, A, dB, context);
65 const int ndim = std::max(A_dims.size(), B_dims.size());
66 std::vector<int> A_broadcast_dims(ndim);
67 std::vector<int> B_broadcast_dims(ndim);
68 std::vector<int> C_broadcast_dims(ndim);
69 math::utils::ComputeBroadcastBinaryOpDims(
74 A_broadcast_dims.data(),
75 B_broadcast_dims.data(),
76 C_broadcast_dims.data());
77 ComputeMulGradient<TGrad, TIn>(
79 A_broadcast_dims.data(),
80 B_broadcast_dims.data(),
81 C_broadcast_dims.data(),
91 REGISTER_CPU_OPERATOR(
93 BinaryElementwiseGradientOp<
96 MulFunctor<CPUContext>>);
100 class GetMulGradient final :
public GradientMakerBase {
101 using GradientMakerBase::GradientMakerBase;
103 std::vector<OperatorDef> GetGradientDefs()
override {
104 return SingleGradientDef(
107 std::vector<std::string>{GO(0), I(0), I(1)},
108 std::vector<std::string>{GI(0), GI(1)});
114 REGISTER_GRADIENT(Mul, GetMulGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...