Caffe2 - C++ API
A deep learning, cross platform ML framework
reshape_op.cc
1 #include "caffe2/operators/reshape_op.h"
2 #include "caffe2/utils/math.h"
3 
4 namespace caffe2 {
5 
6 REGISTER_CPU_OPERATOR(Reshape, ReshapeOp<float, CPUContext>);
7 
8 OPERATOR_SCHEMA(Reshape)
9  .NumInputs(1, 2)
10  .NumOutputs(2)
11  .TensorInferenceFunction(
12  [](const OperatorDef& def, const vector<TensorShape>& in) {
13  vector<TensorShape> out(2);
14 
15  // Do shape inference for old_shape
16  out[1].set_data_type(TensorProto::INT64);
17  out[1].add_dims(in[0].dims_size());
18 
19  ArgumentHelper helper(def);
20  if (!helper.HasArgument("shape")) {
21  // Cannot do shape inference for reshaped tensor from runtime data.
22  CAFFE_ENFORCE_EQ(
23  in.size(),
24  2,
25  "New shape must be specified by either the input blob or the "
26  "argument `shape`.");
27  out[0].set_unknown_shape(true);
28  return out;
29  }
30  CAFFE_ENFORCE_EQ(
31  in.size(),
32  1,
33  "New shape must not be specified by the input blob and the "
34  "argument `shape` at the same time.");
35 
36  // Infer the actual new shape
37  auto actualNewShape = helper.GetRepeatedArgument<int64_t>("shape");
38 
39  // Copy over the dimensions for those that are specified zero
40  // and check the eligibility of input
41  for (int i = 0; i < actualNewShape.size(); ++i) {
42  CAFFE_ENFORCE_GE(
43  actualNewShape[i],
44  -1,
45  "The dimensions in argument `shape` "
46  "must not be a negative number.");
47 
48  if (actualNewShape[i] == 0) {
49  CAFFE_ENFORCE_LT(
50  i,
51  in[0].dims_size(),
52  "Argument `shape` has a dimension set to zero that exceeds "
53  "the original dimension size.");
54  actualNewShape[i] = in[0].dims(i);
55  }
56  }
57 
58  // Check if the new shape is valid and fills in the missing dimension
59  // specified by -1.
60  int64_t totalSize = 1;
61  for (const auto d : in[0].dims()) {
62  totalSize *= d;
63  }
64  int64_t size = 1;
65  int unknownIdx = -1;
66  for (int i = 0; i < actualNewShape.size(); ++i) {
67  const auto dim = actualNewShape[i];
68  if (dim == -1) {
69  CAFFE_ENFORCE(
70  unknownIdx == -1,
71  "Argument `shape` has more than one missing dimension.");
72  unknownIdx = i;
73  } else {
74  size *= dim;
75  }
76  }
77 
78  if (unknownIdx != -1) {
79  CAFFE_ENFORCE(
80  totalSize % size == 0,
81  "Argument `shape` does not agree with the input data.",
82  " (",
83  totalSize,
84  " vs ",
85  size,
86  ")");
87  actualNewShape[unknownIdx] = totalSize / size;
88  } else {
89  CAFFE_ENFORCE_EQ(
90  totalSize,
91  size,
92  "Argument `shape` does not agree with the input data.",
93  " (",
94  totalSize,
95  " != ",
96  size,
97  ")");
98  }
99 
100  out[0].set_data_type(in[0].data_type());
101  for (const auto d : actualNewShape) {
102  out[0].add_dims(d);
103  }
104  return out;
105  })
106  .AllowInplace({{0, 0}})
107  .SetDoc(R"DOC(
108 Reshape the input tensor similar to numpy's
109 [reshape](https://docs.scipy.org/doc/numpy/reference/generated/numpy.reshape.html).
110 
111 Takes a tensor as input and an optional tensor specifying the new shape. When
112 the second input is absent, an extra argument shape must be specified. Outputs
113 the reshaped tensor as well as the original shape.
114 
115 At most one dimension of the new shape can be -1. In this case, the value is
116 inferred from the size of the tensor and the remaining dimensions. A dimension
117 could also be 0, in which case the actual dimension value is going to be copied
118 from the input tensor.
119 
120 Github Links:
121 
122 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/reshape_op.cc
123 
124 <details>
125 
126 <summary> <b>Example</b> </summary>
127 
128 **Code**
129 
130 ```
131 workspace.ResetWorkspace()
132 
133 op = core.CreateOperator(
134  "Reshape",
135  ["data"],
136  ["reshaped", "old_shape"],
137  shape=(3,2)
138 )
139 
140 workspace.FeedBlob("data", (np.random.randint(100, size=(6))))
141 print("data:", workspace.FetchBlob("data"))
142 workspace.RunOperatorOnce(op)
143 print("reshaped:", workspace.FetchBlob("reshaped"))
144 print("old_shape:", workspace.FetchBlob("old_shape"))
145 ```
146 
147 **Result**
148 
149 ```
150 data: [86 60 85 96 7 37]
151 reshaped: [[86 60]
152  [85 96]
153  [ 7 37]]
154 old_shape: [6]
155 ```
156 
157 </details>
158 
159 )DOC")
160  .Arg(
161  "shape",
162  "*(type: Tuple(int))* New shape. Do not set if using "
163  "`new_shape` input.")
164  .Input(0, "data", "*(type: Tensor)* Input tensor.")
165  .Input(
166  1,
167  "new_shape",
168  "*(type: Tensor`<int>`)* [OPTIONAL] Tensor containing new shape.")
169  .Output(0, "reshaped", "*(type: Tensor)* Reshaped output tensor.")
170  .Output(
171  1,
172  "old_shape",
173  "*(type: Tensor`<int>`)* Tensor containing old shape of `data`.")
174  .InheritOnnxSchema();
175 
176 class GetReshapeGradient : public GradientMakerBase {
177  using GradientMakerBase::GradientMakerBase;
178  vector<OperatorDef> GetGradientDefs() override {
179  return SingleGradientDef(
180  "Reshape",
181  "",
182  vector<string>{GO(0), O(1)},
183  vector<string>{GI(0), "_" + GI(0) + "_dims"});
184  }
185 
186  // Argument `shape` is no longer needed in backprop.
187  bool CopyArguments() const override {
188  return false;
189  }
190 };
191 
192 REGISTER_GRADIENT(Reshape, GetReshapeGradient);
193 
194 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13