4 REGISTER_CPU_OPERATOR(Gather, GatherOp<CPUContext>);
6 OPERATOR_SCHEMA(Gather)
11 The *Gather* op accepts a *DATA* tensor of rank $r >= 1$ and *INDICES* tensor of rank $q$ as inputs. It then gathers entries of the outer-most dimension of *DATA*, indexed by *INDICES*, and concatenate them in an output tensor of rank $q + (r - 1)$. 15 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/gather_op.cc 16 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/gather_op.h 21 <summary> <b>Example</b> </summary> 27 workspace.ResetWorkspace() 29 op = core.CreateOperator( 34 data = np.array([[1., 1.2],[2.3, 3.4],[4.5, 5.7]]) 37 inds = np.array([[0, 1],[1, 2]]) 38 print("INDICES:\n",inds) 40 // Feed X into workspace 41 workspace.FeedBlob("DATA", data.astype(np.float32)) 42 workspace.FeedBlob("INDICES", inds.astype(np.int32)) 44 workspace.RunOperatorOnce(op) 45 print("OUTPUT:\n", workspace.FetchBlob("OUTPUT")) 72 .Input(0, "DATA",
"Input data tensor of rank $r>=1$")
76 "Input indices tensor of rank $q$. This tensor must contain integers.")
77 .Output(0,
"OUTPUT",
"Output tensor of rank $q+(r-1)$")
78 .TensorInferenceFunction([](
const OperatorDef& def,
79 const vector<TensorShape>& in) {
80 ArgumentHelper helper(def);
81 const int axis = helper.GetSingleArgument<
int>(
"axis", 0);
82 const auto& data_dims = GetDimsVector(in[0]);
83 const auto& indices_dims = GetDimsVector(in[1]);
85 vector<int> output_dims =
86 caffe2::gather_helper::calc_output_shape_vector<int>(
87 data_dims, indices_dims, axis);
88 vector<TensorShape> out(1);
89 out[0] = CreateTensorShape(output_dims, in[0].data_type());
94 class GetGatherGradient :
public GradientMakerBase {
95 using GradientMakerBase::GradientMakerBase;
97 vector<OperatorDef> GetGradientDefs()
override {
98 ArgumentHelper argsHelper(def_);
99 const bool dense_gradient =
100 argsHelper.GetSingleArgument<
bool>(
"dense_gradient",
false);
101 const int axis = argsHelper.GetSingleArgument<
int>(
"axis", 0);
108 using Op = GatherOp<CPUContext>;
111 if (dense_gradient) {
112 return vector<OperatorDef>{CreateOperatorDef(
115 vector<string>{I(Op::INDICES), GO(0), I(Op::DATA)},
116 vector<string>{GI(Op::DATA)})};
123 SetSparse(Op::DATA, I(Op::INDICES), GO(0));
124 return vector<OperatorDef>();
130 if (argsHelper.HasArgument(
"dense_gradient")) {
132 dense_gradient ==
true,
133 "Gather with axis > 0 must use dense_gradient");
136 Argument axisArg = MakeArgument<int>(
"axis", axis);
137 return SingleGradientDef(
138 "BatchGatherGradient",
142 vector<string>{I(Op::DATA), I(Op::INDICES), GO(0)},
143 vector<string>{GI(0)},
144 std::vector<Argument>{axisArg});
147 REGISTER_GRADIENT(Gather, GetGatherGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...