1 #ifndef CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_ 2 #define CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 8 #include "caffe2/operators/gather_op.h" 12 template <
class Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
18 bool RunOnDevice()
override {
20 this, this->
template Input<Tensor>(INDICES, CPU));
23 template <
typename TInd>
24 bool DoRunWithType() {
26 return gather_helper::gather_impl<TInd, Context>(
27 this, DATA, INDICES, 0, 1,
false);
29 INPUT_TAGS(DATA, INDICES);
32 template <
class Context>
35 USE_OPERATOR_CONTEXT_FUNCTIONS;
39 template <
class... Args>
42 OP_SINGLE_ARG(
int,
"axis", axis_, 1) {}
43 virtual ~BatchGatherGradientOp() noexcept {}
45 bool RunOnDevice()
override {
47 this, this->
template Input<Tensor>(INDICES, CPU));
50 template <
typename TInd>
51 bool DoRunWithType() {
54 TInd>::call(
this,
Input(DATA));
57 template <
typename TInd,
typename TData>
58 bool DoRunWithType2() {
59 auto& data =
Input(DATA);
60 auto& indices =
Input(INDICES);
61 auto& grad =
Input(GRAD);
66 axis = data.dim() + axis;
69 CAFFE_ENFORCE_GE(data.dim(), 2,
"DATA should be at least 2-D");
72 for (
int acheck = 0; acheck < axis; acheck++) {
76 "batch gather outer dimensions should match");
79 auto* output = Output(0, data.sizes(), at::dtype<TData>());
80 TData* out_data = output->template mutable_data<TData>();
81 if (data.numel() <= 0) {
84 memset(out_data, 0, output->nbytes());
86 const TData* grad_data = grad.template data<TData>();
87 const TInd* idxs = indices.template data<TInd>();
89 auto outer_dims_product = data.size_to_dim(axis);
90 auto batch_size = data.size_from_dim(axis);
91 auto block_size = data.size_from_dim(axis + 1);
92 auto N = indices.numel();
93 auto gathered_grad_batch_size = N * block_size;
96 auto src_indexing_axis_dim = data.dim(axis);
97 gather_helper::check_indexarray_range<TInd>(
100 src_indexing_axis_dim,
103 for (
auto batch = 0; batch < outer_dims_product; ++batch) {
104 auto grad_batch_base = grad_data + batch * gathered_grad_batch_size;
105 auto out_batch_base = out_data + batch * batch_size;
107 for (
auto i = 0; i < N; ++i) {
110 idx = idx + src_indexing_axis_dim;
112 if (block_size == 1) {
113 out_batch_base[idx] += grad_batch_base[i];
117 out_batch_base + idx * block_size,
118 grad_batch_base + i * block_size,
119 out_batch_base + idx * block_size,
127 template <
typename TInd>
128 bool DoRunWithOtherType2() {
130 "BatchGatherGradient is not implemented on tensor of type ",
131 Input(DATA).meta().name(),
132 "consider adding it as a type in the DispatchHelper list or " 133 "implementing a generic version (which won't work for " 134 "duplicated indices though)");
137 INPUT_TAGS(DATA, INDICES, GRAD);
144 #endif // CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...