Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_matmul_op.cc
1 
17 #include "caffe2/operators/batch_matmul_op.h"
18 #include "caffe2/core/operator_schema.h"
19 
20 namespace caffe2 {
21 
22 REGISTER_CPU_OPERATOR(BatchMatMul, BatchMatMulOp<CPUContext>);
23 
24 OPERATOR_SCHEMA(BatchMatMul)
25  .NumInputs(2)
26  .NumOutputs(1)
27  .SetDoc(R"DOC(
28 Batch Matrix multiplication Yi = Ai * Bi, where A has shape (dim0, dim1, ... M, K),
29 B has shape (dim0, dim1, ... K, N), Y has shape (dim0, dim1, ... M, N) and i ranges
30 from 0 to (dim0 * dim1 ...) - 1. rank(A) == rank(B) >= 2. In case of A and B being
31 two diemnsional, it behaves like normal matrix multiplication.
32 )DOC")
33  .Input(0, "A", "tensor of shape (dim0, dim1 ... M, K)")
34  .Input(1, "B", "tensor of shpae (dim0, dim2 ... K, N)")
35  .Output(0, "Y", "tensor of shape (dim0, dim1 ... M, N)")
36  .Arg(
37  "trans_a",
38  "Pass 1 to transpose the last two dimensions of A before "
39  "doing multiplication")
40  .Arg(
41  "trans_b",
42  "Pass 1 to transpose the last two dimensions of B before "
43  "doing multiplication")
44  .Arg(
45  "broadcast",
46  "Pass 1 to allow broadcasting of dimensions. Behavior is the same as numpy.matmul. Gradient is currently not supported when running in broadcast mode.")
47  .TensorInferenceFunction([](const OperatorDef& def,
48  const vector<TensorShape>& in) {
49  ArgumentHelper helper(def);
50  bool broadcast = helper.GetSingleArgument<int>("broadcast", 0);
51  if (!broadcast) {
52  const auto ndim = in[0].dims_size();
53  CAFFE_ENFORCE_GE(ndim, 2);
54  int a_dim0;
55  int b_dim1;
56  if (helper.GetSingleArgument<int>("trans_a", 0)) {
57  a_dim0 = in[0].dims(ndim - 1);
58  } else {
59  a_dim0 = in[0].dims(ndim - 2);
60  }
61 
62  if (helper.GetSingleArgument<int>("trans_b", 0)) {
63  b_dim1 = in[1].dims(ndim - 2);
64  } else {
65  b_dim1 = in[1].dims(ndim - 1);
66  }
67 
68  auto output_dims =
69  vector<TIndex>{in[0].dims().begin(), in[0].dims().end()};
70  output_dims[ndim - 2] = a_dim0;
71  output_dims[ndim - 1] = b_dim1;
72 
73  return vector<TensorShape>{
74  CreateTensorShape(vector<TIndex>{output_dims}, in[0].data_type())};
75  } else {
76  auto ndims_A = in[0].dims_size();
77  auto ndims_B = in[1].dims_size();
78  std::vector<TIndex> dims_A(ndims_A), dims_B(ndims_B);
79  for (int i = 0; i < ndims_A; ++i) {
80  dims_A[i] = in[0].dims(i);
81  }
82  for (int i = 0; i < ndims_B; ++i) {
83  dims_B[i] = in[1].dims(i);
84  }
85  bool A_broadcasted = false, B_broadcasted = false;
86  if (ndims_A == 1) {
87  dims_A.insert(dims_A.begin(), 1);
88  ndims_A = 2;
89  A_broadcasted = true;
90  }
91  if (ndims_B == 1) {
92  dims_B.push_back(1);
93  ndims_B = 2;
94  B_broadcasted = true;
95  }
96  size_t M, N, K, K_dim;
97  if (helper.GetSingleArgument<int>("trans_a", 0)) {
98  M = dims_A[ndims_A - 1];
99  K = dims_A[ndims_A - 2];
100  K_dim = ndims_A - 2;
101  } else {
102  M = dims_A[ndims_A - 2];
103  K = dims_A[ndims_A - 1];
104  K_dim = ndims_A - 1;
105  }
106  if (helper.GetSingleArgument<int>("trans_b", 0)) {
107  N = dims_B[ndims_B - 2];
108  } else {
109  N = dims_B[ndims_B - 1];
110  }
111 
112  std::vector<TIndex> new_dims;
113  if (ndims_A >= ndims_B) {
114  new_dims.assign(dims_A.begin(), dims_A.end() - 2);
115  } else {
116  new_dims.assign(dims_B.begin(), dims_B.end() - 2);
117  }
118  if (!A_broadcasted) {
119  new_dims.push_back(M);
120  }
121  if (!B_broadcasted) {
122  new_dims.push_back(N);
123  }
124  if (A_broadcasted && B_broadcasted) {
125  new_dims.push_back(1);
126  }
127  return vector<TensorShape>{
128  CreateTensorShape(vector<TIndex>{new_dims}, in[0].data_type())};
129  }
130  });
131 
132 class GetBatchMatMulGradient : public GradientMakerBase {
133  using GradientMakerBase::GradientMakerBase;
134  vector<OperatorDef> GetGradientDefs() override {
135  CAFFE_ENFORCE_EQ(def_.input_size(), 2);
136 
137  bool broadcast = false;
138  if (ArgumentHelper::HasArgument(Def(), "broadcast")) {
139  broadcast = GetArgument(Def(), "broadcast").i();
140  }
141  CAFFE_ENFORCE(
142  !broadcast,
143  "Gradient is currently not supported with "
144  "broadcast=1 for BatchMatMul.");
145 
146  bool trans_a = 0;
147  bool trans_b = 0;
148 
149  if (ArgumentHelper::HasArgument(Def(), "trans_a")) {
150  trans_a = GetArgument(Def(), "trans_a").i();
151  }
152  if (ArgumentHelper::HasArgument(Def(), "trans_b")) {
153  trans_b = GetArgument(Def(), "trans_b").i();
154  }
155 
156  auto no_trans_arg = vector<Argument>();
157  auto trans_a_arg = vector<Argument>{MakeArgument<int>("trans_a", 1)};
158  auto trans_b_arg = vector<Argument>{MakeArgument<int>("trans_b", 1)};
159  auto trans_both_arg = vector<Argument>{MakeArgument<int>("trans_a", 1),
160  MakeArgument<int>("trans_b", 1)};
161 
162  if (ArgumentHelper::HasArgument(Def(), "use_scratch")) {
163  no_trans_arg.push_back(MakeArgument<int>("use_scratch", 1));
164  trans_a_arg.push_back(MakeArgument<int>("use_scratch", 1));
165  trans_b_arg.push_back(MakeArgument<int>("use_scratch", 1));
166  trans_both_arg.push_back(MakeArgument<int>("use_scratch", 1));
167  }
168 
169  if (trans_a) {
170  if (trans_b) {
171  // A'B':
172  // dA = B'G', dB = G'A'
173  return vector<OperatorDef>{CreateOperatorDef(
174  "BatchMatMul",
175  "",
176  vector<string>{I(1), GO(0)},
177  vector<string>{GI(0)},
178  trans_both_arg),
179  CreateOperatorDef(
180  "BatchMatMul",
181  "",
182  vector<string>{GO(0), I(0)},
183  vector<string>{GI(1)},
184  trans_both_arg)};
185  } else {
186  // A'B:
187  // dA = BG', dB = AG
188  return vector<OperatorDef>{CreateOperatorDef(
189  "BatchMatMul",
190  "",
191  vector<string>{I(1), GO(0)},
192  vector<string>{GI(0)},
193  trans_b_arg),
194  CreateOperatorDef(
195  "BatchMatMul",
196  "",
197  vector<string>{I(0), GO(0)},
198  vector<string>{GI(1)},
199  no_trans_arg)};
200  }
201  } else {
202  if (trans_b) {
203  // AB':
204  // dA = GB, dB = G'A
205  return vector<OperatorDef>{CreateOperatorDef(
206  "BatchMatMul",
207  "",
208  vector<string>{GO(0), I(1)},
209  vector<string>{GI(0)},
210  no_trans_arg),
211  CreateOperatorDef(
212  "BatchMatMul",
213  "",
214  vector<string>{GO(0), I(0)},
215  vector<string>{GI(1)},
216  trans_a_arg)};
217  } else {
218  // AB:
219  // dA = GB', dB = A'G
220  return vector<OperatorDef>{CreateOperatorDef(
221  "BatchMatMul",
222  "",
223  vector<string>{GO(0), I(1)},
224  vector<string>{GI(0)},
225  trans_b_arg),
226  CreateOperatorDef(
227  "BatchMatMul",
228  "",
229  vector<string>{I(0), GO(0)},
230  vector<string>{GI(1)},
231  trans_a_arg)};
232  }
233  }
234  }
235 
236  bool CopyArguments() const override {
237  return false;
238  }
239 };
240 
241 REGISTER_GRADIENT(BatchMatMul, GetBatchMatMulGradient);
242 
243 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.