Caffe2 - C++ API
A deep learning, cross platform ML framework
concat_split_op.h
1 
17 #ifndef CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_
18 #define CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/core/types.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 namespace {
28 inline int GetDimFromOrderString(const string& str) {
29  auto order = StringToStorageOrder(str);
30  switch (order) {
31  case StorageOrder::NHWC:
32  return 3;
33  case StorageOrder::NCHW:
34  return 1;
35  default:
36  CAFFE_THROW("Unsupported storage order: ", str);
37  return -1;
38  }
39 }
40 } // namespace
41 
42 template <class Context>
43 class SplitOp final : public Operator<Context> {
44  public:
45  USE_OPERATOR_CONTEXT_FUNCTIONS;
46  SplitOp(const OperatorDef& operator_def, Workspace* ws)
47  : Operator<Context>(operator_def, ws),
48  split_(OperatorBase::GetRepeatedArgument<int>("split")) {
49  CAFFE_ENFORCE(
51  "You shouldn't specify both the dim to split, and the order "
52  "in the case of 4-D images.");
53  if (OperatorBase::HasArgument("axis")) {
54  axis_ = OperatorBase::GetSingleArgument<int>("axis", -1);
55  // only exists for computing the gradient of a Concat with 'add_axis'
56  add_axis_ = OperatorBase::GetSingleArgument<int>("add_axis", 0);
57  } else {
58  axis_ = GetDimFromOrderString(
59  OperatorBase::GetSingleArgument<string>("order", "NCHW"));
60  add_axis_ = 0;
61  }
62  }
63 
64  bool RunOnDevice() override;
65 
66  protected:
67  int axis_;
68  int add_axis_;
69  vector<int> split_;
70  // Input: X, optionally split
71  // The split tensor is stored in CPU.
72 };
73 
74 template <class Context>
75 class ConcatOp final : public Operator<Context> {
76  public:
77  USE_OPERATOR_CONTEXT_FUNCTIONS;
78  ConcatOp(const OperatorDef& operator_def, Workspace* ws)
79  : Operator<Context>(operator_def, ws) {
80  CAFFE_ENFORCE(
82  "You shouldn't specify both the dim to concat, and the order "
83  "in the case of 4-D images.");
84  if (OperatorBase::HasArgument("axis")) {
85  axis_ = OperatorBase::GetSingleArgument<int>("axis", -1);
86  add_axis_ = OperatorBase::GetSingleArgument<int>("add_axis", 0);
87  } else {
88  axis_ = GetDimFromOrderString(
89  OperatorBase::GetSingleArgument<string>("order", "NCHW"));
90  add_axis_ = 0;
91  }
92  }
93 
94  bool RunOnDevice() override;
95 
96  protected:
97  int axis_;
98  int add_axis_;
99  // Input: a number of tensors. Output: Y, split
100  // The split are stored in CPU.
101 };
102 
103 // Implementations
104 template <class Context>
106  auto& input = Input(0);
107  int canonical_axis = input.canonical_axis_index(axis_);
108  CAFFE_ENFORCE_LT(
109  canonical_axis, input.ndim(), "Axis not in input ndim range.");
110  const int input_channels = input.dim32(canonical_axis);
111  const int* axis_data;
112  vector<int> equal_split;
113  if (InputSize() == 2) {
114  // We obtain split from the input tensor.
115  CAFFE_ENFORCE_EQ(
116  split_.size(),
117  0,
118  "If you set split with an input blob, do not pass in "
119  "split in the argument.");
120  auto& split_tensor = OperatorBase::Input<TensorCPU>(1);
121  CAFFE_ENFORCE_EQ(split_tensor.size(), OutputSize());
122  axis_data = split_tensor.template data<int>();
123  } else if (split_.size() == 0) {
124  CAFFE_ENFORCE_EQ(
125  input_channels % OutputSize(),
126  0,
127  "If you did not specify split explicitly, the number of "
128  "input channels should be divisible by the output size.");
129  equal_split.resize(OutputSize(), input_channels / OutputSize());
130  axis_data = equal_split.data();
131  } else {
132  // We obtain split from the parameters.
133  CAFFE_ENFORCE_EQ(
134  split_.size(),
135  OutputSize(),
136  "The number of splits specified should be equal to the "
137  "number of outputs.");
138  axis_data = split_.data();
139  }
140 
141  CAFFE_ENFORCE_EQ(
142  add_axis_ ? OutputSize()
143  : std::accumulate(axis_data, axis_data + OutputSize(), 0),
144  input_channels,
145  "Sum of split dimensions do not match: should be ",
146  input_channels);
147  vector<TIndex> output_dims(input.dims());
148  int before = 1, after = 1;
149  for (int i = 0; i < canonical_axis; ++i) {
150  before *= input.dim32(i);
151  }
152  for (int i = canonical_axis + 1; i < input.ndim(); ++i) {
153  after *= input.dim32(i);
154  }
155  if (add_axis_) {
156  output_dims.erase(output_dims.begin() + canonical_axis);
157  }
158  size_t input_offset = 0;
159  for (int i = 0; i < OutputSize(); ++i) {
160  auto* output = Output(i);
161  auto axis_dim = add_axis_ ? 1 : axis_data[i];
162  if (!add_axis_) {
163  output_dims[canonical_axis] = axis_data[i];
164  }
165  output->Resize(output_dims);
166  math::CopyMatrix<Context>(
167  input.itemsize(),
168  before,
169  axis_dim * after,
170  static_cast<const char*>(input.raw_data()) + input_offset,
171  input.dim32(canonical_axis) * after,
172  output->raw_mutable_data(input.meta()),
173  axis_dim * after,
174  &context_,
175  input.meta().copy());
176  input_offset += axis_dim * after * input.itemsize();
177  }
178  return true;
179 }
180 
181 template <class Context>
183  auto* output = Output(0);
184  TensorCPU* split = OperatorBase::Output<TensorCPU>(1);
185  split->Resize(vector<TIndex>(1, InputSize()));
186  int* axis_data = split->template mutable_data<int>();
187  auto& input_zero = Input(0);
188  int adj_size = input_zero.ndim() + (add_axis_ ? 1 : 0);
189  int canonical_axis = canonical_axis_index_(axis_, adj_size);
190  CAFFE_ENFORCE_LT(
191  canonical_axis,
192  adj_size,
193  "Axis not in input ndim range.");
194  for (int i = 1; i < InputSize(); ++i) {
195  CAFFE_ENFORCE(
196  Input(i).meta() == input_zero.meta(),
197  "All inputs must have the same type, expected: ",
198  input_zero.meta().name(),
199  " but got: ",
200  Input(i).meta().name(),
201  " for input: ",
202  i);
203  }
204 
205  int before = 1, after = 1;
206  vector<TIndex> output_dims(input_zero.dims());
207  for (int i = 0; i < input_zero.ndim(); ++i) {
208  if (i == canonical_axis && !add_axis_) {
209  continue;
210  }
211  int dim = input_zero.dim32(i);
212  if (i < canonical_axis) {
213  before *= dim;
214  } else { // i > canonical_axis || i == canonical_axis && add_axis_
215  after *= dim;
216  }
217  // check the input dims are compatible.
218  for (int j = 1; j < InputSize(); ++j) {
219  int dim_j = Input(j).dim32(i);
220  CAFFE_ENFORCE(
221  dim == dim_j,
222  "Expect dimension = ",
223  dim,
224  " got ",
225  dim_j,
226  " at axis = ",
227  i,
228  " for input: ",
229  j,
230  ". The input tensors can only have different dimensions "
231  "when arg 'add_axis' = 0 and along the axis = ",
232  canonical_axis,
233  " <",
234  Input(0).dims(),
235  "> vs <",
236  Input(j).dims(),
237  ">.");
238  }
239  }
240 
241  int output_channels = 0;
242  for (int i = 0; i < InputSize(); ++i) {
243  axis_data[i] = add_axis_ ? 1 : Input(i).dim32(canonical_axis);
244  output_channels += axis_data[i];
245  }
246  if (add_axis_) {
247  output_dims.insert(output_dims.begin() + canonical_axis, output_channels);
248  } else {
249  output_dims[canonical_axis] = output_channels;
250  }
251  output->Resize(output_dims);
252  size_t output_offset = 0;
253  for (int i = 0; i < InputSize(); ++i) {
254  auto& input = Input(i);
255  auto axis_dim = add_axis_ ? 1 : input.dim32(canonical_axis);
256  math::CopyMatrix<Context>(
257  input.itemsize(),
258  before,
259  axis_dim * after,
260  input.raw_data(),
261  axis_dim * after,
262  static_cast<char*>(output->raw_mutable_data(input_zero.meta())) +
263  output_offset,
264  output_channels * after,
265  &context_,
266  input_zero.meta().copy());
267  output_offset += axis_dim * after * input.itemsize();
268  }
269  return true;
270 }
271 
272 } // namespace caffe2
273 
274 #endif // CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_
int dim32(const int i) const
Returns the i-th dimension of the tensor in int.
Definition: tensor.h:673
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:304
Copyright (c) 2016-present, Facebook, Inc.
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:52