Caffe2 - C++ API
A deep learning, cross platform ML framework
reshape_op.cc
1 
17 #include "caffe2/operators/reshape_op.h"
18 #include "caffe2/utils/math.h"
19 
20 namespace caffe2 {
21 
22 REGISTER_CPU_OPERATOR(Reshape, ReshapeOp<float, CPUContext>);
23 
24 OPERATOR_SCHEMA(Reshape)
25  .NumInputs(1, 2)
26  .NumOutputs(2)
27  .TensorInferenceFunction(
28  [](const OperatorDef& def, const vector<TensorShape>& in) {
29  vector<TensorShape> out(2);
30 
31  // Do shape inference for old_shape
32  out[1].set_data_type(TensorProto::INT64);
33  out[1].add_dims(in[0].dims_size());
34 
35  ArgumentHelper helper(def);
36  if (!helper.HasArgument("shape")) {
37  // Cannot do shape inference for reshaped tensor from runtime data.
38  CAFFE_ENFORCE_EQ(
39  in.size(),
40  2,
41  "New shape must be specified by either the input blob or the "
42  "argument `shape`.");
43  out[0].set_unknown_shape(true);
44  return out;
45  }
46  CAFFE_ENFORCE_EQ(
47  in.size(),
48  1,
49  "New shape must not be specified by the input blob and the "
50  "argument `shape` at the same time.");
51 
52  // Infer the actual new shape
53  auto actualNewShape = helper.GetRepeatedArgument<int64_t>("shape");
54 
55  // Copy over the dimensions for those that are specified zero
56  // and check the eligibility of input
57  for (int i = 0; i < actualNewShape.size(); ++i) {
58  CAFFE_ENFORCE_GE(
59  actualNewShape[i],
60  -1,
61  "The dimensions in argument `shape` "
62  "must not be a negative number.");
63 
64  if (actualNewShape[i] == 0) {
65  CAFFE_ENFORCE_LT(
66  i,
67  in[0].dims_size(),
68  "Argument `shape` has a dimension set to zero that exceeds "
69  "the original dimension size.");
70  actualNewShape[i] = in[0].dims(i);
71  }
72  }
73 
74  // Check if the new shape is valid and fills in the missing dimension
75  // specified by -1.
76  int64_t totalSize = 1;
77  for (const auto d : in[0].dims()) {
78  totalSize *= d;
79  }
80  int64_t size = 1;
81  int unknownIdx = -1;
82  for (int i = 0; i < actualNewShape.size(); ++i) {
83  const auto dim = actualNewShape[i];
84  if (dim == -1) {
85  CAFFE_ENFORCE(
86  unknownIdx == -1,
87  "Argument `shape` has more than one missing dimension.");
88  unknownIdx = i;
89  } else {
90  size *= dim;
91  }
92  }
93 
94  if (unknownIdx != -1) {
95  CAFFE_ENFORCE(
96  totalSize % size == 0,
97  "Argument `shape` does not agree with the input data.",
98  " (",
99  totalSize,
100  " vs ",
101  size,
102  ")");
103  actualNewShape[unknownIdx] = totalSize / size;
104  } else {
105  CAFFE_ENFORCE_EQ(
106  totalSize,
107  size,
108  "Argument `shape` does not agree with the input data.",
109  " (",
110  totalSize,
111  " != ",
112  size,
113  ")");
114  }
115 
116  out[0].set_data_type(in[0].data_type());
117  for (const auto d : actualNewShape) {
118  out[0].add_dims(d);
119  }
120  return out;
121  })
122  .AllowInplace({{0, 0}})
123  .SetDoc(R"DOC(
124 Reshape the input tensor similar to numpy.reshape.
125 
126 It takes a tensor as input and an optional tensor specifying the new shape.
127 When the second input is absent, an extra argument `shape` must be specified.
128 It outputs the reshaped tensor as well as the original shape.
129 
130 At most one dimension of the new shape can be -1. In this case, the value is
131 inferred from the size of the tensor and the remaining dimensions. A dimension
132 could also be 0, in which case the actual dimension value is going to be copied
133 from the input tensor.
134 )DOC")
135  .Arg("shape", "New shape")
136  .Input(0, "data", "An input tensor.")
137  .Input(1, "new_shape", "New shape.")
138  .Output(0, "reshaped", "Reshaped data.")
139  .Output(1, "old_shape", "Original shape.");
140 
142  using GradientMakerBase::GradientMakerBase;
143  vector<OperatorDef> GetGradientDefs() override {
144  return SingleGradientDef(
145  "Reshape",
146  "",
147  vector<string>{GO(0), O(1)},
148  vector<string>{GI(0), "_" + GI(0) + "_dims"});
149  }
150 
151  // Argument `shape` is no longer needed in backprop.
152  bool CopyArguments() const override {
153  return false;
154  }
155 };
156 
157 REGISTER_GRADIENT(Reshape, GetReshapeGradient);
158 
159 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...