Caffe2 - C++ API
A deep learning, cross platform ML framework
slice_op.h
1 
18 #pragma once
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 namespace {
27 
28 template <class SIndex, class Context>
29 bool SliceImpl(
30  Tensor<Context>* output,
31  const Tensor<Context>& data,
32  const Tensor<Context>& starts,
33  const Tensor<Context>& ends,
34  Context* context,
35  Tensor<Context>* gdata = nullptr,
36  const Tensor<Context>* go = nullptr) {
37  bool backward = output == nullptr;
38 
39  auto* starts_data = starts.template data<SIndex>();
40  auto* ends_data = ends.template data<SIndex>();
41 
42  CAFFE_ENFORCE_EQ(starts.ndim(), 1);
43  CAFFE_ENFORCE_EQ(ends.ndim(), 1);
44  CAFFE_ENFORCE_GE(data.ndim(), starts.size());
45  CAFFE_ENFORCE_EQ(starts.size(), ends.size());
46 
47  std::vector<SIndex> starts_idx(data.ndim());
48  std::vector<SIndex> ends_idx(data.ndim());
49  std::vector<SIndex> dst_sizes(data.ndim());
50 
51  for (int i = 0; i < data.ndim(); ++i) {
52  if (i >= starts.size()) {
53  starts_idx[i] = 0;
54  ends_idx[i] = data.dims()[i];
55  continue;
56  }
57  if (data.dims()[i] > 0) {
58  auto start = starts_data[i];
59  auto end = ends_data[i];
60  if (start < 0) {
61  start = data.dims()[i] + 1 + start;
62  }
63  if (end < 0) {
64  end = data.dims()[i] + 1 + end;
65  }
66  if (start > data.dims()[i]) {
67  start = data.dims()[i];
68  }
69  if (end > data.dims()[i]) {
70  end = data.dims()[i];
71  }
72  CAFFE_ENFORCE_GE(start, 0);
73  CAFFE_ENFORCE_GE(end, 0);
74  CAFFE_ENFORCE_GE(end, start);
75  starts_idx[i] = start;
76  ends_idx[i] = end;
77  dst_sizes[i] = end - start;
78  } else {
79  starts_idx[i] = 0;
80  ends_idx[i] = 0;
81  dst_sizes[i] = 0;
82  }
83  }
84 
85  if (data.size() <= 0) {
86  // When the input is empty, we do not need to do copy.
87  if (!backward) {
88  output->Resize(dst_sizes);
89  output->raw_mutable_data(data.meta());
90  }
91  return true;
92  }
93  // for now only supports slicing in 1 dimension
94  int dim = -1;
95  for (int i = 0; i < data.ndim(); ++i) {
96  if (starts_idx[i] > 0 || ends_idx[i] < data.dims()[i]) {
97  CAFFE_ENFORCE_EQ(
98  dim, -1, "Currently only possible to slice in 1 dimension.");
99  dim = i;
100  }
101  }
102  if (dim == -1) {
103  if (!backward) {
104  output->CopyFrom(data, context);
105  } else {
106  gdata->CopyFrom(*go, context);
107  }
108  return true;
109  }
110  size_t unit = std::accumulate(
111  data.dims().begin() + dim + 1,
112  data.dims().end(),
113  1,
114  std::multiplies<SIndex>());
115  size_t num_blocks = std::accumulate(
116  data.dims().begin(),
117  data.dims().begin() + dim,
118  1,
119  std::multiplies<SIndex>());
120  if (!backward) {
121  output->Resize(dst_sizes);
122  } else {
123  gdata->ResizeLike(data);
124  }
125 
126  size_t itemsize = data.meta().itemsize();
127 
128  if (!backward) {
129  char* src_bytes = (char*)data.raw_data();
130  char* dst_bytes = (char*)output->raw_mutable_data(data.meta());
131 
132  size_t src_nbytes = data.nbytes();
133  size_t dst_nbytes = output->nbytes();
134 
135  size_t src_block_size = unit * data.dims()[dim];
136  size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
137  size_t src_offset = unit * starts_idx[dim];
138 
139  if (num_blocks == 0 || dst_block_size == 0) {
140  return true;
141  }
142 
143  size_t src_block_size_bytes = itemsize * src_block_size;
144  size_t dst_block_size_bytes = itemsize * dst_block_size;
145 
146  char* src_offset_bytes = src_bytes + itemsize * src_offset;
147  char* dst_offset_bytes = dst_bytes;
148  for (int i = 0; i < num_blocks; ++i) {
149  char* local_src_offset_bytes =
150  src_offset_bytes + i * src_block_size_bytes;
151  char* local_dst_offset_bytes =
152  dst_offset_bytes + i * dst_block_size_bytes;
153  DCHECK_LE(
154  static_cast<void*>(local_src_offset_bytes + dst_block_size_bytes),
155  static_cast<void*>(src_bytes + src_nbytes));
156  DCHECK_LE(
157  static_cast<void*>(local_dst_offset_bytes + dst_block_size_bytes),
158  static_cast<void*>(dst_bytes + dst_nbytes));
159  context->template CopyItems<Context, Context>(
160  data.meta(),
161  dst_block_size,
162  (void*)local_src_offset_bytes,
163  (void*)local_dst_offset_bytes);
164  }
165  } else {
166  char* src_bytes = (char*)go->raw_data();
167  char* dst_bytes = (char*)gdata->raw_mutable_data(go->meta());
168 
169  size_t src_nbytes = go->nbytes();
170  size_t dst_nbytes = gdata->nbytes();
171 
172  size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
173  size_t dst_block_size = unit * data.dims()[dim];
174  size_t dst_offset = unit * starts_idx[dim];
175 
176  if (num_blocks == 0 || dst_block_size == 0) {
177  return true;
178  }
179 
180  size_t src_block_size_bytes = itemsize * src_block_size;
181  size_t dst_block_size_bytes = itemsize * dst_block_size;
182 
183  char* src_offset_bytes = src_bytes;
184  char* dst_offset_bytes = dst_bytes + itemsize * dst_offset;
185  // Zero out gradient blob before copy since we copy in fewer items than
186  // there is space for
187  math::Set<char, Context>(dst_nbytes, 0, dst_bytes, context);
188 
189  // If output tensor is empty, just return zeroed gradient tensor
190  if (!src_bytes) {
191  return true;
192  }
193 
194  for (int i = 0; i < num_blocks; ++i) {
195  char* local_src_offset_bytes =
196  src_offset_bytes + i * src_block_size_bytes;
197  char* local_dst_offset_bytes =
198  dst_offset_bytes + i * dst_block_size_bytes;
199  DCHECK_LE(
200  local_src_offset_bytes + src_block_size_bytes,
201  src_bytes + src_nbytes);
202  DCHECK_LE(
203  local_dst_offset_bytes + src_block_size_bytes,
204  dst_bytes + dst_nbytes);
205  context->template CopyItems<Context, Context>(
206  go->meta(),
207  src_block_size,
208  (void*)local_src_offset_bytes,
209  (void*)local_dst_offset_bytes);
210  }
211  }
212  return true;
213 }
214 
215 } // namespace
216 
217 template <class SIndex, class Context>
218 class SliceOp : public Operator<Context> {
219  public:
220  USE_OPERATOR_CONTEXT_FUNCTIONS;
221  SliceOp(const OperatorDef& operator_def, Workspace* ws)
222  : Operator<Context>(operator_def, ws),
223  starts_(OperatorBase::GetRepeatedArgument<SIndex>("starts")),
224  ends_(OperatorBase::GetRepeatedArgument<SIndex>("ends")),
225  statically_inited_(false) {}
226 
227  bool RunOnDevice() override {
228  auto* output = Output(0);
229  auto& data = Input(0);
230 
231  if (InputSize() > 1) {
232  starts_host_.template CopyFrom<Context>(Input(1));
233  ends_host_.template CopyFrom<Context>(Input(2));
234  } else {
235  if (!statically_inited_) {
236  CAFFE_ENFORCE(HasArgument("starts"));
237  CAFFE_ENFORCE(HasArgument("ends"));
238  CAFFE_ENFORCE_EQ(starts_.size(), ends_.size());
239 
240  starts_host_.Resize(starts_.size());
241  ends_host_.Resize(ends_.size());
242 
243  memcpy(
244  starts_host_.template mutable_data<SIndex>(),
245  starts_.data(),
246  sizeof(SIndex) * starts_.size());
247  memcpy(
248  ends_host_.template mutable_data<SIndex>(),
249  ends_.data(),
250  sizeof(SIndex) * ends_.size());
251  statically_inited_ = true;
252  }
253  }
254 
255  return SliceImpl<SIndex, Context>(
256  output, data, starts_host_, ends_host_, &context_);
257  }
258 
259  DISABLE_COPY_AND_ASSIGN(SliceOp);
260 
261  private:
262  std::vector<SIndex> starts_;
263  std::vector<SIndex> ends_;
264  bool statically_inited_;
265  TensorCPU starts_host_;
266  TensorCPU ends_host_;
267 };
268 
269 template <class SIndex, class Context>
270 class SliceGradientOp : public Operator<Context> {
271  public:
272  USE_OPERATOR_CONTEXT_FUNCTIONS;
273  SliceGradientOp(const OperatorDef& operator_def, Workspace* ws)
274  : Operator<Context>(operator_def, ws),
275  starts_(OperatorBase::GetRepeatedArgument<SIndex>("starts")),
276  ends_(OperatorBase::GetRepeatedArgument<SIndex>("ends")),
277  statically_inited_(false) {}
278 
279  bool RunOnDevice() override {
280  auto* gdata = Output(0);
281  auto& data = Input(0);
282 
283  if (InputSize() == 4) {
284  starts_host_.template CopyFrom<Context>(Input(1));
285  ends_host_.template CopyFrom<Context>(Input(2));
286 
287  auto& go = Input(3);
288 
289  return SliceImpl<SIndex, Context>(
290  nullptr, data, starts_host_, ends_host_, &context_, gdata, &go);
291  } else {
292  if (!statically_inited_) {
293  CAFFE_ENFORCE(HasArgument("starts"));
294  CAFFE_ENFORCE(HasArgument("ends"));
295  CAFFE_ENFORCE_EQ(starts_.size(), ends_.size());
296 
297  starts_host_.Resize(starts_.size());
298  ends_host_.Resize(ends_.size());
299 
300  memcpy(
301  starts_host_.template mutable_data<SIndex>(),
302  starts_.data(),
303  sizeof(SIndex) * starts_.size());
304  memcpy(
305  ends_host_.template mutable_data<SIndex>(),
306  ends_.data(),
307  sizeof(SIndex) * ends_.size());
308 
309  statically_inited_ = true;
310  }
311  auto& go = Input(1);
312 
313  return SliceImpl<SIndex, Context>(
314  nullptr, data, starts_host_, ends_host_, &context_, gdata, &go);
315  }
316  }
317 
318  DISABLE_COPY_AND_ASSIGN(SliceGradientOp);
319 
320  private:
321  std::vector<SIndex> starts_;
322  std::vector<SIndex> ends_;
323  bool statically_inited_;
324  TensorCPU starts_host_;
325  TensorCPU ends_host_;
326 };
327 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:304
Copyright (c) 2016-present, Facebook, Inc.
Copyright (c) 2016-present, Facebook, Inc.
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:52