Caffe2 - C++ API
A deep learning, cross platform ML framework
slice_op.cc
1 
17 #include "caffe2/operators/slice_op.h"
18 #include "caffe2/utils/math.h"
19 
20 namespace caffe2 {
21 
22 REGISTER_CPU_OPERATOR(Slice, SliceOp<int, CPUContext>);
23 REGISTER_CPU_OPERATOR(SliceGradient, SliceGradientOp<int, CPUContext>);
24 
25 OPERATOR_SCHEMA(Slice)
26  .NumInputs(1, 3)
27  .NumOutputs(1)
28  .SetDoc(R"DOC(
29 Produces a slice of the input tensor. Currently, only slicing in a single
30 dimension is supported.
31 Slices are passed as 2 1D vectors or as two keyword argument lists with starting
32 and end indices for each dimension of the input `data` tensor. If a negative
33 value is passed for any of the start or end indices, it represents the number of
34 elements before the end of that dimension. End indices are non-inclusive unless
35 negative (end index -1 means up to and including the last element).
36 
37 
38 Example:
39 
40  data = [
41  [1, 2, 3, 4],
42  [5, 6, 7, 8],
43  ]
44  starts = [0, 1]
45  ends = [-1, 3]
46 
47  result = [
48  [2, 3],
49  [6, 7],
50  ]
51 )DOC")
52  .Input(0, "data", "Tensor of data to extract slices from.")
53  .Input(1, "starts", "1D tensor: start-indices for each dimension of data.")
54  .Input(2, "ends", "1D tensor: end-indices for each dimension of data.")
55  .Arg("starts", "List of starting indices")
56  .Arg("ends", "List of ending indices")
57  .TensorInferenceFunction([](const OperatorDef& def,
58  const vector<TensorShape>& in) {
59  if (in.size() > 1) {
60  // Cannot compute shape inference when the splits are defined
61  // in data.
62  return vector<TensorShape>();
63  }
64  auto const& data = in[0];
65 
66  ArgumentHelper helper(def);
67  auto starts = helper.GetRepeatedArgument<int>("starts", vector<int>());
68  auto ends = helper.GetRepeatedArgument<int>("ends", vector<int>());
69  vector<int> dst_sizes(data.dims_size());
70 
71  for (int i = 0; i < data.dims_size(); ++i) {
72  if (i >= starts.size()) {
73  continue;
74  }
75  if (data.dims_size() > 0) {
76  auto start = starts[i];
77  auto end = ends[i];
78  if (start < 0) {
79  start = data.dims(i) + 1 + start;
80  }
81  if (end < 0) {
82  end = data.dims(i) + 1 + end;
83  }
84  dst_sizes[i] = end - start;
85  } else {
86  dst_sizes[i] = 0;
87  }
88  }
89  return vector<TensorShape>{
90  CreateTensorShape(dst_sizes, data.data_type())};
91  })
92  .Output(0, "output", "Sliced data tensor.");
93 
94 OPERATOR_SCHEMA(SliceGradient);
95 
96 namespace {
97 struct GetSliceGradient : public GradientMakerBase {
98  using GradientMakerBase::GradientMakerBase;
99  vector<OperatorDef> GetGradientDefs() override {
100  if (def_.input_size() > 1) {
101  return vector<OperatorDef>{CreateOperatorDef(
102  "SliceGradient",
103  "",
104  std::vector<string>{I(0), I(1), I(2), GO(0)},
105  std::vector<string>{GI(0)})};
106  } else {
107  return vector<OperatorDef>{CreateOperatorDef(
108  "SliceGradient",
109  "",
110  std::vector<string>{I(0), GO(0)},
111  std::vector<string>{GI(0)})};
112  }
113  }
114 };
115 }
116 REGISTER_GRADIENT(Slice, GetSliceGradient);
117 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.