4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 10 namespace gather_helper {
14 template <
typename IndexType,
typename DataDimsVec,
typename IndexDimsVec>
15 static vector<IndexType> calc_output_shape_vector(
16 const DataDimsVec& data_dims,
17 const IndexDimsVec& indices_dims,
19 vector<IndexType> shape;
23 if (data_dims[axis] == 0) {
24 shape.insert(shape.end(), data_dims.begin(), data_dims.end());
26 shape.insert(shape.end(), data_dims.begin(), data_dims.begin() + axis);
27 shape.insert(shape.end(), indices_dims.begin(), indices_dims.end());
28 shape.insert(shape.end(), data_dims.begin() + axis + 1, data_dims.end());
34 template <
typename IndexType>
35 static void check_indexarray_range(
36 const IndexType* indices,
38 IndexType indexing_axis_dim,
41 for (
auto i = 0; i < n; ++i) {
42 auto idx = indices[i];
43 if (wrap_indices && idx < 0) {
44 idx = idx + indexing_axis_dim;
47 0 <= idx && idx < indexing_axis_dim,
48 "INDICES element is out of DATA bounds, id=",
56 template <
typename Index,
typename Context>
57 static bool gather_impl(
58 Operator<Context>* op,
67 const Tensor& data = op->Input(dataIdx);
68 const Tensor& indices = op->Input(indicesIdx);
69 const TypeMeta dataType = data.dtype();
70 size_t item_bytesize = dataType.itemsize();
74 axis = data.dim() + axis;
76 CAFFE_ENFORCE_GE(data.dim(), axis + 1,
"DATA should be at least [axis+1]-D");
77 CAFFE_ENFORCE_GE(axis, 0,
"Axis should be non-negative");
78 CAFFE_ENFORCE_LT(axis, data.dim(),
"Axis out of range");
82 vector<int64_t> shape =
83 calc_output_shape_vector<int64_t>(data.sizes(), indices.sizes(), axis);
84 Tensor* output = op->Output(outputIdx, shape, at::dtype(dataType));
85 auto out =
static_cast<char*
>(output->raw_mutable_data(dataType));
91 if (output->numel() == 0) {
95 const Index* idxs = indices.template data<Index>();
96 auto src_base =
static_cast<const char*
>(data.raw_data());
98 auto outer_dims_product = data.size_to_dim(axis);
99 auto block_size = data.size_from_dim(axis + 1);
100 auto block_bytesize = block_size * item_bytesize;
102 auto src_indexing_axis_dim = data.size(axis);
103 auto src_batch_bytesize = data.size_from_dim(axis) * item_bytesize;
106 auto N = indices.numel();
107 auto gathered_batch_bytesize = N * block_size * item_bytesize;
109 check_indexarray_range<Index>(idxs, N, src_indexing_axis_dim, wrap_indices);
112 if (data.template IsType<float>() && block_size == 1) {
113 for (
auto batch = 0; batch < outer_dims_product; ++batch) {
114 const float* src_floats =
115 (
const float*)(src_base + batch * src_batch_bytesize);
116 float* dst_floats = (
float*)(out + batch * gathered_batch_bytesize);
118 for (
auto i = 0; i < N; ++i) {
120 if (wrap_indices && idx < 0) {
121 idx = idx + src_indexing_axis_dim;
123 dst_floats[i] = src_floats[idx];
129 for (
auto batch = 0; batch < outer_dims_product; ++batch) {
130 for (
auto i = 0; i < N; ++i) {
132 if (wrap_indices && idx < 0) {
133 idx = idx + src_indexing_axis_dim;
136 auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize;
137 auto dst = out + batch * gathered_batch_bytesize + i * block_bytesize;
138 op->getContext()->CopyItemsSameDevice(dataType, block_size, src, dst);
147 template <
class Context>
150 USE_OPERATOR_CONTEXT_FUNCTIONS;
152 template <
class... Args>
155 OP_SINGLE_ARG(
int,
"axis", axis_, 0) {
164 "wrap_indices", (
false));
166 wrap_indices_ = (axis_ == 0) ?
true :
false;
170 virtual ~GatherOp() noexcept {}
172 bool RunOnDevice()
override {
174 this, this->
template Input<Tensor>(INDICES, CPU));
177 template <
typename Index>
178 bool DoRunWithType() {
179 return gather_helper::gather_impl<Index, Context>(
180 this, DATA, INDICES, 0, axis_, wrap_indices_);
183 INPUT_TAGS(DATA, INDICES);
191 #endif // GATHER_OP_H_ A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.