Caffe2 - C++ API
A deep learning, cross platform ML framework
gather_ranges_to_dense_op.h
1 #ifndef CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_
2 #define CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_
3 
4 #include <math.h>
5 
6 #include "caffe2/core/common_omp.h"
7 #include "caffe2/core/context.h"
8 #include "caffe2/core/logging.h"
9 #include "caffe2/core/operator.h"
10 #include "caffe2/core/types.h"
11 #include "caffe2/utils/math.h"
12 
13 #include <cstring>
14 #include <map>
15 #include <utility>
16 
17 namespace caffe2 {
18 template <class Context>
19 class GatherRangesToDenseOp final : public Operator<Context> {
20  public:
21  USE_OPERATOR_CONTEXT_FUNCTIONS;
22  template <class... Args>
23  explicit GatherRangesToDenseOp(Args&&... args)
24  : Operator<Context>(std::forward<Args>(args)...),
25  lengths_(this->template GetRepeatedArgument<int>("lengths")) {
26  CAFFE_ENFORCE_GT(lengths_.size(), 0, "There has to be at least one length");
27  for (auto length : lengths_) {
28  CAFFE_ENFORCE_GT(length, 0, "Each length should be positive");
29  }
30  }
31 
32  bool RunOnDevice() override {
34  this, this->template Input<Tensor>(RANGES, CPU));
35  }
36 
37  template <typename Index>
38  bool DoRunWithType() {
39  auto& data = Input(DATA);
40  auto& ranges = Input(RANGES);
41  CAFFE_ENFORCE_EQ(data.dim(), 1, "Data has to be 1-D");
42  CAFFE_ENFORCE_EQ(ranges.dim(), 3, "Ranges has to be 3-D");
43  if (InputSize() == 3) {
44  auto& key = Input(KEY);
45  CAFFE_ENFORCE_EQ(key.dim(), 1, "Key has to be 1-D");
46  CAFFE_ENFORCE(
47  key.dtype().template Match<int64_t>(), "Key has to be type int64_t");
48  }
49  CAFFE_ENFORCE_EQ(
50  ranges.size(1),
51  lengths_.size(),
52  "Nummber of ranges should match number of lengths");
53  CAFFE_ENFORCE_EQ(
54  ranges.size(1),
55  OutputSize(),
56  "Nummber of ranges should match number of outputs");
57  CAFFE_ENFORCE_EQ(
58  ranges.size(2), 2, "Ranges last dimension should be of size 2");
59 
60  auto* rawData = static_cast<const char*>(data.raw_data());
61  auto* rangesData = ranges.template data<Index>();
62  int rangesDataOffset = 0;
63  auto itemsize = data.dtype().itemsize();
64 
65  auto batchSize = ranges.size(0);
66  vector<int64_t> outputDims{batchSize, 0};
67  vector<char*> outputRawData;
68  for (int i = 0; i < OutputSize(); ++i) {
69  auto* output = Output(i);
70  outputDims[1] = lengths_[i];
71  output->Resize(outputDims);
72  char* ptr = static_cast<char*>(output->raw_mutable_data(data.dtype()));
73  memset(ptr, 0, output->nbytes());
74  outputRawData.push_back(ptr);
75  }
76 
77  for (int i = 0; i < batchSize; ++i) {
78  for (int j = 0; j < OutputSize(); ++j) {
79  auto rangeStart = rangesData[rangesDataOffset++];
80  auto rangeLength = rangesData[rangesDataOffset++];
81  if (rangeLength == 0) {
82  // empty range, will be filled with zeros
83  continue;
84  }
85  CAFFE_ENFORCE_EQ(
86  rangeLength,
87  lengths_[j],
88  "Range lengths missmatch for output #",
89  j);
90 
91  if (InputSize() == 2) {
92  context_.CopyItemsSameDevice(
93  data.dtype(),
94  rangeLength,
95  rawData + rangeStart * itemsize,
96  outputRawData[j] + i * itemsize * lengths_[j]);
97  } else {
98  auto& key = Input(KEY);
99  auto* key_data = key.template data<int64_t>();
100  vector<std::pair<int64_t, const char*>> buffer;
101  for (int b_i = 0; b_i < rangeLength; ++b_i) {
102  int64_t one_key_item = key_data[rangeStart + b_i];
103  auto* one_data_item = rawData + (rangeStart + b_i) * itemsize;
104  buffer.emplace_back(one_key_item, one_data_item);
105  }
106  std::sort(
107  buffer.begin(),
108  buffer.end(),
109  [](const std::pair<int64_t, const char*>& left,
110  const std::pair<int64_t, const char*>& right) {
111  return left.first < right.first;
112  });
113  for (int b_i = 0; b_i < rangeLength; ++b_i) {
114  // Since this CPU only, directly copy to the destination.
115  std::memcpy(
116  outputRawData[j] + (i * lengths_[j] + b_i) * itemsize,
117  buffer[b_i].second,
118  itemsize);
119  }
120  }
121  }
122  }
123  CAFFE_ENFORCE_EQ(rangesDataOffset, ranges.numel());
124 
125  return true;
126  }
127 
128  INPUT_TAGS(DATA, RANGES, KEY);
129 
130  private:
131  vector<int> lengths_;
132 };
133 
134 } // namespace caffe2
135 
136 #endif // CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13