Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_matmul_op.h
1 #ifndef CAFFE2_OPERATORS_MATMUL_OP_H_
2 #define CAFFE2_OPERATORS_MATMUL_OP_H_
3 
4 #include <sstream>
5 
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <class Context, class Engine = DefaultEngine>
13 class BatchMatMulOp final : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  template <class... Args>
17  explicit BatchMatMulOp(Args&&... args)
18  : Operator<Context>(std::forward<Args>(args)...),
19  trans_a_(this->template GetSingleArgument<int>("trans_a", 0)),
20  trans_b_(this->template GetSingleArgument<int>("trans_b", 0)),
21  broadcast_(this->template GetSingleArgument<int>("broadcast", 0)) {}
22 
23  ~BatchMatMulOp() {}
24 
25  bool RunOnDevice() override {
26  return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
27  }
28 
29  template <typename T>
30  bool DoRunWithType() {
31  const auto& A = Input(0);
32  const auto& B = Input(1);
33 
34  auto ndims_A = A.dim();
35  auto dims_A = A.sizes().vec();
36  auto ndims_B = B.dim();
37  auto dims_B = B.sizes().vec();
38 
39  auto noBroadcastErrorMsg = [](size_t dim1, size_t dim2) {
40  std::stringstream ss;
41  ss << "Inputs with dimensions A = ";
42  ss << dim1;
43  ss << " and B = ";
44  ss << dim2;
45  ss << " is not supported with broadcast=0. Did you forget to set the "
46  "broadcast flag?";
47  return ss.str();
48  };
49 
50  // These should all be false if we're not broadcasting.
51  bool dimMismatch = ndims_A != ndims_B;
52  bool dimsLessThan1D = ndims_A < 2;
53  CAFFE_ENFORCE(
54  broadcast_ || (!dimMismatch && !dimsLessThan1D),
55  noBroadcastErrorMsg(ndims_A, ndims_B));
56 
57  auto* data_A = A.template data<T>();
58  auto* data_B = B.template data<T>();
59 
60  auto dimMismatchErrorString = [](size_t dimnum1,
61  size_t dim1,
62  size_t dimnum2,
63  size_t dim2,
64  bool trans_a,
65  bool trans_b) {
66  std::stringstream ss;
67  ss << "Expected dimension ";
68  ss << dimnum1;
69  ss << " of tensor A with value ";
70  ss << dim1;
71  ss << " to match dimension ";
72  ss << dimnum2;
73  ss << " of tensor B with value ";
74  ss << dim2;
75  ss << ". trans_a = ";
76  ss << trans_a;
77  ss << " trans_b = ";
78  ss << trans_b;
79  return ss.str();
80  };
81 
82  if (ndims_A == 1 && ndims_B == 1) {
83  // vector-vector
84  CAFFE_ENFORCE_EQ(
85  dims_A[0],
86  dims_B[0],
87  "Vector-vector product requires each of the vectors to "
88  "be the same size.");
89  auto* Y = Output(0, {1}, at::dtype<T>());
90  math::Dot<T, Context>(
91  dims_A[0], data_A, data_B, Y->template mutable_data<T>(), &context_);
92  } else {
93  bool A_broadcasted = false, B_broadcasted = false;
94  if (ndims_A == 1) {
95  dims_A.insert(dims_A.begin(), 1);
96  ndims_A = 2;
97  A_broadcasted = true;
98  }
99  if (ndims_B == 1) {
100  dims_B.push_back(1);
101  ndims_B = 2;
102  B_broadcasted = true;
103  }
104  // matrix-matrix with batches
105  // [B1..., M, K] * [B2..., K, N] -> [B..., M, N]
106  // In the event that A or B are one-dimensional, the trailing or leading
107  // 1 is not added to the output tensor's size.
108 
109  // First step: partition the tensors into inner and outer blocks.
110  // Ignoring the last two dimensions of A and B, ensure that one of the
111  // tensors' dimensions is a suffix of the other. For example,
112  // [4, x, x] is a suffix of [2, 3, 4, x, x]. In this example, the
113  // dimensions of size 2 and 3 will be broadcasted, so we partition into
114  // 2*3=6 individual instances of batched GEMM with A and B \in [4, x, x].
115  size_t num_inner_dims = std::min(ndims_A, ndims_B);
116  for (size_t i = 2; i < num_inner_dims; ++i) {
117  auto first_r_itr = dims_A.rbegin();
118  auto second_r_itr = dims_B.rbegin();
119  CAFFE_ENFORCE_EQ(
120  *(first_r_itr + i),
121  *(second_r_itr + i),
122  dimMismatchErrorString(
123  ndims_A - i - 1,
124  *(first_r_itr + i),
125  ndims_B - i - 1,
126  *(second_r_itr + i),
127  trans_a_,
128  trans_b_));
129  }
130  size_t num_outer_dims = std::max(ndims_A, ndims_B) - num_inner_dims;
131 
132  // Standard M, N, and K parameters respecting GEMM API and transpose
133  // flags
134  size_t M, N, K, K_dim;
135  if (trans_a_) {
136  M = dims_A[ndims_A - 1];
137  K = dims_A[ndims_A - 2];
138  K_dim = ndims_A - 2;
139  } else {
140  M = dims_A[ndims_A - 2];
141  K = dims_A[ndims_A - 1];
142  K_dim = ndims_A - 1;
143  }
144  if (trans_b_) {
145  N = dims_B[ndims_B - 2];
146  CAFFE_ENFORCE_EQ(
147  K,
148  dims_B[ndims_B - 1],
149  dimMismatchErrorString(
150  K_dim,
151  K,
152  ndims_B - 1,
153  dims_B[ndims_B - 1],
154  trans_a_,
155  trans_b_));
156  } else {
157  N = dims_B[ndims_B - 1];
158  CAFFE_ENFORCE_EQ(
159  K,
160  dims_B[ndims_B - 2],
161  dimMismatchErrorString(
162  K_dim,
163  K,
164  ndims_B - 2,
165  dims_B[ndims_B - 2],
166  trans_a_,
167  trans_b_));
168  }
169 
170  // Calculate output tensor shapes [B..., (M), (N)]
171  // Batch dimensions will be broadcasted out to those of the longer tensor
172  // A or B. Either M or N are optional if A or B, respectively are 1-D.
173  std::vector<int64_t> new_dims;
174  if (ndims_A >= ndims_B) {
175  new_dims.assign(dims_A.begin(), dims_A.end() - 2);
176  } else {
177  new_dims.assign(dims_B.begin(), dims_B.end() - 2);
178  }
179  if (!A_broadcasted) {
180  new_dims.push_back(M);
181  } else {
182  new_dims.push_back(1);
183  }
184  if (!B_broadcasted) {
185  new_dims.push_back(N);
186  } else {
187  new_dims.push_back(1);
188  }
189 
190  // Calculate strides. Continuing our example above,
191  // [4, M, K] * [2, 3, 4, K, N] = [2, 3, 4, M, N]
192  // We calculate this as follows:
193  // 1) Treat the outer batch dimensions as flattened, i.e. view the B
194  // tensor here as [6, 4, K, N] and Y as [6, 4, M, N]. The same rea-
195  // soning is analogous for the case where # dims A >= # dims B.
196  // 2) Perform this operation:
197  // for i in range(6):
198  // Y[i, :, :, :] = BatchMatMul(A, B[i, :, :, :])
199  size_t A_stride = 1; // How far to increment A pointer each itr
200  size_t B_stride = 1; // How far to increment B pointer each itr
201  size_t Y_stride = 1; // How far to increment Y pointer each itr
202  // How many "inner batches" we have. That is, the product of sizes for
203  // the slices excluding M, K, and N, for their respective matrices.
204  size_t num_sub_batches = 1;
205  if (ndims_A >= ndims_B) {
206  auto first_r_itr = dims_A.rbegin();
207  auto output_r_itr = new_dims.rbegin();
208  for (size_t i = 0; i < num_inner_dims; ++i) {
209  A_stride *= *(first_r_itr + i);
210  Y_stride *= *(output_r_itr + i);
211  if (i >= 2) {
212  num_sub_batches *= *(first_r_itr + i);
213  }
214  }
215  B_stride = 0;
216  } else {
217  A_stride = 0;
218  auto second_r_itr = dims_B.rbegin();
219  auto output_r_itr = new_dims.rbegin();
220  for (size_t i = 0; i < num_inner_dims; ++i) {
221  B_stride *= *(second_r_itr + i);
222  Y_stride *= *(output_r_itr + i);
223  if (i >= 2) {
224  num_sub_batches *= *(second_r_itr + i);
225  }
226  }
227  }
228 
229  size_t num_outer_batches = 1;
230  for (size_t i = 0; i < num_outer_dims; ++i) {
231  num_outer_batches *= new_dims[i];
232  }
233 
234  // Mutually exclusive since otherwise we would've taken the vector-vector
235  // path above
236  if (A_broadcasted) {
237  new_dims.erase(new_dims.end() - 2);
238  } else if (B_broadcasted) {
239  new_dims.erase(new_dims.end() - 1);
240  }
241 
242  // Allocate output tensor
243  auto* Y = Output(0, new_dims, at::dtype<T>());
244  auto* Y_data = Y->template mutable_data<T>();
245 
246  // Zero batch dimension indicates no elements
247  if (num_sub_batches == 0 || num_outer_batches == 0) {
248  return true;
249  }
250 
251  // TODO(T23893772): doing this in a loop is likely going to be slow on GPU
252  for (size_t p = 0; p < num_outer_batches; ++p) {
253  math::GemmStridedBatched<T, Context, Engine>(
254  trans_a_ ? CblasTrans : CblasNoTrans,
255  trans_b_ ? CblasTrans : CblasNoTrans,
256  num_sub_batches,
257  M,
258  N,
259  K,
260  1.0f,
261  data_A + p * A_stride,
262  M * K,
263  data_B + p * B_stride,
264  K * N,
265  0.0f,
266  Y_data + p * Y_stride,
267  M * N,
268  &context_);
269  }
270  }
271  return true;
272  }
273 
274  protected:
275  bool trans_a_;
276  bool trans_b_;
277  bool broadcast_;
278 };
279 
280 } // namespace caffe2
281 
282 #endif /* CAFFE2_OPERATORS_MATMUL_OP_H_ */
Definition: any.cpp:108
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
Definition: static.cpp:52
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:58