Caffe2 - C++ API
A deep learning, cross platform ML framework
slice_op.h
1 
2 #pragma once
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 namespace {
11 
12 template <class SIndex, class Context>
13 bool SliceImpl(
14  Tensor* output,
15  const Tensor& data,
16  const Tensor& starts,
17  const Tensor& ends,
18  Context* context,
19  Tensor* gdata = nullptr,
20  const Tensor* go = nullptr) {
21  bool backward = output == nullptr;
22 
23  auto* starts_data = starts.template data<SIndex>();
24  auto* ends_data = ends.template data<SIndex>();
25 
26  CAFFE_ENFORCE_EQ(starts.dim(), 1);
27  CAFFE_ENFORCE_EQ(ends.dim(), 1);
28  CAFFE_ENFORCE_GE(data.dim(), starts.numel());
29  CAFFE_ENFORCE_EQ(starts.numel(), ends.numel());
30 
31  std::vector<SIndex> starts_idx(data.dim());
32  std::vector<SIndex> ends_idx(data.dim());
33  std::vector<SIndex> dst_sizes(data.dim());
34 
35  for (int i = 0; i < data.dim(); ++i) {
36  if (i >= starts.numel()) {
37  starts_idx[i] = 0;
38  ends_idx[i] = data.sizes()[i];
39  continue;
40  }
41  if (data.sizes()[i] > 0) {
42  auto start = starts_data[i];
43  auto end = ends_data[i];
44  if (start < 0) {
45  start = data.sizes()[i] + 1 + start;
46  }
47  if (end < 0) {
48  end = data.sizes()[i] + 1 + end;
49  }
50  if (start > data.sizes()[i]) {
51  start = data.sizes()[i];
52  }
53  if (end > data.sizes()[i]) {
54  end = data.sizes()[i];
55  }
56  CAFFE_ENFORCE_GE(start, 0);
57  CAFFE_ENFORCE_GE(end, 0);
58  CAFFE_ENFORCE_GE(end, start);
59  starts_idx[i] = start;
60  ends_idx[i] = end;
61  dst_sizes[i] = end - start;
62  } else {
63  starts_idx[i] = 0;
64  ends_idx[i] = 0;
65  dst_sizes[i] = 0;
66  }
67  }
68 
69  if (data.numel() <= 0) {
70  // When the input is empty, we do not need to do copy.
71  if (!backward) {
72  output->Resize(dst_sizes);
73  output->raw_mutable_data(data.dtype());
74  }
75  return true;
76  }
77  // for now only supports slicing in 1 dimension
78  int dim = -1;
79  for (int i = 0; i < data.dim(); ++i) {
80  if (starts_idx[i] > 0 || ends_idx[i] < data.sizes()[i]) {
81  CAFFE_ENFORCE_EQ(
82  dim, -1, "Currently only possible to slice in 1 dimension.");
83  dim = i;
84  }
85  }
86  if (dim == -1) {
87  if (!backward) {
88  output->CopyFrom(data, true /*async*/);
89  } else {
90  gdata->CopyFrom(*go, true /*async*/);
91  }
92  return true;
93  }
94  size_t unit = std::accumulate(
95  data.sizes().begin() + dim + 1,
96  data.sizes().end(),
97  1,
98  std::multiplies<SIndex>());
99  size_t num_blocks = std::accumulate(
100  data.sizes().begin(),
101  data.sizes().begin() + dim,
102  1,
103  std::multiplies<SIndex>());
104  if (!backward) {
105  output->Resize(dst_sizes);
106  } else {
107  gdata->ResizeLike(data);
108  }
109 
110  size_t itemsize = data.dtype().itemsize();
111 
112  if (!backward) {
113  char* src_bytes = (char*)data.raw_data();
114  char* dst_bytes = (char*)output->raw_mutable_data(data.dtype());
115 
116  size_t src_nbytes = data.nbytes();
117  size_t dst_nbytes = output->nbytes();
118 
119  size_t src_block_size = unit * data.sizes()[dim];
120  size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
121  size_t src_offset = unit * starts_idx[dim];
122 
123  if (num_blocks == 0 || dst_block_size == 0) {
124  return true;
125  }
126 
127  size_t src_block_size_bytes = itemsize * src_block_size;
128  size_t dst_block_size_bytes = itemsize * dst_block_size;
129 
130  char* src_offset_bytes = src_bytes + itemsize * src_offset;
131  char* dst_offset_bytes = dst_bytes;
132  for (size_t i = 0; i < num_blocks; ++i) {
133  char* local_src_offset_bytes =
134  src_offset_bytes + i * src_block_size_bytes;
135  char* local_dst_offset_bytes =
136  dst_offset_bytes + i * dst_block_size_bytes;
137  DCHECK_LE(
138  static_cast<void*>(local_src_offset_bytes + dst_block_size_bytes),
139  static_cast<void*>(src_bytes + src_nbytes));
140  DCHECK_LE(
141  static_cast<void*>(local_dst_offset_bytes + dst_block_size_bytes),
142  static_cast<void*>(dst_bytes + dst_nbytes));
143  context->CopyItemsSameDevice(
144  data.dtype(),
145  dst_block_size,
146  (void*)local_src_offset_bytes,
147  (void*)local_dst_offset_bytes);
148  }
149  } else {
150  char* src_bytes = (char*)go->raw_data();
151  char* dst_bytes = (char*)gdata->raw_mutable_data(go->dtype());
152 
153  size_t src_nbytes = go->nbytes();
154  size_t dst_nbytes = gdata->nbytes();
155 
156  size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
157  size_t dst_block_size = unit * data.sizes()[dim];
158  size_t dst_offset = unit * starts_idx[dim];
159 
160  if (num_blocks == 0 || dst_block_size == 0) {
161  return true;
162  }
163 
164  size_t src_block_size_bytes = itemsize * src_block_size;
165  size_t dst_block_size_bytes = itemsize * dst_block_size;
166 
167  char* src_offset_bytes = src_bytes;
168  char* dst_offset_bytes = dst_bytes + itemsize * dst_offset;
169  // Zero out gradient blob before copy since we copy in fewer items than
170  // there is space for
171  math::Set<char, Context>(dst_nbytes, 0, dst_bytes, context);
172 
173  // If output tensor is empty, just return zeroed gradient tensor
174  if (!src_bytes) {
175  return true;
176  }
177 
178  for (size_t i = 0; i < num_blocks; ++i) {
179  char* local_src_offset_bytes =
180  src_offset_bytes + i * src_block_size_bytes;
181  char* local_dst_offset_bytes =
182  dst_offset_bytes + i * dst_block_size_bytes;
183  DCHECK_LE(
184  local_src_offset_bytes + src_block_size_bytes,
185  src_bytes + src_nbytes);
186  DCHECK_LE(
187  local_dst_offset_bytes + src_block_size_bytes,
188  dst_bytes + dst_nbytes);
189  context->CopyItemsSameDevice(
190  go->dtype(),
191  src_block_size,
192  (void*)local_src_offset_bytes,
193  (void*)local_dst_offset_bytes);
194  }
195  }
196  return true;
197 }
198 
199 } // namespace
200 
201 template <class Context>
202 class SliceOp : public Operator<Context> {
203  public:
204  USE_OPERATOR_CONTEXT_FUNCTIONS;
205  template <class... Args>
206  explicit SliceOp(Args&&... args)
207  : Operator<Context>(std::forward<Args>(args)...),
208  starts_(this->template GetRepeatedArgument<int64_t>("starts")),
209  ends_(this->template GetRepeatedArgument<int64_t>("ends")),
210  statically_inited_(false) {}
211 
212  bool RunOnDevice() override {
213  if (InputSize() > 1) {
214  return DispatchHelper<TensorTypes<int, int64_t>>::call(this, Input(1));
215  } else {
216  return DoRunWithType<int64_t>();
217  }
218  }
219 
220  template <typename SIndex>
221  bool DoRunWithType() {
222  if (InputSize() > 1) {
223  ReinitializeAndCopyFrom(&starts_host_, at::dtype<SIndex>().device(CPU), Input(1));
224  ReinitializeAndCopyFrom(&ends_host_, at::dtype<SIndex>().device(CPU), Input(2));
225  } else {
226  if (!statically_inited_) {
227  CAFFE_ENFORCE(HasArgument("starts"));
228  CAFFE_ENFORCE(HasArgument("ends"));
229  CAFFE_ENFORCE_EQ(starts_.size(), ends_.size());
230 
231  ReinitializeTensor(&starts_host_, {static_cast<int64_t>(starts_.size())}, at::dtype<SIndex>().device(CPU));
232  ReinitializeTensor(&ends_host_, {static_cast<int64_t>(ends_.size())}, at::dtype<SIndex>().device(CPU));
233 
234  memcpy(
235  starts_host_.template mutable_data<SIndex>(),
236  starts_.data(),
237  sizeof(SIndex) * starts_.size());
238  memcpy(
239  ends_host_.template mutable_data<SIndex>(),
240  ends_.data(),
241  sizeof(SIndex) * ends_.size());
242  statically_inited_ = true;
243  }
244  }
245 
246  const auto& data = Input(0);
247  auto output = Output(0);
248 
249  return SliceImpl<SIndex, Context>(
250  output, data, starts_host_, ends_host_, &context_);
251  }
252 
253  C10_DISABLE_COPY_AND_ASSIGN(SliceOp);
254 
255  protected:
256  std::vector<int64_t> starts_;
257  std::vector<int64_t> ends_;
258  bool statically_inited_;
259  Tensor starts_host_;
260  Tensor ends_host_;
261 };
262 
263 template <class Context>
264 class SliceGradientOp : public Operator<Context> {
265  public:
266  USE_OPERATOR_CONTEXT_FUNCTIONS;
267  template <class... Args>
268  explicit SliceGradientOp(Args&&... args)
269  : Operator<Context>(std::forward<Args>(args)...),
270  starts_(this->template GetRepeatedArgument<int64_t>("starts")),
271  ends_(this->template GetRepeatedArgument<int64_t>("ends")),
272  statically_inited_(false) {}
273 
274  C10_DISABLE_COPY_AND_ASSIGN(SliceGradientOp);
275 
276  bool RunOnDevice() override {
277  if (InputSize() == 4) {
278  return DispatchHelper<TensorTypes<int, int64_t>>::call(this, Input(1));
279  } else {
280  return DoRunWithType<int64_t>();
281  }
282  }
283 
284  template <typename SIndex>
285  bool DoRunWithType() {
286  auto* gdata = Output(0);
287  auto& data = Input(0);
288 
289  if (InputSize() == 4) {
290  ReinitializeAndCopyFrom(&starts_host_, at::dtype<SIndex>().device(CPU), Input(1));
291  ReinitializeAndCopyFrom(&ends_host_, at::dtype<SIndex>().device(CPU), Input(2));
292 
293  auto& go = Input(3);
294 
295  return SliceImpl<SIndex, Context>(
296  nullptr, data, starts_host_, ends_host_, &context_, gdata, &go);
297  } else {
298  if (!statically_inited_) {
299  CAFFE_ENFORCE(HasArgument("starts"));
300  CAFFE_ENFORCE(HasArgument("ends"));
301  CAFFE_ENFORCE_EQ(starts_.size(), ends_.size());
302 
304  &starts_host_, {static_cast<int64_t>(starts_.size())}, at::dtype<SIndex>().device(CPU));
306  &ends_host_, {static_cast<int64_t>(ends_.size())}, at::dtype<SIndex>().device(CPU));
307 
308  memcpy(
309  starts_host_.template mutable_data<SIndex>(),
310  starts_.data(),
311  sizeof(SIndex) * starts_.size());
312  memcpy(
313  ends_host_.template mutable_data<SIndex>(),
314  ends_.data(),
315  sizeof(SIndex) * ends_.size());
316 
317  statically_inited_ = true;
318  }
319  auto& go = Input(1);
320 
321  return SliceImpl<SIndex, Context>(
322  nullptr, data, starts_host_, ends_host_, &context_, gdata, &go);
323  }
324  }
325 
326  private:
327 
328  std::vector<int64_t> starts_;
329  std::vector<int64_t> ends_;
330  bool statically_inited_;
331  Tensor starts_host_;
332  Tensor ends_host_;
333 };
334 } // namespace caffe2
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70