Caffe2 - C++ API
A deep learning, cross platform ML framework
cbrt_op.cc
1 #include "caffe2/operators/cbrt_op.h"
2 #include "caffe2/utils/eigen_utils.h"
3 
4 #include <algorithm>
5 #include <functional>
6 #include <string>
7 
8 namespace caffe2 {
9 
10 template <>
11 template <typename T>
12 bool CbrtGradientFunctor<CPUContext>::Forward(
13  const std::vector<int>& dY_dims,
14  const std::vector<int>& /* Y_dims */,
15  const T* dY,
16  const T* Y,
17  T* dX,
18  CPUContext* /* context */) const {
19  const int size = std::accumulate(
20  dY_dims.cbegin(), dY_dims.cend(), 1, std::multiplies<int>());
21  EigenVectorMap<T>(dX, size) = ConstEigenVectorArrayMap<T>(dY, size) /
22  ConstEigenVectorArrayMap<T>(Y, size).square() / T(3);
23  return true;
24 }
25 
26 REGISTER_CPU_OPERATOR(
27  Cbrt,
28  UnaryElementwiseOp<
29  TensorTypes<float>,
30  CPUContext,
31  CbrtFunctor<CPUContext>>);
32 REGISTER_CPU_OPERATOR(
33  CbrtGradient,
34  BinaryElementwiseOp<
35  TensorTypes<float>,
36  CPUContext,
37  CbrtGradientFunctor<CPUContext>>);
38 
39 OPERATOR_SCHEMA(Cbrt)
40  .NumInputs(1)
41  .NumOutputs(1)
42  .AllowInplace({{0, 0}})
43  .IdenticalTypeAndShape()
44  .Input(0, "X", "*(type: Tensor`<float>`)* Input tensor.")
45  .Output(
46  0,
47  "Y",
48  "*(type: Tensor`<float>`)* Output tensor calculated as the cbrt of the input tensor, element-wise.");
49 
50 OPERATOR_SCHEMA(CbrtGradient)
51  .NumInputs(2)
52  .NumOutputs(1)
53  .AllowInplace({{0, 0}})
54  .IdenticalTypeAndShape();
55 
56 namespace {
57 
58 class GetCbrtGradient : public GradientMakerBase {
59  using GradientMakerBase::GradientMakerBase;
60  std::vector<OperatorDef> GetGradientDefs() override {
61  return SingleGradientDef(
62  "CbrtGradient",
63  "",
64  std::vector<std::string>{GO(0), O(0)},
65  std::vector<std::string>{GI(0)});
66  }
67 };
68 
69 } // namespace
70 
71 REGISTER_GRADIENT(Cbrt, GetCbrtGradient);
72 
73 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13