Caffe2 - C++ API
A deep learning, cross platform ML framework
slice_op.cc
1 #include "caffe2/operators/slice_op.h"
2 #include "caffe2/utils/math.h"
3 
4 namespace caffe2 {
5 
6 REGISTER_CPU_OPERATOR(Slice, SliceOp<CPUContext>);
7 REGISTER_CPU_GRADIENT_OPERATOR(SliceGradient, SliceGradientOp<CPUContext>);
8 
9 OPERATOR_SCHEMA(Slice)
10  .NumInputs(1, 3)
11  .NumOutputs(1)
12  .DisallowInputFillers() // the filler cannot be enabled without output dims
13  .SetDoc(R"DOC(
14 Produces a slice of the input tensor.
15 
16 - Currently, only slicing in a single dimension is supported.
17 
18 - Start and end indices are either passed as two 1D input tensors or using the `starts` and `ends` arguments.
19 
20 - If a negative value is passed for any of the start or end indices, it represents the number of elements before the end of that dimension. End indices are non-inclusive unless negative (end index -1 means up to and including the last element).
21 
22 Github Links:
23 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/slice_op.cc
24 
25 <details>
26 
27 <summary> <b>Example</b> </summary>
28 
29 **Code**
30 
31 ```
32 
33 workspace.ResetWorkspace()
34 
35 op = core.CreateOperator(
36  "Slice",
37  ["X"],
38  ["Y"],
39  starts=(0,1),
40  ends=(-1,3)
41 )
42 
43 workspace.FeedBlob("X", np.array([[1,2,3,4],[5,6,7,8]]))
44 print("X:", workspace.FetchBlob("X"))
45 workspace.RunOperatorOnce(op)
46 print("Y:", workspace.FetchBlob("Y"))
47 
48 ```
49 
50 **Result**
51 
52 ```
53 
54 X:
55 [[1 2 3 4]
56  [5 6 7 8]]
57 Y:
58 [[2 3]
59  [6 7]]
60 
61 ```
62 
63 </details>
64 
65 )DOC")
66  .Input(0, "X", "(*Tensor*): tensor to extract slices from")
67  .Input(
68  1,
69  "starts",
70  "(*Tensor`<int>`*): 1D tensor of start-indices for each dimension of data")
71  .Input(
72  2,
73  "ends",
74  "(*Tensor`<int>`*): 1D tensor of end-indices for each dimension of data")
75  .Arg("starts", "(*Tuple(int)*): list of starting indices")
76  .Arg("ends", "(*Tuple(int)*): list of ending indices")
77  .TensorInferenceFunction([](const OperatorDef& def,
78  const vector<TensorShape>& in) {
79  if (in.size() > 1) {
80  // Cannot compute shape inference when the splits are defined
81  // in data.
82  return vector<TensorShape>();
83  }
84  auto const& data = in[0];
85 
86  ArgumentHelper helper(def);
87  auto starts = helper.GetRepeatedArgument<int>("starts", vector<int>());
88  auto ends = helper.GetRepeatedArgument<int>("ends", vector<int>());
89  vector<int> dst_sizes(data.dims_size());
90 
91  for (int i = 0; i < data.dims_size(); ++i) {
92  if (i >= starts.size()) {
93  continue;
94  }
95  if (data.dims_size() > 0) {
96  auto start = starts[i];
97  auto end = ends[i];
98  if (start < 0) {
99  start = data.dims(i) + 1 + start;
100  }
101  if (end < 0) {
102  end = data.dims(i) + 1 + end;
103  }
104  dst_sizes[i] = end - start;
105  } else {
106  dst_sizes[i] = 0;
107  }
108  }
109  return vector<TensorShape>{
110  CreateTensorShape(dst_sizes, data.data_type())};
111  })
112  .Output(0, "Y", "(*Tensor*): sliced output tensor")
113  .InheritOnnxSchema();
114 
115 GRADIENT_OPERATOR_SCHEMA(SliceGradient);
116 
117 namespace {
118 struct GetSliceGradient : public GradientMakerBase {
119  using GradientMakerBase::GradientMakerBase;
120  vector<OperatorDef> GetGradientDefs() override {
121  if (def_.input_size() > 1) {
122  return vector<OperatorDef>{CreateOperatorDef(
123  "SliceGradient",
124  "",
125  std::vector<string>{I(0), I(1), I(2), GO(0)},
126  std::vector<string>{GI(0)})};
127  } else {
128  return vector<OperatorDef>{CreateOperatorDef(
129  "SliceGradient",
130  "",
131  std::vector<string>{I(0), GO(0)},
132  std::vector<string>{GI(0)})};
133  }
134  }
135 };
136 }
137 REGISTER_GRADIENT(Slice, GetSliceGradient);
138 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13