Caffe2 - C++ API
A deep learning, cross platform ML framework
reshape_op.cc
1 #include <caffe2/ideep/ideep_utils.h>
2 
3 namespace caffe2 {
4 
5 // Takes a shape and data tensor and reshapes it
6 class IDEEPReshapeOp final : public IDEEPOperator {
7  public:
8  USE_IDEEP_DEF_ALIASES();
9  USE_IDEEP_OPERATOR_FUNCTIONS();
10 
11  IDEEPReshapeOp(const OperatorDef& operator_def, Workspace* ws)
12  : IDEEPOperator(operator_def, ws),
13  new_shape_(OperatorBase::GetRepeatedArgument<int>("shape")) {}
14 
15  bool RunOnDevice() override {
16  ideep::tensor::dims actual_new_shape = new_shape_;
17  if (InputSize() == 2) {
18  CAFFE_ENFORCE(
19  !OperatorBase::HasArgument("shape"),
20  "New shape is specified by the input blob, do not pass in "
21  "the argument `shape`.");
22 
23  // shape info live on CPU
24  auto& shape = OperatorBase::Input<TensorCPU>(1, CPU);
25  CAFFE_ENFORCE(shape.ndim() == 1, "Shape should be 1-D");
26  actual_new_shape.reserve(shape.size());
27  if (shape.template IsType<int>()) {
28  const int* shape_data = shape.template data<int>();
29  actual_new_shape.assign(shape_data, shape_data + shape.size());
30  } else if (shape.template IsType<int64_t>()) {
31  const int64_t* shape_data = shape.template data<int64_t>();
32  for (int i = 0; i < shape.size(); ++i) {
33  actual_new_shape.push_back(static_cast<int>(shape_data[i]));
34  }
35  } else {
36  CAFFE_THROW(
37  "IDEEP reshape only supports shape data in int32_t or int64_t");
38  }
39  } else {
40  CAFFE_ENFORCE(
41  OperatorBase::HasArgument("shape"), "Argument `shape` is missing.");
42  }
43 
44  auto& input = Input(0);
45  // Copy over the dimensions for those that are specified zero.
46  for (int i = 0; i < actual_new_shape.size() && i < input.ndims(); ++i) {
47  if (actual_new_shape[i] == 0) {
48  actual_new_shape[i] = input.get_dim(i);
49  }
50  }
51 
52  // Checks if the new shape is valid and fills in the missing dimension
53  // specified by -1.
54  // NOTE: At most one dimension can be -1.
55  auto total_size = input.get_nelems();
56  int size = 1;
57  int unknown_idx = -1;
58  for (int i = 0; i < actual_new_shape.size(); ++i) {
59  const auto dim = actual_new_shape[i];
60  if (dim == -1) {
61  CAFFE_ENFORCE(
62  unknown_idx == -1,
63  "Argument `shape` has more than one missing dimension.");
64  unknown_idx = i;
65  } else {
66  size *= dim;
67  }
68  }
69  if (size == 0 && total_size != 0) {
70  CAFFE_THROW(
71  "Can not reshape a non-zero size (",
72  total_size,
73  ") tensor to zero size.");
74  }
75 
76  if (unknown_idx != -1) {
77  CAFFE_ENFORCE_NE(
78  size,
79  0,
80  "New shape at dim ",
81  unknown_idx,
82  " can not be inferred since new size is zero.");
83  CAFFE_ENFORCE(
84  total_size % size == 0,
85  "Argument `shape` does not agree with the input data.",
86  " (",
87  total_size,
88  " vs ",
89  size,
90  ")");
91  actual_new_shape[unknown_idx] = total_size / size;
92  } else {
93  CAFFE_ENFORCE_EQ(
94  total_size,
95  size,
96  "Argument `shape` does not agree with the input data.",
97  " (",
98  total_size,
99  " != ",
100  size,
101  ")");
102  }
103 
104  // Write the original shape to the second output.
105  // shape info live on CPU
106  TensorCPU* old_shape = OperatorBase::Output<TensorCPU>(1, CPU);
107  old_shape->Resize(input.ndims());
108  int* old_shape_data = old_shape->template mutable_data<int>();
109  for (int i = 0; i < input.ndims(); ++i) {
110  old_shape_data[i] = input.get_dim(i);
111  }
112 
113  auto* output = Output(0);
114  if (output != &input) {
115  // If we are not doing in-place computation, a copy is needed.
116  output->reinit_like(input);
117  ideep::direct_copy::compute(input, *output);
118  }
119 
120  output->reshape(actual_new_shape);
121  return true;
122  }
123 
124  private:
125  ideep::tensor::dims new_shape_;
126 };
127 
128 REGISTER_IDEEP_OPERATOR(Reshape, IDEEPReshapeOp);
129 
130 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70