Caffe2 - C++ API
A deep learning, cross platform ML framework
matmul_op.cc
1 
17 #include "caffe2/operators/matmul_op.h"
18 
19 namespace caffe2 {
20 
21 REGISTER_CPU_OPERATOR(MatMul, MatMulOp<float, CPUContext>);
22 
23 OPERATOR_SCHEMA(MatMul)
24  .NumInputs(2, 3)
25  .NumOutputs(1)
26  .TensorInferenceFunction([](const OperatorDef& def,
27  const vector<TensorShape>& in) {
28  vector<TensorShape> out(1);
29  out[0].set_data_type(in[0].data_type());
30  ArgumentHelper arg_helper(def);
31  int axis_a = arg_helper.GetSingleArgument<int>("axis_a", 1);
32  int axis_b = arg_helper.GetSingleArgument<int>("axis_b", 1);
33  int trans_a = arg_helper.GetSingleArgument<bool>("trans_a", false);
34  int trans_b = arg_helper.GetSingleArgument<bool>("trans_b", false);
35  int canonical_axis_a = canonical_axis_index_(axis_a, in[0].dims().size());
36  int canonical_axis_b = canonical_axis_index_(axis_b, in[0].dims().size());
37 
38  int M = size_to_dim_(canonical_axis_a, GetDimsVector(in[0]));
39  int N = size_from_dim_(canonical_axis_b, GetDimsVector(in[1]));
40  if (trans_a) {
41  M = size_from_dim_(canonical_axis_a, GetDimsVector(in[0]));
42  }
43  if (trans_b) {
44  N = size_to_dim_(canonical_axis_b, GetDimsVector(in[1]));
45  }
46 
47  out[0].add_dims(M);
48  out[0].add_dims(N);
49 
50  return out;
51  })
52  .SetDoc(R"DOC(
53 Matrix multiplication Y = A * B, where A has size (M x K), B has size (K x N),
54 and Y will have a size (M x N).
55 )DOC")
56  .Input(0, "A", "2D matrix of size (M x K)")
57  .Input(1, "B", "2D matrix of size (K x N)")
58  .Output(0, "Y", "2D matrix of size (M x N)")
59  .Arg(
60  "axis_a",
61  "Exclusive axis that divides the first and second dimension \
62 of matrix A, default to 1")
63  .Arg(
64  "axis_b",
65  "Exclusive axis that divides the first and second dimension \
66 of matrix B, default to 1")
67  .Arg(
68  "trans_a",
69  "Pass 1 to transpose A before multiplication and after the \
70 dimension adjustment using axis_a")
71  .Arg(
72  "trans_b",
73  "Pass 1 to transpose B before multiplication and after the \
74 dimension adjustment using axis_b");
75 
76 class GetMatMulGradient : public GradientMakerBase {
77  using GradientMakerBase::GradientMakerBase;
78  vector<OperatorDef> GetGradientDefs() override {
79  CAFFE_ENFORCE_EQ(def_.input_size(), 2);
80 
81  bool axis_a = 1;
82  bool axis_b = 1;
83  bool trans_a = 0;
84  bool trans_b = 0;
85 
86  if (ArgumentHelper::HasArgument(Def(), "trans_a")) {
87  trans_a = GetArgument(Def(), "trans_a").i();
88  }
89  if (ArgumentHelper::HasArgument(Def(), "trans_b")) {
90  trans_b = GetArgument(Def(), "trans_b").i();
91  }
92  if (ArgumentHelper::HasArgument(Def(), "axis_a")) {
93  axis_a = GetArgument(Def(), "axis_a").i();
94  }
95  if (ArgumentHelper::HasArgument(Def(), "axis_b")) {
96  axis_b = GetArgument(Def(), "axis_b").i();
97  }
98 
99  if (trans_a) {
100  if (trans_b) {
101  // A'B':
102  // dA = B'G', dB = G'A'
103  return vector<OperatorDef>{
104  CreateOperatorDef(
105  "MatMul",
106  "",
107  vector<string>{I(1), GO(0), I(0)},
108  vector<string>{GI(0)},
109  vector<Argument>{MakeArgument<int>("trans_a", 1),
110  MakeArgument<int>("trans_b", 1),
111  MakeArgument<int>("axis_a", axis_b)}),
112  CreateOperatorDef(
113  "MatMul",
114  "",
115  vector<string>{GO(0), I(0), I(1)},
116  vector<string>{GI(1)},
117  vector<Argument>{MakeArgument<int>("trans_a", 1),
118  MakeArgument<int>("trans_b", 1),
119  MakeArgument<int>("axis_b", axis_a)})};
120  } else {
121  // A'B:
122  // dA = BG', dB = AG
123  return vector<OperatorDef>{
124  CreateOperatorDef(
125  "MatMul",
126  "",
127  vector<string>{I(1), GO(0), I(0)},
128  vector<string>{GI(0)},
129  vector<Argument>{MakeArgument<int>("trans_b", 1),
130  MakeArgument<int>("axis_a", axis_b)}),
131  CreateOperatorDef(
132  "MatMul",
133  "",
134  vector<string>{I(0), GO(0), I(1)},
135  vector<string>{GI(1)},
136  vector<Argument>{MakeArgument<int>("axis_a", axis_a)})};
137  }
138  } else {
139  if (trans_b) {
140  // AB':
141  // dA = GB, dB = G'A
142  return vector<OperatorDef>{
143  CreateOperatorDef(
144  "MatMul",
145  "",
146  vector<string>{GO(0), I(1), I(0)},
147  vector<string>{GI(0)},
148  vector<Argument>{MakeArgument<int>("axis_b", axis_b)}),
149  CreateOperatorDef(
150  "MatMul",
151  "",
152  vector<string>{GO(0), I(0), I(1)},
153  vector<string>{GI(1)},
154  vector<Argument>{MakeArgument<int>("trans_a", 1),
155  MakeArgument<int>("axis_b", axis_a)})};
156  } else {
157  // AB:
158  // dA = GB', dB = A'G
159  return vector<OperatorDef>{
160  CreateOperatorDef(
161  "MatMul",
162  "",
163  vector<string>{GO(0), I(1), I(0)},
164  vector<string>{GI(0)},
165  vector<Argument>{MakeArgument<int>("trans_b", 1),
166  MakeArgument<int>("axis_b", axis_b)}),
167  CreateOperatorDef(
168  "MatMul",
169  "",
170  vector<string>{I(0), GO(0), I(1)},
171  vector<string>{GI(1)},
172  vector<Argument>{MakeArgument<int>("trans_a", 1),
173  MakeArgument<int>("axis_a", axis_a)})};
174  }
175  }
176  }
177 
178  bool CopyArguments() const override {
179  return false;
180  }
181 };
182 
183 REGISTER_GRADIENT(MatMul, GetMatMulGradient);
184 
185 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.
TIndex size_from_dim_(int k, vector< TIndex > dims)
Return product of all dimensions starting from K.
Definition: tensor.h:56