Caffe2 - C++ API
A deep learning, cross platform ML framework
reshape_op.h
1 
17 #ifndef CAFFE2_OPERATORS_RESHAPE_OP_H_
18 #define CAFFE2_OPERATORS_RESHAPE_OP_H_
19 
20 #include "caffe2/core/common_omp.h"
21 #include "caffe2/core/context.h"
22 #include "caffe2/core/logging.h"
23 #include "caffe2/core/operator.h"
24 #include "caffe2/utils/math.h"
25 
26 namespace caffe2 {
27 
28 // Takes a shape and data tensor and reshapes it
29 template <typename F, class Context>
30 class ReshapeOp : public Operator<Context> {
31  public:
32  USE_OPERATOR_CONTEXT_FUNCTIONS;
33  ReshapeOp(const OperatorDef& operator_def, Workspace* ws)
34  : Operator<Context>(operator_def, ws),
35  new_shape_(OperatorBase::GetRepeatedArgument<int64_t>("shape")) {}
36 
37  bool RunOnDevice() override {
38  if (InputSize() == 2) {
39  return DispatchHelper<TensorTypes<int, int64_t>>::call(this, Input(1));
40  }
41  CAFFE_ENFORCE(
42  OperatorBase::HasArgument("shape"), "Argument `shape` is missing.");
43  return this->template DoRunWithType<int64_t>();
44  }
45 
46  template <typename T>
47  bool DoRunWithType() {
48  auto& input = Input(0);
49 
50  vector<int64_t> actual_new_shape = new_shape_;
51  if (InputSize() == 2) {
52  CAFFE_ENFORCE(
53  !OperatorBase::HasArgument("shape"),
54  "New shape is specified by the input blob, do not pass in "
55  "the argument `shape`.");
56 
57  auto& shape = Input(1);
58  CAFFE_ENFORCE(shape.ndim() == 1, "Shape should be 1-D");
59 
60  const T* shape_data = shape.template data<T>();
61 
62  // Bit awkward, but needed so works on both CPU and CUDA contexts
63  std::vector<T> tmpv(shape.size());
64  context_.template CopyBytes<Context, CPUContext>(
65  shape.size() * sizeof(T), shape_data, &tmpv[0]);
66  actual_new_shape.assign(tmpv.begin(), tmpv.begin() + shape.size());
67  }
68 
69  // Copy over the dimensions for those that are specified zero.
70  for (int i = 0; i < actual_new_shape.size(); ++i) {
71  if (actual_new_shape[i] == 0) {
72  actual_new_shape[i] = input.dim(i);
73  }
74  }
75 
76  // Checks if the new shape is valid and fills in the missing dimension
77  // specified by -1.
78  // NOTE: At most one dimension can be -1.
79  auto total_size = input.size_from_dim(0);
80  T size = 1;
81  int unknown_idx = -1;
82  for (int i = 0; i < actual_new_shape.size(); ++i) {
83  const auto dim = actual_new_shape[i];
84  if (dim == -1) {
85  CAFFE_ENFORCE(
86  unknown_idx == -1,
87  "Argument `shape` has more than one missing dimension.");
88  unknown_idx = i;
89  } else {
90  size *= dim;
91  }
92  }
93 
94  if (unknown_idx != -1) {
95  CAFFE_ENFORCE(
96  total_size % size == 0,
97  "Argument `shape` does not agree with the input data.",
98  " (",
99  total_size,
100  " vs ",
101  size,
102  ")");
103  actual_new_shape[unknown_idx] = total_size / size;
104  } else {
105  CAFFE_ENFORCE_EQ(
106  total_size,
107  size,
108  "Argument `shape` does not agree with the input data.",
109  " (",
110  total_size,
111  " != ",
112  size,
113  ")");
114  }
115 
116  // Write the original shape to the second output.
117  auto* old_shape = Output(1);
118  old_shape->Resize(input.ndim());
119  T* old_shape_data = old_shape->template mutable_data<T>();
120  for (int i = 0; i < input.ndim(); ++i) {
121  math::Set<T, Context>(1, input.dim(i), old_shape_data + i, &context_);
122  }
123 
124  auto* output = Output(0);
125  output->Resize(actual_new_shape);
126  if (output != &input) {
127  // If we are not doing in-place computation, a copy is needed.
128  context_.template CopyItems<Context, Context>(
129  input.meta(),
130  input.size(),
131  input.raw_data(),
132  output->raw_mutable_data(input.meta()));
133  }
134 
135  return true;
136  }
137 
138  private:
139  vector<int64_t> new_shape_;
140 };
141 
142 } // namespace caffe2
143 
144 #endif // CAFFE2_OPERATORS_RESHAPE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:52