Caffe2 - C++ API
A deep learning, cross platform ML framework
remove_data_blocks_op.h
1 #ifndef CAFFE2_OPERATORS_REMOVE_DATA_BLOCKS_OP_H_
2 #define CAFFE2_OPERATORS_REMOVE_DATA_BLOCKS_OP_H_
3 
4 #include <algorithm>
5 #include <vector>
6 
7 #include "caffe2/core/context.h"
8 #include "caffe2/core/operator.h"
9 
10 namespace caffe2 {
11 
12 template <class Context>
13 class RemoveDataBlocksOp final : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  USE_SIMPLE_CTOR_DTOR(RemoveDataBlocksOp);
17  USE_DISPATCH_HELPER;
18 
19  bool RunOnDevice() override {
20  if (Input(INDICES).sizes()[0] == 0) {
21  Output(0)->CopyFrom(Input(0));
22  return true;
23  } else {
24  return DispatchHelper<TensorTypes<int, long>>::call(this, Input(INDICES));
25  }
26  }
27 
28  template <typename T>
29  bool DoRunWithType() {
30  const auto& data = Input(DATA);
31  const auto& indices = Input(INDICES);
32  CAFFE_ENFORCE(data.dim() > 0, "DATA should be at leat 1-D.");
33  CAFFE_ENFORCE(indices.dim() == 1, "INDICES should be 1-D.");
34 
35  const auto outer_size = data.sizes()[0];
36  const auto block_size = data.size_from_dim(1);
37  const auto block_size_bytes = block_size * data.dtype().itemsize();
38  auto indices_size = indices.sizes()[0];
39  const char* data_ptr = (char*)data.raw_data();
40  const auto* ind_ptr = indices.template data<T>();
41 
42  std::vector<T> ind_vec;
43  for (int64_t i = 0; i < indices_size; i++) {
44  ind_vec.push_back(ind_ptr[i]);
45  }
46  std::sort(ind_vec.begin(), ind_vec.end());
47  CAFFE_ENFORCE(ind_vec[0] >= 0, "The min index should be larger than zero.");
48  CAFFE_ENFORCE(
49  ind_vec[indices_size - 1] < outer_size,
50  "The max index should be smaller than the data outer size.");
51  // removes duplicate indices
52  ind_vec.erase(std::unique(ind_vec.begin(), ind_vec.end()), ind_vec.end());
53  indices_size = ind_vec.size();
54 
55  auto* output = Output(0);
56  auto shape = data.sizes().vec();
57  shape[0] -= indices_size;
58  output->Resize(shape);
59  char* out_ptr = (char*)output->raw_mutable_data(data.dtype());
60 
61  ind_vec.insert(ind_vec.begin(), -1);
62  int64_t ind_vec_size = ind_vec.size();
63  for (auto i = 0; i < ind_vec_size; i++) {
64  int64_t interval_start = ind_vec[i] + 1;
65  int64_t interval_end =
66  (i == ind_vec_size - 1) ? outer_size : ind_vec[i + 1];
67  auto num_items = interval_end - interval_start;
68  context_.CopyItemsSameDevice(
69  data.dtype(),
70  num_items * block_size,
71  data_ptr + block_size_bytes * interval_start,
72  out_ptr);
73  out_ptr += block_size_bytes * num_items;
74  }
75 
76  return true;
77  }
78 
79  private:
80  INPUT_TAGS(DATA, INDICES);
81 };
82 
83 } // namespace caffe2
84 
85 #endif // CAFFE2_OPERATORS_REMOVE_DATA_BLOCKS_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13