Caffe2 - C++ API
A deep learning, cross platform ML framework
acos_op.cc
1 #include "caffe2/operators/acos_op.h"
2 #include "caffe2/utils/eigen_utils.h"
3 
4 #include <algorithm>
5 #include <functional>
6 
7 namespace caffe2 {
8 
9 template <>
10 template <typename T>
11 bool AcosGradientFunctor<CPUContext>::Forward(
12  const std::vector<int>& X_dims,
13  const std::vector<int>& /* dY_dims */,
14  const T* X,
15  const T* dY,
16  T* dX,
17  CPUContext* /* context */) const {
18  const int size = std::accumulate(
19  X_dims.cbegin(), X_dims.cend(), 1, std::multiplies<int>());
20  ConstEigenVectorArrayMap<T> dY_arr(dY, size);
21  ConstEigenVectorArrayMap<T> X_arr(X, size);
22  EigenVectorMap<T>(dX, size) = -dY_arr * (T(1) - X_arr.square()).rsqrt();
23  return true;
24 }
25 
26 REGISTER_CPU_OPERATOR(
27  Acos,
28  UnaryElementwiseOp<
29  TensorTypes<float>,
30  CPUContext,
31  AcosFunctor<CPUContext>>);
32 REGISTER_CPU_OPERATOR(
33  AcosGradient,
34  BinaryElementwiseOp<
35  TensorTypes<float>,
36  CPUContext,
37  AcosGradientFunctor<CPUContext>>);
38 
39 OPERATOR_SCHEMA(Acos)
40  .NumInputs(1)
41  .NumOutputs(1)
42  .IdenticalTypeAndShape()
43  .SetDoc(R"DOC(
44 Calculates the arccosine of the given input tensor, element-wise.
45 )DOC")
46  .Input(0, "input", "Input tensor")
47  .Output(
48  0,
49  "output",
50  "The arccosine of the input tensor computed element-wise");
51 
52 OPERATOR_SCHEMA(AcosGradient)
53  .NumInputs(2)
54  .NumOutputs(1)
55  .IdenticalTypeAndShape();
56 
57 namespace {
58 
59 class GetAcosGradient : public GradientMakerBase {
60  using GradientMakerBase::GradientMakerBase;
61  std::vector<OperatorDef> GetGradientDefs() override {
62  return SingleGradientDef(
63  "AcosGradient",
64  "",
65  std::vector<std::string>{I(0), GO(0)},
66  std::vector<std::string>{GI(0)});
67  }
68 };
69 
70 } // namespace
71 
72 REGISTER_GRADIENT(Acos, GetAcosGradient);
73 
74 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13