1 #include "caffe2/operators/batch_gather_ops.h" 5 REGISTER_CPU_OPERATOR(BatchGather, BatchGatherOp<CPUContext>);
6 REGISTER_CPU_OPERATOR(BatchGatherGradient, BatchGatherGradientOp<CPUContext>);
8 OPERATOR_SCHEMA(BatchGather)
11 .TensorInferenceFunction([](
const OperatorDef& def,
12 const vector<TensorShape>& in) {
13 vector<TensorShape> out(1);
14 ArgumentHelper helper(def);
15 const auto& data_dims = GetDimsVector(in[0]);
16 const auto& indices_dims = GetDimsVector(in[1]);
18 vector<int> output_dims =
19 caffe2::gather_helper::calc_output_shape_vector<int>(
20 data_dims, indices_dims, 1);
21 out[0] = CreateTensorShape(output_dims, TensorProto::FLOAT);
25 Batch gather operation, first dimension in DATA is the batch size. 26 Given DATA tensor of rank r >= 2, and INDICES tensor of rank q >= 1, gather 27 entries of the second outer dimension (axis == 1) of DATA indexed by INDICES, 28 and concatenate them in an output tensor of rank q + (r - 1). 44 .Input(0, "DATA",
"Tensor of rank r >= 2.")
45 .Input(1,
"INDICES",
"Tensor of int32/int64 indices, of any rank q.")
46 .Output(0,
"OUTPUT",
"Tensor of rank q + (r - 1).")
49 OPERATOR_SCHEMA(BatchGatherGradient).NumInputs(3).NumOutputs(1);
52 using GradientMakerBase::GradientMakerBase;
53 vector<OperatorDef> GetGradientDefs()
override {
56 "BatchGatherGradient",
58 vector<string>{I(Op::DATA), I(Op::INDICES), GO(0)},
59 vector<string>{GI(0)});
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...