Caffe2 - C++ API
A deep learning, cross platform ML framework
reshape_op.h
1 #ifndef CAFFE2_OPERATORS_RESHAPE_OP_H_
2 #define CAFFE2_OPERATORS_RESHAPE_OP_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 // Takes a shape and data tensor and reshapes it
13 template <typename F, class Context>
14 class ReshapeOp : public Operator<Context> {
15  public:
16  USE_OPERATOR_CONTEXT_FUNCTIONS;
17  template <class... Args>
18  explicit ReshapeOp(Args&&... args)
19  : Operator<Context>(std::forward<Args>(args)...),
20  new_shape_(this->template GetRepeatedArgument<int64_t>("shape")) {}
21 
22  bool RunOnDevice() override {
23  if (InputSize() == 2) {
24  return DispatchHelper<TensorTypes<int, int64_t>>::call(this, Input(1));
25  }
26  CAFFE_ENFORCE(
27  OperatorBase::HasArgument("shape"), "Argument `shape` is missing.");
28  return this->template DoRunWithType<int64_t>();
29  }
30 
31  template <typename T>
32  bool DoRunWithType() {
33  DoRunWithTypeImpl<T>(Input(0), Output(0));
34  return true;
35  }
36 
37  protected:
38  template <typename T>
39  void DoRunWithTypeImpl(const Tensor& input, Tensor* output) {
40  vector<int64_t> actual_new_shape = new_shape_;
41  if (InputSize() == 2) {
42  CAFFE_ENFORCE(
43  !OperatorBase::HasArgument("shape"),
44  "New shape is specified by the input blob, do not pass in "
45  "the argument `shape`.");
46 
47  auto& shape = Input(1);
48  CAFFE_ENFORCE(shape.dim() == 1, "Shape should be 1-D");
49 
50  const T* shape_data = shape.template data<T>();
51 
52  // Bit awkward, but needed so works on both CPU and CUDA contexts
53  std::vector<T> tmpv(shape.numel());
54  if (shape.numel() > 0) {
55  context_.CopyBytesToCPU(
56  shape.numel() * sizeof(T), shape_data, &tmpv[0]);
57  actual_new_shape.assign(tmpv.begin(), tmpv.begin() + shape.numel());
58  }
59  }
60 
61  // Copy over the dimensions for those that are specified zero.
62  for (size_t i = 0; i < actual_new_shape.size() && i < input.dim(); ++i) {
63  if (actual_new_shape[i] == 0) {
64  actual_new_shape[i] = input.size(i);
65  }
66  }
67 
68  // Checks if the new shape is valid and fills in the missing dimension
69  // specified by -1.
70  // NOTE: At most one dimension can be -1.
71  auto total_size = input.numel();
72  T size = 1;
73  int unknown_idx = -1;
74  for (int i = 0; i < actual_new_shape.size(); ++i) {
75  const auto dim = actual_new_shape[i];
76  if (dim == -1) {
77  CAFFE_ENFORCE(
78  unknown_idx == -1,
79  "Argument `shape` has more than one missing dimension.");
80  unknown_idx = i;
81  } else {
82  size *= dim;
83  }
84  }
85  if (size == 0 && total_size != 0) {
86  CAFFE_THROW("Can not reshape a non-zero size (", total_size, ") tensor to zero size.");
87  }
88 
89  if (unknown_idx != -1) {
90  CAFFE_ENFORCE_NE(
91  size,
92  0,
93  "New shape at dim ",
94  unknown_idx,
95  " can not be inferred since new size is zero.");
96  CAFFE_ENFORCE(
97  total_size % size == 0,
98  "Argument `shape` does not agree with the input data.",
99  " (",
100  total_size,
101  " vs ",
102  size,
103  ")");
104  actual_new_shape[unknown_idx] = total_size / size;
105  } else {
106  CAFFE_ENFORCE_EQ(
107  total_size,
108  size,
109  "Argument `shape` does not agree with the input data.",
110  " (",
111  total_size,
112  " != ",
113  size,
114  ")");
115  }
116 
117  // Write the original shape to the second output.
118 
119  auto* old_shape = Output(1, {input.dim()}, at::dtype<T>());
120  T* old_shape_data = old_shape->template mutable_data<T>();
121  for (int i = 0; i < input.dim(); ++i) {
122  math::Set<T, Context>(1, input.size(i), old_shape_data + i, &context_);
123  }
124 
125  output->Resize(actual_new_shape);
126  if (output != &input) {
127  // If we are not doing in-place computation, a copy is needed.
128  context_.CopyItemsSameDevice(
129  input.dtype(),
130  input.numel(),
131  input.raw_data(),
132  output->raw_mutable_data(input.dtype()));
133  }
134  }
135 
136  private:
137  vector<int64_t> new_shape_;
138 };
139 
140 } // namespace caffe2
141 
142 #endif // CAFFE2_OPERATORS_RESHAPE_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70