Caffe2 - C++ API
A deep learning, cross platform ML framework
flatten_op.cc
1 #include "caffe2/operators/flatten_op.h"
2 
3 namespace caffe2 {
4 
5 REGISTER_CPU_OPERATOR(Flatten, FlattenOp<CPUContext>);
6 
7 OPERATOR_SCHEMA(Flatten)
8  .NumInputs(1)
9  .NumOutputs(1)
10  .TensorInferenceFunction(TensorInferenceForFlatten)
11  .SetDoc(R"DOC(
12 Flattens the input tensor into a 2D matrix. If input tensor has shape
13 $(d_0, d_1, ..., d_n)$ then the output will have shape
14 $\bigl((d_0 * d_1 * ... * d_{(axis-1)}), (d_{axis} * d_{(axis+1)} * ... * d_n)\bigr)$.
15 
16 Github Links:
17 
18 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/flatten_op.cc
19 
20 <details>
21 
22 <summary> <b>Example</b> </summary>
23 
24 **Code**
25 
26 ```
27 workspace.ResetWorkspace()
28 
29 op = core.CreateOperator(
30  "Flatten",
31  ["X"],
32  ["Y"],
33  axis=1
34 )
35 
36 workspace.FeedBlob("X", np.random.rand(1,3,2,2))
37 print("X:", workspace.FetchBlob("X"))
38 workspace.RunOperatorOnce(op)
39 print("Y:", workspace.FetchBlob("Y"))
40 ```
41 
42 **Result**
43 
44 ```
45 X: [[[[0.53432311 0.23734561]
46  [0.56481598 0.52152617]]
47 
48  [[0.33662627 0.32472711]
49  [0.17939016 0.97175851]]
50 
51  [[0.87226421 0.49045439]
52  [0.92470531 0.30935077]]]]
53 Y: [[0.53432311 0.23734561 0.56481598 0.52152617 0.33662627 0.32472711
54  0.17939016 0.97175851 0.87226421 0.49045439 0.92470531 0.30935077]]
55 ```
56 
57 </details>
58 
59 )DOC")
60  .Input(0, "X", "*(type: Tensor)* Input Tensor of rank >= axis.")
61  .Output(
62  0,
63  "Y",
64  "*(type: Tensor)* A 2D tensor with the contents of the input tensor, "
65  "with input dimensions up to `axis` flattened to the outer dimension "
66  "of the output and the remaining input dimensions flattened into the "
67  "inner dimension of the output.")
68  .Arg(
69  "axis",
70  "*(type: int; default: 1)* Indicates up to which input dimensions "
71  "(exclusive) should be flattened to the outer dimension of the output.")
72  .InheritOnnxSchema();
73 
74 class GetFlattenGradient : public GradientMakerBase {
75  using GradientMakerBase::GradientMakerBase;
76  vector<OperatorDef> GetGradientDefs() override {
77  return SingleGradientDef(
78  "ResizeLike", "", vector<string>{GO(0), I(0)}, vector<string>{GI(0)});
79  }
80 };
81 
82 REGISTER_GRADIENT(Flatten, GetFlattenGradient);
83 
84 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13