Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_gather_ops.cc
1 #include "caffe2/operators/batch_gather_ops.h"
2 
3 namespace caffe2 {
4 
5 REGISTER_CPU_OPERATOR(BatchGather, BatchGatherOp<CPUContext>);
6 REGISTER_CPU_OPERATOR(BatchGatherGradient, BatchGatherGradientOp<CPUContext>);
7 
8 OPERATOR_SCHEMA(BatchGather)
9  .NumInputs(2)
10  .NumOutputs(1)
11  .TensorInferenceFunction([](const OperatorDef& def,
12  const vector<TensorShape>& in) {
13  vector<TensorShape> out(1);
14  ArgumentHelper helper(def);
15  const auto& data_dims = GetDimsVector(in[0]);
16  const auto& indices_dims = GetDimsVector(in[1]);
17 
18  vector<int> output_dims =
19  caffe2::gather_helper::calc_output_shape_vector<int>(
20  data_dims, indices_dims, 1);
21  out[0] = CreateTensorShape(output_dims, TensorProto::FLOAT);
22  return out;
23  })
24  .SetDoc(R"DOC(
25 Batch gather operation, first dimension in DATA is the batch size.
26 Given DATA tensor of rank r >= 2, and INDICES tensor of rank q >= 1, gather
27 entries of the second outer dimension (axis == 1) of DATA indexed by INDICES,
28 and concatenate them in an output tensor of rank q + (r - 1).
29 
30 Example:
31  DATA = [
32  [1.0, 1.2, 2.4, 4.5],
33  [2.3, 3.4, 3.6, 2.3],
34  [4.5, 5.7, 1.2, 4.5],
35  ]
36  INDICES = [0, 2]
37 
38  OUTPUT = [
39  [1.0, 2.4],
40  [2.3, 3.6],
41  [4.5, 1.2],
42  ]
43 )DOC")
44  .Input(0, "DATA", "Tensor of rank r >= 2.")
45  .Input(1, "INDICES", "Tensor of int32/int64 indices, of any rank q.")
46  .Output(0, "OUTPUT", "Tensor of rank q + (r - 1).")
47  .InheritOnnxSchema();
48 
49 OPERATOR_SCHEMA(BatchGatherGradient).NumInputs(3).NumOutputs(1);
50 
52  using GradientMakerBase::GradientMakerBase;
53  vector<OperatorDef> GetGradientDefs() override {
54  using Op = BatchGatherOp<CPUContext>;
55  return SingleGradientDef(
56  "BatchGatherGradient",
57  "",
58  vector<string>{I(Op::DATA), I(Op::INDICES), GO(0)},
59  vector<string>{GI(0)});
60  }
61 };
62 
63 REGISTER_GRADIENT(BatchGather, GetBatchGatherGradient);
64 
65 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...