Caffe2 - C++ API
A deep learning, cross platform ML framework
gather_op.h
1 #ifndef GATHER_OP_H_
2 #define GATHER_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 // This maintains index-mapping functions shared by Gather and BatchGather ops.
10 namespace gather_helper {
11 
12 // New shape is concatenation:
13 // [data dims before axis] + [indices dims] + [data dims after axis]
14 template <typename IndexType, typename DataDimsVec, typename IndexDimsVec>
15 static vector<IndexType> calc_output_shape_vector(
16  const DataDimsVec& data_dims,
17  const IndexDimsVec& indices_dims,
18  int axis) {
19  vector<IndexType> shape;
20  // If the dimension we are indexing is empty, just use data_dims as shape.
21  // This replicates behavior in (https://github.com/pytorch/pytorch/pull/13781)
22  // needed to allow workflows with empty batch to succeed.
23  if (data_dims[axis] == 0) {
24  shape.insert(shape.end(), data_dims.begin(), data_dims.end());
25  } else {
26  shape.insert(shape.end(), data_dims.begin(), data_dims.begin() + axis);
27  shape.insert(shape.end(), indices_dims.begin(), indices_dims.end());
28  shape.insert(shape.end(), data_dims.begin() + axis + 1, data_dims.end());
29  }
30  return shape;
31 }
32 
33 // Check that indices fall within dimension array size with CAFFE_ENFORCE.
34 template <typename IndexType>
35 static void check_indexarray_range(
36  const IndexType* indices,
37  int64_t n,
38  IndexType indexing_axis_dim,
39  bool wrap_indices) {
40  //
41  for (auto i = 0; i < n; ++i) {
42  auto idx = indices[i];
43  if (wrap_indices && idx < 0) {
44  idx = idx + indexing_axis_dim;
45  }
46  CAFFE_ENFORCE(
47  0 <= idx && idx < indexing_axis_dim,
48  "INDICES element is out of DATA bounds, id=",
49  idx,
50  " axis_dim=",
51  indexing_axis_dim);
52  }
53 }
54 
55 // Actual gather implementation - resizes output and copies indexed data.
56 template <typename Index, typename Context>
57 static bool gather_impl(
58  Operator<Context>* op,
59  int dataIdx,
60  int indicesIdx,
61  int outputIdx,
62  int axis,
63  bool wrap_indices) {
64  // If we endup using it on GPU doing O(N) memcpy is probably not best :)
65  // TODO: implement prefetching if it starts mattering (TF does it)
66 
67  const Tensor& data = op->Input(dataIdx);
68  const Tensor& indices = op->Input(indicesIdx);
69  const TypeMeta dataType = data.dtype();
70  size_t item_bytesize = dataType.itemsize();
71 
72  // ONNX allows negative axis to index from the back, valid range: [-r, r].
73  if (axis < 0) {
74  axis = data.dim() + axis;
75  }
76  CAFFE_ENFORCE_GE(data.dim(), axis + 1, "DATA should be at least [axis+1]-D");
77  CAFFE_ENFORCE_GE(axis, 0, "Axis should be non-negative");
78  CAFFE_ENFORCE_LT(axis, data.dim(), "Axis out of range");
79 
80  // New shape:
81  // [data dims before axis] + [indices dims] + [data dims after axis]
82  vector<int64_t> shape =
83  calc_output_shape_vector<int64_t>(data.sizes(), indices.sizes(), axis);
84  Tensor* output = op->Output(outputIdx, shape, at::dtype(dataType));
85  auto out = static_cast<char*>(output->raw_mutable_data(dataType));
86 
87  // Succeed if size of output is zero, which can happen for empty batch which
88  // would have data dimension size of 0.
89  // This *must* be done AFTER output->raw_mutable_data() above as that has
90  // important allocation side effect that we must see.
91  if (output->numel() == 0) {
92  return true;
93  }
94 
95  const Index* idxs = indices.template data<Index>();
96  auto src_base = static_cast<const char*>(data.raw_data());
97 
98  auto outer_dims_product = data.size_to_dim(axis);
99  auto block_size = data.size_from_dim(axis + 1);
100  auto block_bytesize = block_size * item_bytesize;
101 
102  auto src_indexing_axis_dim = data.size(axis);
103  auto src_batch_bytesize = data.size_from_dim(axis) * item_bytesize;
104  // Treat indices as a single block even if they have multiple dimensions.
105  // The "gathered batch" is a cumulative result combining indexed blocks.
106  auto N = indices.numel();
107  auto gathered_batch_bytesize = N * block_size * item_bytesize;
108 
109  check_indexarray_range<Index>(idxs, N, src_indexing_axis_dim, wrap_indices);
110 
111  // Special-case single-float copy for efficiency
112  if (data.template IsType<float>() && block_size == 1) {
113  for (auto batch = 0; batch < outer_dims_product; ++batch) {
114  const float* src_floats =
115  (const float*)(src_base + batch * src_batch_bytesize);
116  float* dst_floats = (float*)(out + batch * gathered_batch_bytesize);
117 
118  for (auto i = 0; i < N; ++i) {
119  auto idx = idxs[i];
120  if (wrap_indices && idx < 0) {
121  idx = idx + src_indexing_axis_dim;
122  }
123  dst_floats[i] = src_floats[idx];
124  }
125  }
126  } else {
127  // outer_dims_product specifies how many times we repeat inner dimensions,
128  // so we just iterate over it to cover all outer dimensions.
129  for (auto batch = 0; batch < outer_dims_product; ++batch) {
130  for (auto i = 0; i < N; ++i) {
131  auto idx = idxs[i];
132  if (wrap_indices && idx < 0) {
133  idx = idx + src_indexing_axis_dim;
134  }
135 
136  auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize;
137  auto dst = out + batch * gathered_batch_bytesize + i * block_bytesize;
138  op->getContext()->CopyItemsSameDevice(dataType, block_size, src, dst);
139  }
140  }
141  }
142  return true;
143 }
144 
145 } // namespace gather_helper
146 
147 template <class Context>
148 class GatherOp : public Operator<Context> {
149  public:
150  USE_OPERATOR_CONTEXT_FUNCTIONS;
151 
152  template <class... Args>
153  explicit GatherOp(Args&&... args)
154  : Operator<Context>(std::forward<Args>(args)...),
155  OP_SINGLE_ARG(int, "axis", axis_, 0) {
156  // TBD: We may want to fix the old index wrap behaviour once we have
157  // operator versioning, to only apply it when needed as otherwise its likely
158  // an error.
159  // Right now, we apply index wrapping by default only to axis == 0,
160  // since we have ONNX conversion code that uses it. For other ops it
161  // needs to be speified explicitly with argument or you don't get it.
162  if (OperatorBase::HasArgument("wrap_indices")) {
163  wrap_indices_ = Operator<Context>::template GetSingleArgument<bool>(
164  "wrap_indices", (false));
165  } else {
166  wrap_indices_ = (axis_ == 0) ? true : false;
167  }
168  }
169 
170  virtual ~GatherOp() noexcept {}
171 
172  bool RunOnDevice() override {
174  this, this->template Input<Tensor>(INDICES, CPU));
175  }
176 
177  template <typename Index>
178  bool DoRunWithType() {
179  return gather_helper::gather_impl<Index, Context>(
180  this, DATA, INDICES, 0, axis_, wrap_indices_);
181  }
182 
183  INPUT_TAGS(DATA, INDICES);
184 
185  protected:
186  int axis_;
187  bool wrap_indices_;
188 };
189 
190 } // namespace caffe2
191 #endif // GATHER_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70