Caffe2 - C++ API
A deep learning, cross platform ML framework
gather_fused_8bit_rowwise_op.cc
1 #include "caffe2/operators/gather_fused_8bit_rowwise_op.h"
2 
3 namespace caffe2 {
4 
5 OPERATOR_SCHEMA(GatherFused8BitRowwise)
6  .NumInputs(2)
7  .NumOutputs(1)
8  .SetDoc(R"DOC(
9 Perform the same operation as Gather, but operating on 8-bit rowwise quantized
10 matrices with fused storage (where each row stores quantized values, and then
11 the scale and offset).
12 DATA needs to have rank 2 and INDICES needs to have rank 1.
13 )DOC")
14  .Input(
15  0,
16  "DATA",
17  "uint8 tensor with rank 2 obtained with operator FloatToFused8BitRowwiseQuantized")
18  .Input(
19  1,
20  "INDICES",
21  "Integer vector containing indices of the first dimension of DATA for"
22  "the rows that are being gathered")
23  .Output(0, "OUTPUT", "output")
24  .TensorInferenceFunction([](const OperatorDef& def,
25  const vector<TensorShape>& in) {
26  vector<TensorShape> out(1);
27  for (auto d : in[1].dims()) {
28  out[0].add_dims(d);
29  }
30  for (int i = 1; i < in[0].dims_size(); ++i) {
31  out[0].add_dims(in[0].dims(i));
32  }
33  out[0].set_data_type(in[0].data_type());
34  return out;
35  });
36 
37 REGISTER_CPU_OPERATOR(
38  GatherFused8BitRowwise,
39  GatherFused8BitRowwiseOp<CPUContext>);
40 
41 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13