Caffe2 - C++ API
A deep learning, cross platform ML framework
concat_split_op.cc
1 
17 #include "caffe2/operators/concat_split_op.h"
18 
19 namespace caffe2 {
20 REGISTER_CPU_OPERATOR(Split, SplitOp<CPUContext>);
21 REGISTER_CPU_OPERATOR(Concat, ConcatOp<CPUContext>);
22 OPERATOR_SCHEMA(Split)
23  .NumInputs(1, 2)
24  .NumOutputs(1, INT_MAX)
25  .Input(0, "input", "The tensor to split")
26  .Input(1, "split", "Optional list of output lengths (see also arg 'split')")
27  .Arg("axis", "Which axis to split on")
28  .Arg("split", "length of each output")
29  .Arg("order", "Either NHWC or NCWH, will split on C axis, defaults to NCHW")
30  .SetDoc(R"DOC(
31 Split a tensor into a list of tensors, along the specified
32 'axis'. The lengths of the split can be specified using argument 'axis' or
33 optional second input blob to the operator. Otherwise, the tensor is split
34 to equal sized parts.
35 )DOC");
36 
37 namespace {
38 OpSchema::Cost CostInferenceForConcat(
39  const OperatorDef& def,
40  const vector<TensorShape>& in) {
41  ArgumentHelper helper(def);
42  const int axis = helper.HasArgument("axis")
43  ? helper.GetSingleArgument<int>("axis", -1)
44  : GetDimFromOrderString(
45  helper.GetSingleArgument<string>("order", "NCHW"));
46  bool add_axis = helper.GetSingleArgument<int>("add_axis", 0) != 0;
47  const int canonical_axis = canonical_axis_index_(axis, in[0].dims_size());
48  CAFFE_ENFORCE_GT(in.size(), 0);
49  vector<int> out_shape(in[0].dims().begin(), in[0].dims().end());
50  if (add_axis) {
51  out_shape.insert(out_shape.begin() + canonical_axis, in.size());
52  } else {
53  for (int i = 1; i < in.size(); ++i) {
54  out_shape[canonical_axis] += in[i].dims(canonical_axis);
55  }
56  }
57  int size = 1;
58  for (auto& s : out_shape) {
59  size *= s;
60  }
61 
62  struct OpSchema::Cost cost;
63  cost.flops = size;
64  cost.bytes_moved = size * sizeof(float);
65  cost.params_bytes = 0;
66  return cost;
67 }
68 } // namespace
69 
70 OPERATOR_SCHEMA(Concat)
71  .NumInputs(1, INT_MAX)
72  .NumOutputs(2)
73  .Arg("axis", "Which axis to concat on")
74  .Arg("order", "Either NHWC or NCHW, will concat on C axis, defaults to NCHW")
75  .Arg(
76  "add_axis",
77  "Pass 1 to add the axis specified in arg 'axis' to all "
78  "input tensors")
79  .TensorInferenceFunction([](const OperatorDef& def,
80  const vector<TensorShape>& in) {
81  ArgumentHelper helper(def);
82  const int axis = helper.HasArgument("axis")
83  ? helper.GetSingleArgument<int>("axis", -1)
84  : GetDimFromOrderString(
85  helper.GetSingleArgument<string>("order", "NCHW"));
86  bool add_axis = helper.GetSingleArgument<int>("add_axis", 0) != 0;
87  const int canonical_axis = canonical_axis_index_(axis, in[0].dims_size());
88  CAFFE_ENFORCE_GT(in.size(), 0);
89  vector<int> split_shape(1, in.size());
90  vector<int> out_shape(in[0].dims().begin(), in[0].dims().end());
91  if (add_axis) {
92  out_shape.insert(out_shape.begin() + canonical_axis, in.size());
93  } else {
94  for (int i = 1; i < in.size(); ++i) {
95  out_shape[canonical_axis] += in[i].dims(canonical_axis);
96  }
97  }
98  if (def.output_size() == 1) {
99  return vector<TensorShape>{
100  CreateTensorShape(out_shape, in[0].data_type())};
101  }
102  return vector<TensorShape>{
103  CreateTensorShape(out_shape, in[0].data_type()),
104  CreateTensorShape(split_shape, TensorProto::INT32)};
105  })
106  .CostInferenceFunction(CostInferenceForConcat)
107  .SetDoc("Concatenate a list of tensors into a single tensor")
108  .Output(0, "concat_result", "Concatenated tensor")
109  .Output(1, "split_info", "The dimensions of the inputs.");
110 
111 // Backward compatibility names.
112 REGISTER_CPU_OPERATOR(DepthSplit, SplitOp<CPUContext>);
113 REGISTER_CPU_OPERATOR(DepthConcat, ConcatOp<CPUContext>);
114 OPERATOR_SCHEMA(DepthSplit)
115  .NumInputs(1, 2)
116  .NumOutputs(1, INT_MAX)
117  .SetDoc("Backward compatible operator name for Split.");
118 OPERATOR_SCHEMA(DepthConcat)
119  .NumInputs(1, INT_MAX)
120  .NumOutputs(2)
121  .SetDoc("Backward compatible operator name for Concat.");
122 
123 class GetSplitGradient : public GradientMakerBase {
124  using GradientMakerBase::GradientMakerBase;
125  vector<OperatorDef> GetGradientDefs() override {
126  vector<string> output_grads;
127  for (int i = 0; i < def_.output_size(); ++i) {
128  if (!GradOut(i).IsEmpty()) {
129  output_grads.push_back(GO(i));
130  }
131  }
132  if (output_grads.empty()) {
133  return {};
134  }
135  return SingleGradientDef(
136  "Concat", "", output_grads,
137  vector<string>{GI(0), "_" + GI(0) + "_dims"});
138  }
139 };
140 REGISTER_GRADIENT(Split, GetSplitGradient);
141 REGISTER_GRADIENT(DepthSplit, GetSplitGradient);
142 
143 class GetConcatGradient : public GradientMakerBase {
144  using GradientMakerBase::GradientMakerBase;
145  vector<OperatorDef> GetGradientDefs() override {
146  if (GradOut(0).IsEmpty()) {
147  return {};
148  }
149  vector<string> grads;
150  for (int i = 0; i < def_.input_size(); ++i) {
151  grads.push_back(GI(i));
152  }
153  return SingleGradientDef(
154  "Split", "", vector<string>{GO(0), O(1)}, grads);
155  }
156 };
157 REGISTER_GRADIENT(Concat, GetConcatGradient);
158 REGISTER_GRADIENT(DepthConcat, GetConcatGradient);
159 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...