1 #include "caffe2/operators/elementwise_div_op.h" 2 #include "caffe2/utils/eigen_utils.h" 13 template <
typename TGrad,
typename TIn,
typename TOut>
14 void ComputeDivGradient(
24 CPUContext* context) {
26 std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
28 std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
30 std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
32 math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
34 math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
35 std::vector<int> index(ndim, 0);
36 for (
int C_index = 0; C_index < C_size; ++C_index) {
38 math::utils::GetIndexFromDims(ndim, B_dims, index.data());
39 dB[B_index] += -dC[C_index] * C[C_index] / B[B_index];
42 math::utils::GetIndexFromDims(ndim, A_dims, index.data());
43 dA[A_index] += dC[C_index] / B[B_index];
45 math::utils::IncreaseIndexInDims(ndim, C_dims, index.data());
52 template <
typename TGrad,
typename TIn,
typename TOut>
53 bool DivFunctor<CPUContext>::Backward(
54 const std::vector<int>& A_dims,
55 const std::vector<int>& B_dims,
62 CPUContext* context)
const {
63 if (A_dims == B_dims) {
64 const int size = std::accumulate(
65 A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
66 EigenVectorMap<TGrad>(dB, size) =
67 -ConstEigenVectorArrayMap<TGrad>(dC, size) *
68 ConstEigenVectorArrayMap<TOut>(C, size) /
69 ConstEigenVectorArrayMap<TIn>(B, size);
70 math::Div(size, dC, B, dA, context);
73 const int ndim = std::max(A_dims.size(), B_dims.size());
74 std::vector<int> A_broadcast_dims(ndim);
75 std::vector<int> B_broadcast_dims(ndim);
76 std::vector<int> C_broadcast_dims(ndim);
77 math::utils::ComputeBroadcastBinaryOpDims(
82 A_broadcast_dims.data(),
83 B_broadcast_dims.data(),
84 C_broadcast_dims.data());
86 ComputeDivGradient<TGrad, TIn, TOut>(
88 A_broadcast_dims.data(),
89 B_broadcast_dims.data(),
90 C_broadcast_dims.data(),
107 ComputeDivGradient<TGrad, TIn, TOut>(
109 A_broadcast_dims.data(),
110 B_broadcast_dims.data(),
111 C_broadcast_dims.data(),
129 final :
public Operator<CPUContext> {
133 template <
class... Args>
136 OP_SINGLE_ARG(
bool,
"broadcast", legacy_broadcast_,
false),
137 OP_SINGLE_ARG(
int,
"axis", axis_, -1),
138 OP_SINGLE_ARG(
string,
"axis_str", axis_str_,
""),
139 OP_SINGLE_ARG(
string,
"order", order_,
"NCHW"),
141 if (legacy_broadcast_) {
147 "Args axis and axis_str cannot be used simultaneously.");
148 }
else if (axis_str_.size()) {
151 axis_str_.size(), 1,
"Unsupported axis string", axis_str_);
152 const size_t semantic_axis_ = order_.find(axis_str_);
156 "Unrecognizable axis string ",
158 " from order string ",
160 axis_ = semantic_axis_;
163 axis_ == -1 && axis_str_.empty(),
164 "Do not specify axis or axis_str if broadcast is not enabled.");
169 bool RunOnDevice()
override {
173 template <
typename T>
174 bool DoRunWithType() {
175 const T* dC_data =
nullptr;
176 const T* A_data =
nullptr;
177 const T* B_data =
nullptr;
178 const T* C_data =
nullptr;
179 std::vector<int> A_dims;
180 std::vector<int> B_dims;
183 if (InputSize() == 3) {
184 const auto& B =
Input(0);
185 const auto& C =
Input(1);
186 const auto& dC =
Input(2);
187 if (legacy_broadcast_) {
188 if (B.numel() == 1) {
189 A_dims = {
static_cast<int>(C.numel())};
193 std::tie(pre, n, post) =
194 elementwise_ops_utils::ComputeLegacyBroadcastSizes(C, B, axis_);
195 A_dims = {
static_cast<int>(pre),
197 static_cast<int>(post)};
198 B_dims = {
static_cast<int>(n), 1};
202 C.sizes().cbegin(), C.sizes().cend(), std::back_inserter(A_dims));
204 B.sizes().cbegin(), B.sizes().cend(), std::back_inserter(B_dims));
206 B_data = B.template data<T>();
207 C_data = C.template data<T>();
208 dC_data = dC.template data<T>();
209 dA_sizes = C.sizes();
210 dB_sizes = B.sizes();
212 const auto& dC =
Input(0);
214 const auto& B =
Input(2);
215 const auto& C =
Input(3);
216 if (legacy_broadcast_) {
217 if (B.numel() == 1) {
218 A_dims = {
static_cast<int>(
A.numel())};
222 std::tie(pre, n, post) =
223 elementwise_ops_utils::ComputeLegacyBroadcastSizes(
A, B, axis_);
224 A_dims = {
static_cast<int>(pre),
226 static_cast<int>(post)};
227 B_dims = {
static_cast<int>(n), 1};
231 A.sizes().cbegin(),
A.sizes().cend(), std::back_inserter(A_dims));
233 B.sizes().cbegin(), B.sizes().cend(), std::back_inserter(B_dims));
235 dC_data = dC.template data<T>();
236 A_data =
A.template data<T>();
237 B_data = B.template data<T>();
238 C_data = C.template data<T>();
239 dA_sizes =
A.sizes();
240 dB_sizes = B.sizes();
242 auto* dA = Output(0, dA_sizes, at::dtype<T>());
243 auto* dB = Output(1, dB_sizes, at::dtype<T>());
244 auto* dA_data = dA->template mutable_data<T>();
245 auto* dB_data = dB->template mutable_data<T>();
246 return functor_.Backward(
259 const bool legacy_broadcast_;
261 const std::string axis_str_;
262 const std::string order_;
267 REGISTER_CPU_OPERATOR(
277 using GradientMakerBase::GradientMakerBase;
279 std::vector<OperatorDef> GetGradientDefs()
override {
280 return SingleGradientDef(
283 std::vector<std::string>{GO(0), I(0), I(1), O(0)},
284 std::vector<std::string>{GI(0), GI(1)});
290 REGISTER_GRADIENT(Div, GetDivGradient);
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...