1 #include "caffe2/operators/batch_matmul_op.h" 2 #include "caffe2/core/operator_schema.h" 6 REGISTER_CPU_OPERATOR(BatchMatMul, BatchMatMulOp<CPUContext>);
8 vector<TensorShape> TensorInferenceForBatchMatMul(
9 const OperatorDef& def,
10 const vector<TensorShape>& in) {
11 ArgumentHelper helper(def);
12 bool broadcast = helper.GetSingleArgument<
int>(
"broadcast", 0);
14 const auto ndim = in[0].dims_size();
15 CAFFE_ENFORCE_GE(ndim, 2);
16 CAFFE_ENFORCE_GE(in[1].dims_size(), 2);
19 if (helper.GetSingleArgument<
int>(
"trans_a", 0)) {
20 a_dim0 = in[0].dims(ndim - 1);
22 a_dim0 = in[0].dims(ndim - 2);
25 if (helper.GetSingleArgument<
int>(
"trans_b", 0)) {
26 b_dim1 = in[1].dims(ndim - 2);
28 b_dim1 = in[1].dims(ndim - 1);
31 auto output_dims = vector<int64_t>{in[0].dims().begin(), in[0].dims().end()};
32 output_dims[ndim - 2] = a_dim0;
33 output_dims[ndim - 1] = b_dim1;
35 return vector<TensorShape>{
36 CreateTensorShape(vector<int64_t>{output_dims}, in[0].data_type())};
38 auto ndims_A = in[0].dims_size();
39 auto ndims_B = in[1].dims_size();
40 std::vector<int64_t> dims_A(ndims_A), dims_B(ndims_B);
41 for (
int i = 0; i < ndims_A; ++i) {
42 dims_A[i] = in[0].dims(i);
44 for (
int i = 0; i < ndims_B; ++i) {
45 dims_B[i] = in[1].dims(i);
47 bool A_broadcasted =
false, B_broadcasted =
false;
49 dims_A.insert(dims_A.begin(), 1);
59 if (helper.GetSingleArgument<
int>(
"trans_a", 0)) {
60 M = dims_A[ndims_A - 1];
62 M = dims_A[ndims_A - 2];
64 if (helper.GetSingleArgument<
int>(
"trans_b", 0)) {
65 N = dims_B[ndims_B - 2];
67 N = dims_B[ndims_B - 1];
70 std::vector<int64_t> new_dims;
71 if (ndims_A >= ndims_B) {
72 new_dims.assign(dims_A.begin(), dims_A.end() - 2);
74 new_dims.assign(dims_B.begin(), dims_B.end() - 2);
77 new_dims.push_back(M);
80 new_dims.push_back(N);
82 if (A_broadcasted && B_broadcasted) {
83 new_dims.push_back(1);
85 return vector<TensorShape>{
86 CreateTensorShape(vector<int64_t>{new_dims}, in[0].data_type())};
90 OpSchema::Cost CostInferenceForBatchMatMul(
91 const OperatorDef& def,
92 const vector<TensorShape>& in) {
93 CAFFE_ENFORCE_EQ(in.size(), 2,
"BatchMatMul requires two inputs");
95 ArgumentHelper helper(def);
96 struct OpSchema::Cost c;
97 const auto&
A = in[0];
98 const auto&
B = in[1];
99 const TensorShape Y = TensorInferenceForBatchMatMul(def, in)[0];
101 uint64_t nElemA = nElemFromDim(A);
102 uint64_t nElemB = nElemFromDim(
B);
103 uint64_t nElemY = nElemFromDim(Y);
105 auto ndims_A = A.dims_size();
107 if (helper.GetSingleArgument<
int>(
"trans_a", 0)) {
108 K = in[0].dims(ndims_A - 2);
110 K = in[0].dims(ndims_A - 1);
113 c.flops = 2 * nElemY * K;
114 c.bytes_read = (nElemA + nElemB) *
sizeof(A.data_type());
115 c.bytes_written = nElemY *
sizeof(Y.data_type());
120 OPERATOR_SCHEMA(BatchMatMul)
124 Batch Matrix multiplication Yi = Ai * Bi, where A has shape (dim0, dim1, ... M, K), 125 B has shape (dim0, dim1, ... K, N), Y has shape (dim0, dim1, ... M, N) and i ranges 126 from 0 to (dim0 * dim1 ...) - 1. rank(A) == rank(B) >= 2. In case of A and B being 127 two diemnsional, it behaves like normal matrix multiplication. 129 .Input(0, "A",
"tensor of shape (dim0, dim1 ... M, K)")
130 .Input(1,
"B",
"tensor of shpae (dim0, dim2 ... K, N)")
131 .Output(0,
"Y",
"tensor of shape (dim0, dim1 ... M, N)")
134 "Pass 1 to transpose the last two dimensions of A before " 135 "doing multiplication")
138 "Pass 1 to transpose the last two dimensions of B before " 139 "doing multiplication")
142 "Pass 1 to allow broadcasting of dimensions. Behavior is the same as numpy.matmul. Gradient is currently not supported when running in broadcast mode.")
143 .TensorInferenceFunction(TensorInferenceForBatchMatMul)
144 .CostInferenceFunction(
146 .InheritOnnxSchema();
148 class GetBatchMatMulGradient :
public GradientMakerBase {
149 using GradientMakerBase::GradientMakerBase;
150 vector<OperatorDef> GetGradientDefs()
override {
151 CAFFE_ENFORCE_EQ(def_.input_size(), 2);
153 bool broadcast =
false;
154 if (ArgumentHelper::HasArgument(Def(),
"broadcast")) {
155 broadcast = GetArgument(Def(),
"broadcast").i();
159 "Gradient is currently not supported with " 160 "broadcast=1 for BatchMatMul.");
165 if (ArgumentHelper::HasArgument(Def(),
"trans_a")) {
166 trans_a = GetArgument(Def(),
"trans_a").i();
168 if (ArgumentHelper::HasArgument(Def(),
"trans_b")) {
169 trans_b = GetArgument(Def(),
"trans_b").i();
172 auto no_trans_arg = vector<Argument>();
173 auto trans_a_arg = vector<Argument>{MakeArgument<int>(
"trans_a", 1)};
174 auto trans_b_arg = vector<Argument>{MakeArgument<int>(
"trans_b", 1)};
175 auto trans_both_arg = vector<Argument>{MakeArgument<int>(
"trans_a", 1),
176 MakeArgument<int>(
"trans_b", 1)};
182 return vector<OperatorDef>{CreateOperatorDef(
185 vector<string>{I(1), GO(0)},
186 vector<string>{GI(0)},
191 vector<string>{GO(0), I(0)},
192 vector<string>{GI(1)},
197 return vector<OperatorDef>{CreateOperatorDef(
200 vector<string>{I(1), GO(0)},
201 vector<string>{GI(0)},
206 vector<string>{I(0), GO(0)},
207 vector<string>{GI(1)},
214 return vector<OperatorDef>{CreateOperatorDef(
217 vector<string>{GO(0), I(1)},
218 vector<string>{GI(0)},
223 vector<string>{GO(0), I(0)},
224 vector<string>{GI(1)},
229 return vector<OperatorDef>{CreateOperatorDef(
232 vector<string>{GO(0), I(1)},
233 vector<string>{GI(0)},
238 vector<string>{I(0), GO(0)},
239 vector<string>{GI(1)},
245 bool CopyArguments()
const override {
250 REGISTER_GRADIENT(BatchMatMul, GetBatchMatMulGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
std::function< struct Cost(const OperatorDef &, const vector< TensorShape > &)> CostInferenceFunctionType
Registers a function that takes in an OperatorDef and a series of input shapes and returns the total ...