Caffe2 - C++ API
A deep learning, cross platform ML framework
gather_ranges_to_dense_op.h
1 
17 #ifndef CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_
18 #define CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_
19 
20 #include <math.h>
21 
22 #include "caffe2/core/common_omp.h"
23 #include "caffe2/core/context.h"
24 #include "caffe2/core/logging.h"
25 #include "caffe2/core/operator.h"
26 #include "caffe2/core/types.h"
27 #include "caffe2/utils/math.h"
28 
29 #include <map>
30 #include <utility>
31 
32 namespace caffe2 {
33 template <class Context>
34 class GatherRangesToDenseOp final : public Operator<Context> {
35  public:
36  USE_OPERATOR_CONTEXT_FUNCTIONS;
37  GatherRangesToDenseOp(const OperatorDef& operator_def, Workspace* ws)
38  : Operator<Context>(operator_def, ws),
39  lengths_(OperatorBase::GetRepeatedArgument<int>("lengths")) {
40  CAFFE_ENFORCE_GT(lengths_.size(), 0, "There has to be at least one length");
41  for (auto length : lengths_) {
42  CAFFE_ENFORCE_GT(length, 0, "Each length should be positive");
43  }
44  }
45 
46  bool RunOnDevice() override {
48  this, OperatorBase::Input<TensorCPU>(RANGES));
49  }
50 
51  template <typename Index>
52  bool DoRunWithType() {
53  auto& data = Input(DATA);
54  auto& ranges = Input(RANGES);
55  CAFFE_ENFORCE_EQ(data.ndim(), 1, "Data has to be 1-D");
56  CAFFE_ENFORCE_EQ(ranges.ndim(), 3, "Data has to be 3-D");
57  CAFFE_ENFORCE_EQ(
58  ranges.dim(1),
59  lengths_.size(),
60  "Nummber of ranges should match number of lengths");
61  CAFFE_ENFORCE_EQ(
62  ranges.dim(1),
63  OutputSize(),
64  "Nummber of ranges should match number of outputs");
65  CAFFE_ENFORCE_EQ(
66  ranges.dim(2), 2, "Ranges last dimension should be of size 2");
67 
68  auto* rawData = static_cast<const char*>(data.raw_data());
69  auto* rangesData = ranges.template data<Index>();
70  int rangesDataOffset = 0;
71  auto itemsize = data.meta().itemsize();
72 
73  auto batchSize = ranges.dim(0);
74  vector<TIndex> outputDims{batchSize, 0};
75  vector<char*> outputRawData;
76  for (int i = 0; i < OutputSize(); ++i) {
77  auto* output = Output(i);
78  outputDims[1] = lengths_[i];
79  output->Resize(outputDims);
80  char* ptr = static_cast<char*>(output->raw_mutable_data(data.meta()));
81  memset(ptr, 0, output->nbytes());
82  outputRawData.push_back(ptr);
83  }
84 
85  for (int i = 0; i < batchSize; ++i) {
86  for (int j = 0; j < OutputSize(); ++j) {
87  auto rangeStart = rangesData[rangesDataOffset++];
88  auto rangeLength = rangesData[rangesDataOffset++];
89  if (rangeLength == 0) {
90  // empty range, will be filled with zeros
91  continue;
92  }
93  CAFFE_ENFORCE_EQ(
94  rangeLength,
95  lengths_[j],
96  "Range lengths missmatch for output #",
97  j);
98  context_.template CopyItems<Context, Context>(
99  data.meta(),
100  rangeLength,
101  rawData + rangeStart * itemsize,
102  outputRawData[j] + i * itemsize * lengths_[j]);
103  }
104  }
105  CAFFE_ENFORCE_EQ(rangesDataOffset, ranges.size());
106 
107  return true;
108  }
109 
110  INPUT_TAGS(DATA, RANGES);
111 
112  private:
113  vector<int> lengths_;
114 };
115 
116 } // namespace caffe2
117 
118 #endif // CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.