Caffe2 - C++ API
A deep learning, cross platform ML framework
sqrt_op.cc
1 #include "caffe2/operators/sqrt_op.h"
2 
3 #include <string>
4 #include <vector>
5 
6 namespace caffe2 {
7 
8 REGISTER_CPU_OPERATOR(
9  Sqrt,
10  UnaryElementwiseOp<
11  TensorTypes<float>,
12  CPUContext,
13  SqrtFunctor<CPUContext>>);
14 
15 // Input: X, output: Y
16 OPERATOR_SCHEMA(Sqrt)
17  .NumInputs(1)
18  .NumOutputs(1)
19  .AllowInplace({{0, 0}})
20  .IdenticalTypeAndShape()
21  .SetDoc(R"DOC(
22 Performs element-wise square-root ($\sqrt{x}$) of input tensor $X$.
23 
24 Github Link:
25 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/sqrt_op.cc
26 
27 <details>
28 
29 <summary> <b>Example</b> </summary>
30 
31 **Code**
32 
33 ```
34 
35 workspace.ResetWorkspace()
36 
37 op = core.CreateOperator(
38  "Sqrt",
39  ["X"],
40  ["Y"],
41 )
42 
43 workspace.FeedBlob("X", (np.random.randint(10, size=(3,3))).astype(np.float32))
44 print("X:", workspace.FetchBlob("X"))
45 workspace.RunOperatorOnce(op)
46 print("Y:", workspace.FetchBlob("Y"))
47 
48 ```
49 
50 **Result**
51 
52 ```
53 
54 X:
55 [[8. 3. 3.]
56  [4. 0. 0.]
57  [1. 2. 5.]]
58 Y:
59 [[2.8284268 1.7320508 1.7320508 ]
60  [1.9999999 0. 0. ]
61  [0.99999994 1.4142134 2.236068 ]]
62 
63 ```
64 
65 </details>
66 )DOC")
67 .Input(0, "X", "*(type: Tensor`<float>`)* Input data tensor.")
68 .Output(0, "Y", "*(type: Tensor`<float>`)* Output tensor.");
69 
70 namespace {
71 
73  using GradientMakerBase::GradientMakerBase;
74  std::vector<OperatorDef> GetGradientDefs() override {
75  Argument scale_arg;
76  scale_arg.set_name("scale");
77  scale_arg.set_f(0.5);
78  return std::vector<OperatorDef>{CreateOperatorDef(
79  "Scale",
80  "",
81  std::vector<std::string>{GO(0)},
82  std::vector<std::string>{GI(0)},
83  std::vector<Argument>{scale_arg}),
84  CreateOperatorDef(
85  "Div",
86  "",
87  std::vector<std::string>{GI(0), O(0)},
88  std::vector<std::string>{GI(0)})};
89  }
90 };
91 
92 } // namespace
93 
94 REGISTER_GRADIENT(Sqrt, GetSqrtGradient);
95 
96 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13