Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_matmul_cpu.cc
1 #include <ATen/core/dispatch/KernelRegistration.h>
2 #include "caffe2/operators/experimental/c10/schemas/batch_matmul.h"
3 #include "caffe2/utils/math.h"
4 #include "caffe2/core/tensor.h"
5 
7 using caffe2::Tensor;
8 using std::vector;
9 namespace math = caffe2::math;
10 
11 namespace caffe2 {
12 namespace {
13 
14 struct Cache final : public c10::KernelCache {
15  std::shared_ptr<at::Tensor> scratch;
16 };
17 
18 template <class T, class Context>
19 void batch_matmul_op_cpu_impl(
20  const at::Tensor& A_,
21  const at::Tensor& B_,
22  const at::Tensor& Y_,
23  int64_t trans_a,
24  int64_t trans_b,
25  int64_t broadcast,
26  Cache* cache) {
27  Tensor A{C10Tensor(A_)};
28  Tensor B{C10Tensor(B_)};
29  Tensor Y{C10Tensor(Y_)};
30  CPUContext context;
31  using Engine = caffe2::DefaultEngine;
32 
33  auto ndims_A = A.dim();
34  auto dims_A = A.sizes().vec();
35  auto ndims_B = B.dim();
36  auto dims_B = B.sizes().vec();
37 
38  auto noBroadcastErrorMsg = [](size_t dim1, size_t dim2) {
39  std::stringstream ss;
40  ss << "Inputs with dimensions A = ";
41  ss << dim1;
42  ss << " and B = ";
43  ss << dim2;
44  ss << " is not supported with broadcast=0. Did you forget to set the "
45  "broadcast flag?";
46  return ss.str();
47  };
48 
49  // These should all be false if we're not broadcasting.
50  bool dimMismatch = ndims_A != ndims_B;
51  bool dimsLessThan1D = ndims_A < 2;
52  CAFFE_ENFORCE(
53  broadcast || (!dimMismatch && !dimsLessThan1D),
54  noBroadcastErrorMsg(ndims_A, ndims_B));
55 
56  auto* data_A = A.template data<T>();
57  auto* data_B = B.template data<T>();
58 
59  auto dimMismatchErrorString = [](size_t dimnum1,
60  size_t dim1,
61  size_t dimnum2,
62  size_t dim2,
63  bool trans_a_,
64  bool trans_b_) {
65  std::stringstream ss;
66  ss << "Expected dimension ";
67  ss << dimnum1;
68  ss << " of tensor A with value ";
69  ss << dim1;
70  ss << " to match dimension ";
71  ss << dimnum2;
72  ss << " of tensor B with value ";
73  ss << dim2;
74  ss << ". trans_a = ";
75  ss << trans_a_;
76  ss << " trans_b = ";
77  ss << trans_b_;
78  return ss.str();
79  };
80 
81  if (ndims_A == 1 && ndims_B == 1) {
82  // vector-vector
83  CAFFE_ENFORCE_EQ(
84  dims_A[0],
85  dims_B[0],
86  "Vector-vector product requires each of the vectors to "
87  "be the same size.");
88  Y.Resize(1);
89  math::Dot<T, Context>(
90  dims_A[0], data_A, data_B, Y.template mutable_data<T>(), static_cast<Context*>(&context));
91  } else {
92  bool A_broadcasted = false, B_broadcasted = false;
93  if (ndims_A == 1) {
94  dims_A.insert(dims_A.begin(), 1);
95  ndims_A = 2;
96  A_broadcasted = true;
97  }
98  if (ndims_B == 1) {
99  dims_B.push_back(1);
100  ndims_B = 2;
101  B_broadcasted = true;
102  }
103  // matrix-matrix with batches
104  // [B1..., M, K] * [B2..., K, N] -> [B..., M, N]
105  // In the event that A or B are one-dimensional, the trailing or leading
106  // 1 is not added to the output tensor's size.
107 
108  // First step: partition the tensors into inner and outer blocks.
109  // Ignoring the last two dimensions of A and B, ensure that one of the
110  // tensors' dimensions is a suffix of the other. For example,
111  // [4, x, x] is a suffix of [2, 3, 4, x, x]. In this example, the
112  // dimensions of size 2 and 3 will be broadcasted, so we partition into
113  // 2*3=6 individual instances of batched GEMM with A and B \in [4, x, x].
114  size_t num_inner_dims = std::min(ndims_A, ndims_B);
115  for (size_t i = 2; i < num_inner_dims; ++i) {
116  auto first_r_itr = dims_A.rbegin();
117  auto second_r_itr = dims_B.rbegin();
118  CAFFE_ENFORCE_EQ(
119  *(first_r_itr + i),
120  *(second_r_itr + i),
121  dimMismatchErrorString(
122  ndims_A - i - 1,
123  *(first_r_itr + i),
124  ndims_B - i - 1,
125  *(second_r_itr + i),
126  trans_a,
127  trans_b));
128  }
129  size_t num_outer_dims = std::max(ndims_A, ndims_B) - num_inner_dims;
130 
131  // Standard M, N, and K parameters respecting GEMM API and transpose
132  // flags
133  size_t M, N, K, K_dim;
134  if (trans_a) {
135  M = dims_A[ndims_A - 1];
136  K = dims_A[ndims_A - 2];
137  K_dim = ndims_A - 2;
138  } else {
139  M = dims_A[ndims_A - 2];
140  K = dims_A[ndims_A - 1];
141  K_dim = ndims_A - 1;
142  }
143  if (trans_b) {
144  N = dims_B[ndims_B - 2];
145  CAFFE_ENFORCE_EQ(
146  K,
147  dims_B[ndims_B - 1],
148  dimMismatchErrorString(
149  K_dim,
150  K,
151  ndims_B - 1,
152  dims_B[ndims_B - 1],
153  trans_a,
154  trans_b));
155  } else {
156  N = dims_B[ndims_B - 1];
157  CAFFE_ENFORCE_EQ(
158  K,
159  dims_B[ndims_B - 2],
160  dimMismatchErrorString(
161  K_dim,
162  K,
163  ndims_B - 2,
164  dims_B[ndims_B - 2],
165  trans_a,
166  trans_b));
167  }
168 
169  // Calculate output tensor shapes [B..., (M), (N)]
170  // Batch dimensions will be broadcasted out to those of the longer tensor
171  // A or B. Either M or N are optional if A or B, respectively are 1-D.
172  std::vector<int64_t> new_dims;
173  if (ndims_A >= ndims_B) {
174  new_dims.assign(dims_A.begin(), dims_A.end() - 2);
175  } else {
176  new_dims.assign(dims_B.begin(), dims_B.end() - 2);
177  }
178  if (!A_broadcasted) {
179  new_dims.push_back(M);
180  } else {
181  new_dims.push_back(1);
182  }
183  if (!B_broadcasted) {
184  new_dims.push_back(N);
185  } else {
186  new_dims.push_back(1);
187  }
188 
189  // Calculate strides. Continuing our example above,
190  // [4, M, K] * [2, 3, 4, K, N] = [2, 3, 4, M, N]
191  // We calculate this as follows:
192  // 1) Treat the outer batch dimensions as flattened, i.e. view the B
193  // tensor here as [6, 4, K, N] and Y as [6, 4, M, N]. The same rea-
194  // soning is analogous for the case where # dims A >= # dims B.
195  // 2) Perform this operation:
196  // for i in range(6):
197  // Y[i, :, :, :] = BatchMatMul(A, B[i, :, :, :])
198  size_t A_stride = 1; // How far to increment A pointer each itr
199  size_t B_stride = 1; // How far to increment B pointer each itr
200  size_t Y_stride = 1; // How far to increment Y pointer each itr
201  // How many "inner batches" we have. That is, the product of sizes for
202  // the slices excluding M, K, and N, for their respective matrices.
203  size_t num_sub_batches = 1;
204  if (ndims_A >= ndims_B) {
205  auto first_r_itr = dims_A.rbegin();
206  auto output_r_itr = new_dims.rbegin();
207  for (size_t i = 0; i < num_inner_dims; ++i) {
208  A_stride *= *(first_r_itr + i);
209  Y_stride *= *(output_r_itr + i);
210  if (i >= 2) {
211  num_sub_batches *= *(first_r_itr + i);
212  }
213  }
214  B_stride = 0;
215  } else {
216  A_stride = 0;
217  auto second_r_itr = dims_B.rbegin();
218  auto output_r_itr = new_dims.rbegin();
219  for (size_t i = 0; i < num_inner_dims; ++i) {
220  B_stride *= *(second_r_itr + i);
221  Y_stride *= *(output_r_itr + i);
222  if (i >= 2) {
223  num_sub_batches *= *(second_r_itr + i);
224  }
225  }
226  }
227 
228  size_t num_outer_batches = 1;
229  for (size_t i = 0; i < num_outer_dims; ++i) {
230  num_outer_batches *= new_dims[i];
231  }
232 
233  // Mutually exclusive since otherwise we would've taken the vector-vector
234  // path above
235  if (A_broadcasted) {
236  new_dims.erase(new_dims.end() - 2);
237  } else if (B_broadcasted) {
238  new_dims.erase(new_dims.end() - 1);
239  }
240 
241  // Allocate output tensor
242  Y.Resize(new_dims);
243  auto* Y_data = Y.template mutable_data<T>();
244 
245  // Zero batch dimension indicates no elements
246  if (num_sub_batches == 0 || num_outer_batches == 0) {
247  return;
248  }
249 
250  // TODO(T23893772): doing this in a loop is likely going to be slow on GPU
251  for (size_t p = 0; p < num_outer_batches; ++p) {
252  math::GemmStridedBatched<T, Context, Engine>(
253  trans_a ? CblasTrans : CblasNoTrans,
254  trans_b ? CblasTrans : CblasNoTrans,
255  num_sub_batches,
256  M,
257  N,
258  K,
259  1.0f,
260  data_A + p * A_stride,
261  M * K,
262  data_B + p * B_stride,
263  K * N,
264  0.0f,
265  Y_data + p * Y_stride,
266  M * N,
267  static_cast<Context*>(&context));
268  }
269  }
270 }
271 } // namespace
272 } // namespace caffe2
273 
274 namespace c10 {
275 C10_REGISTER_KERNEL(caffe2::ops::BatchMatmul)
276  .withCache<caffe2::Cache>()
277  .kernel<decltype(caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>>()
278  .dispatchKey(CPUTensorId());
279 } // namespace c10
Definition: any.cpp:108
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition: tensor.h:25
Virtual interface for the Context class in Caffe2.
Definition: context_base.h:32
Definition: static.cpp:52
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:58
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7
A kernel can keep around a cache to have better performance when it&#39;s called multiple times...
Definition: KernelCache.h:15