Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_gather_ops.cc
1 
17 #include "caffe2/operators/batch_gather_ops.h"
18 
19 namespace caffe2 {
20 
21 REGISTER_CPU_OPERATOR(BatchGather, BatchGatherOp<CPUContext>);
22 REGISTER_CPU_OPERATOR(BatchGatherGradient, BatchGatherGradientOp<CPUContext>);
23 
24 OPERATOR_SCHEMA(BatchGather)
25  .NumInputs(2)
26  .NumOutputs(1)
27  .TensorInferenceFunction([](const OperatorDef& def,
28  const vector<TensorShape>& in) {
29  vector<TensorShape> out(1);
30  ArgumentHelper helper(def);
31 
32  vector<int> output_dims;
33  const auto& data_dims = GetDimsVector(in[0]);
34  const auto& indices_dims = GetDimsVector(in[1]);
35  output_dims.push_back(data_dims[0]);
36  output_dims.insert(
37  output_dims.end(), indices_dims.begin(), indices_dims.end());
38  output_dims.insert(
39  output_dims.end(), data_dims.begin() + 2, data_dims.end());
40 
41  out[0] = CreateTensorShape(output_dims, TensorProto::FLOAT);
42  return out;
43  })
44  .SetDoc(R"DOC(
45 Batch gather operation, first dimension in DATA is the batch size.
46 Given DATA tensor of rank r >= 2, and INDICES tensor of rank q >= 1, gather
47 entries of the outer-most dimension of DATA indexed by INDICES, and concatenate
48 them in an output tensor of rank (q - 1) + (r - 1).
49 
50 Example:
51  DATA = [
52  [1.0, 1.2, 2.4, 4.5],
53  [2.3, 3.4, 3.6, 2.3],
54  [4.5, 5.7, 1.2, 4.5],
55  ]
56  INDICES = [
57  [0, 2],
58  ]
59  OUTPUT = [
60  [1.0, 2.4],
61  [2.3, 3.6],
62  [4.5, 1.2],
63  ]
64 )DOC")
65  .Input(0, "DATA", "Tensor of rank r >= 2.")
66  .Input(1, "INDICES", "Tensor of int32/int64 indices, of any rank q.")
67  .Output(0, "OUTPUT", "Tensor of rank (q - 1) + (r - 1).");
68 
69 OPERATOR_SCHEMA(BatchGatherGradient).NumInputs(3).NumOutputs(1);
70 
72  using GradientMakerBase::GradientMakerBase;
73  vector<OperatorDef> GetGradientDefs() override {
74  using Op = BatchGatherOp<CPUContext>;
75  return SingleGradientDef(
76  "BatchGatherGradient",
77  "",
78  vector<string>{I(Op::DATA), I(Op::INDICES), GO(0)},
79  vector<string>{GI(0)});
80  }
81 };
82 
83 REGISTER_GRADIENT(BatchGather, GetBatchGatherGradient);
84 
85 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...