1 #ifndef CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_ 2 #define CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/core/types.h" 7 #include "caffe2/utils/math.h" 12 inline int GetDimFromOrderString(
const string& str) {
13 auto order = StringToStorageOrder(str);
15 case StorageOrder::NHWC:
17 case StorageOrder::NCHW:
20 CAFFE_THROW(
"Unsupported storage order: ", str);
26 template <
class Context>
29 static const int kSplitOpInputSize = 2;
31 USE_OPERATOR_CONTEXT_FUNCTIONS;
32 template <
class... Args>
33 explicit SplitOp(Args&&... args)
35 split_(this->
template GetRepeatedArgument<int>(
"split")) {
39 "You shouldn't specify both the dim to split, and the order " 40 "in the case of 4-D images.");
42 axis_ = this->
template GetSingleArgument<int>(
"axis", -1);
44 add_axis_ = this->
template GetSingleArgument<int>(
"add_axis", 0);
46 axis_ = GetDimFromOrderString(
47 this->
template GetSingleArgument<string>(
"order",
"NCHW"));
52 bool RunOnDevice()
override;
62 template <
class Context>
65 USE_OPERATOR_CONTEXT_FUNCTIONS;
66 template <
class... Args>
72 "You shouldn't specify both the dim to split, and the order " 73 "in the case of 4-D images.");
75 axis_ = this->
template GetSingleArgument<int>(
"axis", 0);
77 axis_ = GetDimFromOrderString(
78 this->
template GetSingleArgument<string>(
"order",
"NCHW"));
82 bool RunOnDevice()
override;
86 Tensor inclusive_scan_buffer_{Context::GetDeviceType()};
87 Tensor inclusive_scan_length_buffer_{Context::GetDeviceType()};
92 template <
class Context>
95 USE_OPERATOR_CONTEXT_FUNCTIONS;
96 template <
class... Args>
102 "You shouldn't specify both the dim to concat, and the order " 103 "in the case of 4-D images.");
105 axis_ = this->
template GetSingleArgument<int>(
"axis", -1);
106 add_axis_ = this->
template GetSingleArgument<int>(
"add_axis", 0);
108 axis_ = GetDimFromOrderString(
109 this->
template GetSingleArgument<string>(
"order",
"NCHW"));
114 bool RunOnDevice()
override;
124 template <
class Context>
126 auto& input =
Input(0);
127 int canonical_axis = input.canonical_axis_index(axis_);
129 canonical_axis, input.dim(),
"Axis not in input ndim range.");
130 const int input_channels = input.dim32(canonical_axis);
131 const int* axis_data;
132 vector<int> equal_split;
133 if (InputSize() == kSplitOpInputSize) {
138 "If you set split with an input blob, do not pass in " 139 "split in the argument.");
140 auto& split_tensor = this->
template Input<Tensor>(1, CPU);
141 CAFFE_ENFORCE_EQ(split_tensor.numel(), OutputSize());
142 axis_data = split_tensor.template data<int>();
143 }
else if (split_.size() == 0) {
145 input_channels % OutputSize(),
147 "If you did not specify split explicitly, the number of " 148 "input channels should be divisible by the output size.");
149 equal_split.resize(OutputSize(), input_channels / OutputSize());
150 axis_data = equal_split.data();
156 "The number of splits specified should be equal to the " 157 "number of outputs.");
158 axis_data = split_.data();
162 add_axis_ ? OutputSize()
163 : std::accumulate(axis_data, axis_data + OutputSize(), 0),
165 "Sum of split dimensions do not match: should be ",
167 vector<int64_t> output_dims(input.sizes().vec());
168 int before = 1, after = 1;
169 for (
int i = 0; i < canonical_axis; ++i) {
170 before *= input.dim32(i);
172 for (
int i = canonical_axis + 1; i < input.dim(); ++i) {
173 after *= input.dim32(i);
176 output_dims.erase(output_dims.begin() + canonical_axis);
178 size_t input_offset = 0;
179 for (
int i = 0; i < OutputSize(); ++i) {
180 auto* output = Output(i);
181 auto axis_dim = add_axis_ ? 1 : axis_data[i];
183 output_dims[canonical_axis] = axis_data[i];
185 output->Resize(output_dims);
186 math::CopyMatrix<Context>(
190 static_cast<const char*
>(input.raw_data()) + input_offset,
191 input.dim32(canonical_axis) * after,
192 output->raw_mutable_data(input.dtype()),
195 input.dtype().copy());
196 input_offset += axis_dim * after * input.itemsize();
202 template <
class Context>
204 auto& input =
Input(0);
205 auto& length = this->
template Input<Tensor>(1, CPU);
206 auto length_length = length.numel();
208 length_length % OutputSize(),
210 "len(Lengths) should be divisible by OutputSize().");
211 int canonical_axis = input.canonical_axis_index(axis_);
213 canonical_axis, input.dim(),
"Axis not in input ndim range.");
214 const int input_channels = input.dim32(canonical_axis);
215 const auto* axis_data = length.template data<int>();
217 std::accumulate(axis_data, axis_data + length.numel(), 0),
219 "Sum of split dimensions do not match: should be ",
221 vector<int64_t> output_dims(input.sizes().vec());
222 int before = input.size_to_dim(canonical_axis);
223 int after = input.size_from_dim(canonical_axis + 1);
224 size_t input_offset = 0;
225 for (
int i = 0; i < OutputSize(); ++i) {
226 auto* output = Output(i);
227 const auto* axis_offset = axis_data + length_length / OutputSize() * i;
228 auto axis_dim = std::accumulate(
229 axis_offset, axis_offset + length_length / OutputSize(), 0);
230 output_dims[canonical_axis] = axis_dim;
231 output->Resize(output_dims);
232 math::CopyMatrix<Context>(
236 static_cast<const char*
>(input.raw_data()) + input_offset,
237 input.dim32(canonical_axis) * after,
238 output->raw_mutable_data(input.dtype()),
241 input.dtype().copy());
242 input_offset += axis_dim * after * input.itemsize();
247 template <
class Context>
249 auto* output = Output(0);
254 1, std::vector<int64_t>(1, InputSize()), at::dtype<int>().device(CPU));
255 int* axis_data = split->template mutable_data<int>();
256 auto& input_zero =
Input(0);
257 int adj_size = input_zero.dim() + (add_axis_ ? 1 : 0);
258 int canonical_axis = canonical_axis_index_(axis_, adj_size);
259 CAFFE_ENFORCE_LT(canonical_axis, adj_size,
"Axis not in input ndim range.");
260 for (
int i = 1; i < InputSize(); ++i) {
262 Input(i).dtype() == input_zero.dtype(),
263 "All inputs must have the same type, expected: ",
264 input_zero.dtype().name(),
266 Input(i).dtype().name(),
271 int before = 1, after = 1;
272 vector<int64_t> output_dims(input_zero.sizes().vec());
273 for (
int i = 0; i < input_zero.dim(); ++i) {
274 if (i == canonical_axis && !add_axis_) {
277 int dim = input_zero.dim32(i);
278 if (i < canonical_axis) {
284 for (
int j = 1; j < InputSize(); ++j) {
285 int dim_j =
Input(j).dim32(i);
288 "Expect dimension = ",
296 ". The input tensors can only have different dimensions " 297 "when arg 'add_axis' = 0 and along the axis = ",
307 int output_channels = 0;
308 for (
int i = 0; i < InputSize(); ++i) {
309 axis_data[i] = add_axis_ ? 1 :
Input(i).dim32(canonical_axis);
310 output_channels += axis_data[i];
313 output_dims.insert(output_dims.begin() + canonical_axis, output_channels);
315 output_dims[canonical_axis] = output_channels;
317 output->Resize(output_dims);
318 size_t output_offset = 0;
319 for (
int i = 0; i < InputSize(); ++i) {
320 auto& input =
Input(i);
321 auto axis_dim = add_axis_ ? 1 : input.dim32(canonical_axis);
322 math::CopyMatrix<Context>(
328 static_cast<char*
>(output->raw_mutable_data(input_zero.dtype())) +
330 output_channels * after,
332 input_zero.dtype().copy());
333 output_offset += axis_dim * after * input.itemsize();
339 const OperatorDef& def,
340 const std::vector<TensorShape>& in);
342 std::vector<TensorShape> TensorInferenceForConcat(
343 const OperatorDef& def,
344 const std::vector<TensorShape>& in);
348 #endif // CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.