1 #include "caffe2/operators/rsqrt_op.h" 3 #include "caffe2/utils/eigen_utils.h" 13 bool RsqrtGradientFunctor<CPUContext>::Forward(
14 const std::vector<int>& dY_dims,
15 const std::vector<int>& ,
20 const int size = std::accumulate(
21 dY_dims.cbegin(), dY_dims.cend(), 1, std::multiplies<int>());
22 EigenVectorMap<T>(dX, size) = ConstEigenVectorMap<T>(dY, size).array() *
23 ConstEigenVectorMap<T>(Y, size).array().cube() *
static_cast<T>(-0.5);
27 REGISTER_CPU_OPERATOR(
32 RsqrtFunctor<CPUContext>>);
33 REGISTER_CPU_OPERATOR(
38 RsqrtGradientFunctor<CPUContext>>);
40 OPERATOR_SCHEMA(Rsqrt)
43 .AllowInplace({{0, 0}})
44 .IdenticalTypeAndShape()
45 .SetDoc(
"Computes the element-wise rsqrt of the input.")
46 .Input(0,
"X",
"ND input tensor")
47 .Output(0,
"Y",
"ND output tensor");
49 OPERATOR_SCHEMA(RsqrtGradient)
52 .AllowInplace({{0, 0}});
56 class GetRsqrtGradient final :
public GradientMakerBase {
57 using GradientMakerBase::GradientMakerBase;
59 std::vector<OperatorDef> GetGradientDefs()
override {
60 return SingleGradientDef(
63 std::vector<std::string>{GO(0), O(0)},
64 std::vector<std::string>{GI(0)});
70 REGISTER_GRADIENT(Rsqrt, GetRsqrtGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...