Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_matmul_op.cc
1 #include "caffe2/operators/batch_matmul_op.h"
2 #include "caffe2/core/operator_schema.h"
3 
4 namespace caffe2 {
5 
6 REGISTER_CPU_OPERATOR(BatchMatMul, BatchMatMulOp<CPUContext>);
7 
8 vector<TensorShape> TensorInferenceForBatchMatMul(
9  const OperatorDef& def,
10  const vector<TensorShape>& in) {
11  ArgumentHelper helper(def);
12  bool broadcast = helper.GetSingleArgument<int>("broadcast", 0);
13  if (!broadcast) {
14  const auto ndim = in[0].dims_size();
15  CAFFE_ENFORCE_GE(ndim, 2);
16  CAFFE_ENFORCE_GE(in[1].dims_size(), 2);
17  int a_dim0;
18  int b_dim1;
19  if (helper.GetSingleArgument<int>("trans_a", 0)) {
20  a_dim0 = in[0].dims(ndim - 1);
21  } else {
22  a_dim0 = in[0].dims(ndim - 2);
23  }
24 
25  if (helper.GetSingleArgument<int>("trans_b", 0)) {
26  b_dim1 = in[1].dims(ndim - 2);
27  } else {
28  b_dim1 = in[1].dims(ndim - 1);
29  }
30 
31  auto output_dims = vector<int64_t>{in[0].dims().begin(), in[0].dims().end()};
32  output_dims[ndim - 2] = a_dim0;
33  output_dims[ndim - 1] = b_dim1;
34 
35  return vector<TensorShape>{
36  CreateTensorShape(vector<int64_t>{output_dims}, in[0].data_type())};
37  } else {
38  auto ndims_A = in[0].dims_size();
39  auto ndims_B = in[1].dims_size();
40  std::vector<int64_t> dims_A(ndims_A), dims_B(ndims_B);
41  for (int i = 0; i < ndims_A; ++i) {
42  dims_A[i] = in[0].dims(i);
43  }
44  for (int i = 0; i < ndims_B; ++i) {
45  dims_B[i] = in[1].dims(i);
46  }
47  bool A_broadcasted = false, B_broadcasted = false;
48  if (ndims_A == 1) {
49  dims_A.insert(dims_A.begin(), 1);
50  ndims_A = 2;
51  A_broadcasted = true;
52  }
53  if (ndims_B == 1) {
54  dims_B.push_back(1);
55  ndims_B = 2;
56  B_broadcasted = true;
57  }
58  size_t M, N;
59  if (helper.GetSingleArgument<int>("trans_a", 0)) {
60  M = dims_A[ndims_A - 1];
61  } else {
62  M = dims_A[ndims_A - 2];
63  }
64  if (helper.GetSingleArgument<int>("trans_b", 0)) {
65  N = dims_B[ndims_B - 2];
66  } else {
67  N = dims_B[ndims_B - 1];
68  }
69 
70  std::vector<int64_t> new_dims;
71  if (ndims_A >= ndims_B) {
72  new_dims.assign(dims_A.begin(), dims_A.end() - 2);
73  } else {
74  new_dims.assign(dims_B.begin(), dims_B.end() - 2);
75  }
76  if (!A_broadcasted) {
77  new_dims.push_back(M);
78  }
79  if (!B_broadcasted) {
80  new_dims.push_back(N);
81  }
82  if (A_broadcasted && B_broadcasted) {
83  new_dims.push_back(1);
84  }
85  return vector<TensorShape>{
86  CreateTensorShape(vector<int64_t>{new_dims}, in[0].data_type())};
87  }
88 }
89 
90 OpSchema::Cost CostInferenceForBatchMatMul(
91  const OperatorDef& def,
92  const vector<TensorShape>& in) {
93  CAFFE_ENFORCE_EQ(in.size(), 2, "BatchMatMul requires two inputs");
94 
95  ArgumentHelper helper(def);
96  struct OpSchema::Cost c;
97  const auto& A = in[0];
98  const auto& B = in[1];
99  const TensorShape Y = TensorInferenceForBatchMatMul(def, in)[0];
100 
101  uint64_t nElemA = nElemFromDim(A);
102  uint64_t nElemB = nElemFromDim(B);
103  uint64_t nElemY = nElemFromDim(Y);
104 
105  auto ndims_A = A.dims_size();
106  size_t K;
107  if (helper.GetSingleArgument<int>("trans_a", 0)) {
108  K = in[0].dims(ndims_A - 2);
109  } else {
110  K = in[0].dims(ndims_A - 1);
111  }
112 
113  c.flops = 2 * nElemY * K;
114  c.bytes_read = (nElemA + nElemB) * sizeof(A.data_type());
115  c.bytes_written = nElemY * sizeof(Y.data_type());
116  c.params_bytes = 0;
117  return c;
118 }
119 
120 OPERATOR_SCHEMA(BatchMatMul)
121  .NumInputs(2)
122  .NumOutputs(1)
123  .SetDoc(R"DOC(
124 Batch Matrix multiplication Yi = Ai * Bi, where A has shape (dim0, dim1, ... M, K),
125 B has shape (dim0, dim1, ... K, N), Y has shape (dim0, dim1, ... M, N) and i ranges
126 from 0 to (dim0 * dim1 ...) - 1. rank(A) == rank(B) >= 2. In case of A and B being
127 two diemnsional, it behaves like normal matrix multiplication.
128 )DOC")
129  .Input(0, "A", "tensor of shape (dim0, dim1 ... M, K)")
130  .Input(1, "B", "tensor of shpae (dim0, dim2 ... K, N)")
131  .Output(0, "Y", "tensor of shape (dim0, dim1 ... M, N)")
132  .Arg(
133  "trans_a",
134  "Pass 1 to transpose the last two dimensions of A before "
135  "doing multiplication")
136  .Arg(
137  "trans_b",
138  "Pass 1 to transpose the last two dimensions of B before "
139  "doing multiplication")
140  .Arg(
141  "broadcast",
142  "Pass 1 to allow broadcasting of dimensions. Behavior is the same as numpy.matmul. Gradient is currently not supported when running in broadcast mode.")
143  .TensorInferenceFunction(TensorInferenceForBatchMatMul)
144  .CostInferenceFunction(
145  OpSchema::CostInferenceFunctionType(CostInferenceForBatchMatMul))
146  .InheritOnnxSchema();
147 
148 class GetBatchMatMulGradient : public GradientMakerBase {
149  using GradientMakerBase::GradientMakerBase;
150  vector<OperatorDef> GetGradientDefs() override {
151  CAFFE_ENFORCE_EQ(def_.input_size(), 2);
152 
153  bool broadcast = false;
154  if (ArgumentHelper::HasArgument(Def(), "broadcast")) {
155  broadcast = GetArgument(Def(), "broadcast").i();
156  }
157  CAFFE_ENFORCE(
158  !broadcast,
159  "Gradient is currently not supported with "
160  "broadcast=1 for BatchMatMul.");
161 
162  bool trans_a = 0;
163  bool trans_b = 0;
164 
165  if (ArgumentHelper::HasArgument(Def(), "trans_a")) {
166  trans_a = GetArgument(Def(), "trans_a").i();
167  }
168  if (ArgumentHelper::HasArgument(Def(), "trans_b")) {
169  trans_b = GetArgument(Def(), "trans_b").i();
170  }
171 
172  auto no_trans_arg = vector<Argument>();
173  auto trans_a_arg = vector<Argument>{MakeArgument<int>("trans_a", 1)};
174  auto trans_b_arg = vector<Argument>{MakeArgument<int>("trans_b", 1)};
175  auto trans_both_arg = vector<Argument>{MakeArgument<int>("trans_a", 1),
176  MakeArgument<int>("trans_b", 1)};
177 
178  if (trans_a) {
179  if (trans_b) {
180  // A'B':
181  // dA = B'G', dB = G'A'
182  return vector<OperatorDef>{CreateOperatorDef(
183  "BatchMatMul",
184  "",
185  vector<string>{I(1), GO(0)},
186  vector<string>{GI(0)},
187  trans_both_arg),
188  CreateOperatorDef(
189  "BatchMatMul",
190  "",
191  vector<string>{GO(0), I(0)},
192  vector<string>{GI(1)},
193  trans_both_arg)};
194  } else {
195  // A'B:
196  // dA = BG', dB = AG
197  return vector<OperatorDef>{CreateOperatorDef(
198  "BatchMatMul",
199  "",
200  vector<string>{I(1), GO(0)},
201  vector<string>{GI(0)},
202  trans_b_arg),
203  CreateOperatorDef(
204  "BatchMatMul",
205  "",
206  vector<string>{I(0), GO(0)},
207  vector<string>{GI(1)},
208  no_trans_arg)};
209  }
210  } else {
211  if (trans_b) {
212  // AB':
213  // dA = GB, dB = G'A
214  return vector<OperatorDef>{CreateOperatorDef(
215  "BatchMatMul",
216  "",
217  vector<string>{GO(0), I(1)},
218  vector<string>{GI(0)},
219  no_trans_arg),
220  CreateOperatorDef(
221  "BatchMatMul",
222  "",
223  vector<string>{GO(0), I(0)},
224  vector<string>{GI(1)},
225  trans_a_arg)};
226  } else {
227  // AB:
228  // dA = GB', dB = A'G
229  return vector<OperatorDef>{CreateOperatorDef(
230  "BatchMatMul",
231  "",
232  vector<string>{GO(0), I(1)},
233  vector<string>{GI(0)},
234  trans_b_arg),
235  CreateOperatorDef(
236  "BatchMatMul",
237  "",
238  vector<string>{I(0), GO(0)},
239  vector<string>{GI(1)},
240  trans_a_arg)};
241  }
242  }
243  }
244 
245  bool CopyArguments() const override {
246  return false;
247  }
248 };
249 
250 REGISTER_GRADIENT(BatchMatMul, GetBatchMatMulGradient);
251 
252 } // namespace caffe2
Definition: any.cpp:108
Definition: static.cpp:52
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:58
std::function< struct Cost(const OperatorDef &, const vector< TensorShape > &)> CostInferenceFunctionType
Registers a function that takes in an OperatorDef and a series of input shapes and returns the total ...