1 #include "caffe2/operators/matmul_op.h" 5 REGISTER_CPU_OPERATOR(MatMul, MatMulOp<float, CPUContext>);
7 OPERATOR_SCHEMA(MatMul)
10 .TensorInferenceFunction([](
const OperatorDef& def,
11 const vector<TensorShape>& in) {
12 vector<TensorShape> out(1);
13 out[0].set_data_type(in[0].data_type());
14 ArgumentHelper arg_helper(def);
15 int axis_a = arg_helper.GetSingleArgument<
int>(
"axis_a", 1);
16 int axis_b = arg_helper.GetSingleArgument<
int>(
"axis_b", 1);
17 int trans_a = arg_helper.GetSingleArgument<
bool>(
"trans_a",
false);
18 int trans_b = arg_helper.GetSingleArgument<
bool>(
"trans_b",
false);
19 int canonical_axis_a = canonical_axis_index_(axis_a, in[0].dims().size());
20 int canonical_axis_b = canonical_axis_index_(axis_b, in[0].dims().size());
22 int M = size_to_dim_(canonical_axis_a, GetDimsVector(in[0]));
28 N = size_to_dim_(canonical_axis_b, GetDimsVector(in[1]));
37 Matrix multiplication $Y = A * B$, where `A` has size (M x K), `B` has size 38 (K x N), and `Y` will have a size (M x N). To transpose `A` or `B` before 39 multiplication, pass 1 to the `trans_a` and/or `trans_b` arguments, which 40 separate the first and second dimensions of the respective matrices using 41 `axis_a` and `axis_b`. 45 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/matmul_op.cc 49 <summary> <b>Example</b> </summary> 54 workspace.ResetWorkspace() 56 op = core.CreateOperator( 62 workspace.FeedBlob("A", np.random.randint(10, size=(3,3)).astype(np.float32)) 63 workspace.FeedBlob("B", np.random.randint(10, size=(3,3)).astype(np.float32)) 64 print("A:", workspace.FetchBlob("A")) 65 print("B:", workspace.FetchBlob("B")) 66 workspace.RunOperatorOnce(op) 67 print("Y:", workspace.FetchBlob("Y")) 90 "*(type: Tensor`<float>`)* 2D matrix of size (M x K).")
94 "*(type: Tensor`<float>`)* 2D matrix of size (K x N).")
98 "*(type: Tensor`<float>`)* 2D matrix of size (M x N).")
101 "*(type: int; default: 1)* Exclusive axis that divides the first and " 102 "second dimension of matrix `A`.")
105 "*(type: int; default: 1)* Exclusive axis that divides the first and " 106 "second dimension of matrix `B`.")
109 "*(type: int; default: 0)* Pass 1 to transpose `A` before multiplication and " 110 "after the dimension adjustment using `axis_a`.")
113 "*(type: int; default: 0)* Pass 1 to transpose `B` before multiplication and " 114 "after the dimension adjustment using `axis_b`.");
117 using GradientMakerBase::GradientMakerBase;
118 vector<OperatorDef> GetGradientDefs()
override {
119 CAFFE_ENFORCE(def_.input_size() == 2 || def_.input_size() == 3);
126 if (ArgumentHelper::HasArgument(Def(),
"trans_a")) {
127 trans_a = GetArgument(Def(),
"trans_a").i();
129 if (ArgumentHelper::HasArgument(Def(),
"trans_b")) {
130 trans_b = GetArgument(Def(),
"trans_b").i();
132 if (ArgumentHelper::HasArgument(Def(),
"axis_a")) {
133 axis_a = GetArgument(Def(),
"axis_a").i();
135 if (ArgumentHelper::HasArgument(Def(),
"axis_b")) {
136 axis_b = GetArgument(Def(),
"axis_b").i();
143 return vector<OperatorDef>{
147 vector<string>{I(1), GO(0), I(0)},
148 vector<string>{GI(0)},
149 vector<Argument>{MakeArgument<int>(
"trans_a", 1),
150 MakeArgument<int>(
"trans_b", 1),
151 MakeArgument<int>(
"axis_a", axis_b)}),
155 vector<string>{GO(0), I(0), I(1)},
156 vector<string>{GI(1)},
157 vector<Argument>{MakeArgument<int>(
"trans_a", 1),
158 MakeArgument<int>(
"trans_b", 1),
159 MakeArgument<int>(
"axis_b", axis_a)})};
163 return vector<OperatorDef>{
167 vector<string>{I(1), GO(0), I(0)},
168 vector<string>{GI(0)},
169 vector<Argument>{MakeArgument<int>(
"trans_b", 1),
170 MakeArgument<int>(
"axis_a", axis_b)}),
174 vector<string>{I(0), GO(0), I(1)},
175 vector<string>{GI(1)},
176 vector<Argument>{MakeArgument<int>(
"axis_a", axis_a)})};
182 return vector<OperatorDef>{
186 vector<string>{GO(0), I(1), I(0)},
187 vector<string>{GI(0)},
188 vector<Argument>{MakeArgument<int>(
"axis_b", axis_b)}),
192 vector<string>{GO(0), I(0), I(1)},
193 vector<string>{GI(1)},
194 vector<Argument>{MakeArgument<int>(
"trans_a", 1),
195 MakeArgument<int>(
"axis_b", axis_a)})};
199 return vector<OperatorDef>{
203 vector<string>{GO(0), I(1), I(0)},
204 vector<string>{GI(0)},
205 vector<Argument>{MakeArgument<int>(
"trans_b", 1),
206 MakeArgument<int>(
"axis_b", axis_b)}),
210 vector<string>{I(0), GO(0), I(1)},
211 vector<string>{GI(1)},
212 vector<Argument>{MakeArgument<int>(
"trans_a", 1),
213 MakeArgument<int>(
"axis_a", axis_a)})};
218 bool CopyArguments()
const override {
int64_t size_from_dim_(int k, IntArrayRef dims)
Return product of all dimensions starting from k.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...