Caffe2 - C++ API
A deep learning, cross platform ML framework
1 #include "caffe2/operators/matmul_op.h"
3 namespace caffe2 {
5 REGISTER_CPU_OPERATOR(MatMul, MatMulOp<float, CPUContext>);
8  .NumInputs(2, 3)
9  .NumOutputs(1)
10  .TensorInferenceFunction([](const OperatorDef& def,
11  const vector<TensorShape>& in) {
12  vector<TensorShape> out(1);
13  out[0].set_data_type(in[0].data_type());
14  ArgumentHelper arg_helper(def);
15  int axis_a = arg_helper.GetSingleArgument<int>("axis_a", 1);
16  int axis_b = arg_helper.GetSingleArgument<int>("axis_b", 1);
17  int trans_a = arg_helper.GetSingleArgument<bool>("trans_a", false);
18  int trans_b = arg_helper.GetSingleArgument<bool>("trans_b", false);
19  int canonical_axis_a = canonical_axis_index_(axis_a, in[0].dims().size());
20  int canonical_axis_b = canonical_axis_index_(axis_b, in[0].dims().size());
22  int M = size_to_dim_(canonical_axis_a, GetDimsVector(in[0]));
23  int N = size_from_dim_(canonical_axis_b, GetDimsVector(in[1]));
24  if (trans_a) {
25  M = size_from_dim_(canonical_axis_a, GetDimsVector(in[0]));
26  }
27  if (trans_b) {
28  N = size_to_dim_(canonical_axis_b, GetDimsVector(in[1]));
29  }
31  out[0].add_dims(M);
32  out[0].add_dims(N);
34  return out;
35  })
36  .SetDoc(R"DOC(
37 Matrix multiplication $Y = A * B$, where `A` has size (M x K), `B` has size
38 (K x N), and `Y` will have a size (M x N). To transpose `A` or `B` before
39 multiplication, pass 1 to the `trans_a` and/or `trans_b` arguments, which
40 separate the first and second dimensions of the respective matrices using
41 `axis_a` and `axis_b`.
43 Github Links:
45 -
47 <details>
49 <summary> <b>Example</b> </summary>
51 **Code**
53 ```
54 workspace.ResetWorkspace()
56 op = core.CreateOperator(
57  "MatMul",
58  ["A", "B"],
59  ["Y"],
60 )
62 workspace.FeedBlob("A", np.random.randint(10, size=(3,3)).astype(np.float32))
63 workspace.FeedBlob("B", np.random.randint(10, size=(3,3)).astype(np.float32))
64 print("A:", workspace.FetchBlob("A"))
65 print("B:", workspace.FetchBlob("B"))
66 workspace.RunOperatorOnce(op)
67 print("Y:", workspace.FetchBlob("Y"))
68 ```
70 **Result**
72 ```
73 A: [[1. 8. 3.]
74  [6. 4. 4.]
75  [5. 4. 7.]]
76 B: [[4. 0. 3.]
77  [3. 1. 1.]
78  [8. 5. 8.]]
79 Y: [[52. 23. 35.]
80  [68. 24. 54.]
81  [88. 39. 75.]]
82 ```
84 </details>
86 )DOC")
87  .Input(
88  0,
89  "A",
90  "*(type: Tensor`<float>`)* 2D matrix of size (M x K).")
91  .Input(
92  1,
93  "B",
94  "*(type: Tensor`<float>`)* 2D matrix of size (K x N).")
95  .Output(
96  0,
97  "Y",
98  "*(type: Tensor`<float>`)* 2D matrix of size (M x N).")
99  .Arg(
100  "axis_a",
101  "*(type: int; default: 1)* Exclusive axis that divides the first and "
102  "second dimension of matrix `A`.")
103  .Arg(
104  "axis_b",
105  "*(type: int; default: 1)* Exclusive axis that divides the first and "
106  "second dimension of matrix `B`.")
107  .Arg(
108  "trans_a",
109  "*(type: int; default: 0)* Pass 1 to transpose `A` before multiplication and "
110  "after the dimension adjustment using `axis_a`.")
111  .Arg(
112  "trans_b",
113  "*(type: int; default: 0)* Pass 1 to transpose `B` before multiplication and "
114  "after the dimension adjustment using `axis_b`.");
116 class GetMatMulGradient : public GradientMakerBase {
117  using GradientMakerBase::GradientMakerBase;
118  vector<OperatorDef> GetGradientDefs() override {
119  CAFFE_ENFORCE(def_.input_size() == 2 || def_.input_size() == 3);
121  bool axis_a = 1;
122  bool axis_b = 1;
123  bool trans_a = 0;
124  bool trans_b = 0;
126  if (ArgumentHelper::HasArgument(Def(), "trans_a")) {
127  trans_a = GetArgument(Def(), "trans_a").i();
128  }
129  if (ArgumentHelper::HasArgument(Def(), "trans_b")) {
130  trans_b = GetArgument(Def(), "trans_b").i();
131  }
132  if (ArgumentHelper::HasArgument(Def(), "axis_a")) {
133  axis_a = GetArgument(Def(), "axis_a").i();
134  }
135  if (ArgumentHelper::HasArgument(Def(), "axis_b")) {
136  axis_b = GetArgument(Def(), "axis_b").i();
137  }
139  if (trans_a) {
140  if (trans_b) {
141  // A'B':
142  // dA = B'G', dB = G'A'
143  return vector<OperatorDef>{
144  CreateOperatorDef(
145  "MatMul",
146  "",
147  vector<string>{I(1), GO(0), I(0)},
148  vector<string>{GI(0)},
149  vector<Argument>{MakeArgument<int>("trans_a", 1),
150  MakeArgument<int>("trans_b", 1),
151  MakeArgument<int>("axis_a", axis_b)}),
152  CreateOperatorDef(
153  "MatMul",
154  "",
155  vector<string>{GO(0), I(0), I(1)},
156  vector<string>{GI(1)},
157  vector<Argument>{MakeArgument<int>("trans_a", 1),
158  MakeArgument<int>("trans_b", 1),
159  MakeArgument<int>("axis_b", axis_a)})};
160  } else {
161  // A'B:
162  // dA = BG', dB = AG
163  return vector<OperatorDef>{
164  CreateOperatorDef(
165  "MatMul",
166  "",
167  vector<string>{I(1), GO(0), I(0)},
168  vector<string>{GI(0)},
169  vector<Argument>{MakeArgument<int>("trans_b", 1),
170  MakeArgument<int>("axis_a", axis_b)}),
171  CreateOperatorDef(
172  "MatMul",
173  "",
174  vector<string>{I(0), GO(0), I(1)},
175  vector<string>{GI(1)},
176  vector<Argument>{MakeArgument<int>("axis_a", axis_a)})};
177  }
178  } else {
179  if (trans_b) {
180  // AB':
181  // dA = GB, dB = G'A
182  return vector<OperatorDef>{
183  CreateOperatorDef(
184  "MatMul",
185  "",
186  vector<string>{GO(0), I(1), I(0)},
187  vector<string>{GI(0)},
188  vector<Argument>{MakeArgument<int>("axis_b", axis_b)}),
189  CreateOperatorDef(
190  "MatMul",
191  "",
192  vector<string>{GO(0), I(0), I(1)},
193  vector<string>{GI(1)},
194  vector<Argument>{MakeArgument<int>("trans_a", 1),
195  MakeArgument<int>("axis_b", axis_a)})};
196  } else {
197  // AB:
198  // dA = GB', dB = A'G
199  return vector<OperatorDef>{
200  CreateOperatorDef(
201  "MatMul",
202  "",
203  vector<string>{GO(0), I(1), I(0)},
204  vector<string>{GI(0)},
205  vector<Argument>{MakeArgument<int>("trans_b", 1),
206  MakeArgument<int>("axis_b", axis_b)}),
207  CreateOperatorDef(
208  "MatMul",
209  "",
210  vector<string>{I(0), GO(0), I(1)},
211  vector<string>{GI(1)},
212  vector<Argument>{MakeArgument<int>("trans_a", 1),
213  MakeArgument<int>("axis_a", axis_a)})};
214  }
215  }
216  }
218  bool CopyArguments() const override {
219  return false;
220  }
221 };
223 REGISTER_GRADIENT(MatMul, GetMatMulGradient);
225 } // namespace caffe2
Definition: any.cpp:108
int64_t size_from_dim_(int k, IntArrayRef dims)
Return product of all dimensions starting from k.
Definition: TensorImpl.h:53
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13