Caffe2 - C++ API
A deep learning, cross platform ML framework
gather_op.cc
1 #include "gather_op.h"
2 namespace caffe2 {
3 
4 REGISTER_CPU_OPERATOR(Gather, GatherOp<CPUContext>);
5 
6 OPERATOR_SCHEMA(Gather)
7  .NumInputs(2)
8  .NumOutputs(1)
9  .SetDoc(R"DOC(
10 
11 The *Gather* op accepts a *DATA* tensor of rank $r >= 1$ and *INDICES* tensor of rank $q$ as inputs. It then gathers entries of the outer-most dimension of *DATA*, indexed by *INDICES*, and concatenate them in an output tensor of rank $q + (r - 1)$.
12 
13 Github Links:
14 
15 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/gather_op.cc
16 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/gather_op.h
17 
18 
19 <details>
20 
21 <summary> <b>Example</b> </summary>
22 
23 **Code**
24 
25 ```
26 
27 workspace.ResetWorkspace()
28 
29 op = core.CreateOperator(
30  "Gather",
31  ["DATA", "INDICES"],
32  ["OUTPUT"]
33 )
34 data = np.array([[1., 1.2],[2.3, 3.4],[4.5, 5.7]])
35 print("DATA:\n",data)
36 
37 inds = np.array([[0, 1],[1, 2]])
38 print("INDICES:\n",inds)
39 
40 // Feed X into workspace
41 workspace.FeedBlob("DATA", data.astype(np.float32))
42 workspace.FeedBlob("INDICES", inds.astype(np.int32))
43 
44 workspace.RunOperatorOnce(op)
45 print("OUTPUT:\n", workspace.FetchBlob("OUTPUT"))
46 
47 ```
48 
49 **Result**
50 
51 ```
52 
53 DATA:
54  [[1. 1.2]
55  [2.3 3.4]
56  [4.5 5.7]]
57 INDICES:
58  [[0 1]
59  [1 2]]
60 OUTPUT:
61  [[[1. 1.2]
62  [2.3 3.4]]
63 
64  [[2.3 3.4]
65  [4.5 5.7]]]
66 
67 ```
68 
69 </details>
70 
71 )DOC")
72  .Input(0, "DATA", "Input data tensor of rank $r>=1$")
73  .Input(
74  1,
75  "INDICES",
76  "Input indices tensor of rank $q$. This tensor must contain integers.")
77  .Output(0, "OUTPUT", "Output tensor of rank $q+(r-1)$")
78  .TensorInferenceFunction([](const OperatorDef& def,
79  const vector<TensorShape>& in) {
80  ArgumentHelper helper(def);
81  const int axis = helper.GetSingleArgument<int>("axis", 0);
82  const auto& data_dims = GetDimsVector(in[0]);
83  const auto& indices_dims = GetDimsVector(in[1]);
84 
85  vector<int> output_dims =
86  caffe2::gather_helper::calc_output_shape_vector<int>(
87  data_dims, indices_dims, axis);
88  vector<TensorShape> out(1);
89  out[0] = CreateTensorShape(output_dims, in[0].data_type());
90  return out;
91  })
92  .InheritOnnxSchema();
93 
94 class GetGatherGradient : public GradientMakerBase {
95  using GradientMakerBase::GradientMakerBase;
96 
97  vector<OperatorDef> GetGradientDefs() override {
98  ArgumentHelper argsHelper(def_);
99  const bool dense_gradient =
100  argsHelper.GetSingleArgument<bool>("dense_gradient", false);
101  const int axis = argsHelper.GetSingleArgument<int>("axis", 0);
102 
103  // TBD: While it hasn't been used yet, we need to add wrap_indices support
104  // to gradients next.
105  // if (argsHelper.HasArgument("wrap_indices_")) {
106  // }
107 
108  using Op = GatherOp<CPUContext>;
109 
110  if (axis == 0) {
111  if (dense_gradient) {
112  return vector<OperatorDef>{CreateOperatorDef(
113  "SparseToDense",
114  "",
115  vector<string>{I(Op::INDICES), GO(0), I(Op::DATA)},
116  vector<string>{GI(Op::DATA)})};
117  } else {
118  // For now we don't do any reshaping as the consumer of this op would
119  // probably be ScatterUpdate which is intenionally ignores shapes. We
120  // might need to revisit it in the future for correctness purposes. The
121  // right shape for the output woild be to flatten INDICES and collapse
122  // first X dims of GRAD
123  SetSparse(Op::DATA, I(Op::INDICES), GO(0));
124  return vector<OperatorDef>();
125  }
126  }
127 
128  // TBD: This is misleading to use dense_gradient by default for axis 0
129  // and not othewise....
130  if (argsHelper.HasArgument("dense_gradient")) {
131  CAFFE_ENFORCE(
132  dense_gradient == true,
133  "Gather with axis > 0 must use dense_gradient");
134  }
135 
136  Argument axisArg = MakeArgument<int>("axis", axis);
137  return SingleGradientDef(
138  "BatchGatherGradient",
139  "",
140  // This is the order as expected by BatchGatherGradient indices,
141  // different from SpartseToDense above.
142  vector<string>{I(Op::DATA), I(Op::INDICES), GO(0)},
143  vector<string>{GI(0)},
144  std::vector<Argument>{axisArg});
145  }
146 };
147 REGISTER_GRADIENT(Gather, GetGatherGradient);
148 
149 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13