1 #include <ATen/core/dispatch/KernelRegistration.h> 2 #include "caffe2/operators/experimental/c10/schemas/batch_matmul.h" 3 #include "caffe2/utils/math.h" 4 #include "caffe2/core/tensor.h" 15 std::shared_ptr<at::Tensor> scratch;
18 template <
class T,
class Context>
19 void batch_matmul_op_cpu_impl(
33 auto ndims_A =
A.dim();
34 auto dims_A =
A.sizes().vec();
35 auto ndims_B =
B.dim();
36 auto dims_B =
B.sizes().vec();
38 auto noBroadcastErrorMsg = [](
size_t dim1,
size_t dim2) {
40 ss <<
"Inputs with dimensions A = ";
44 ss <<
" is not supported with broadcast=0. Did you forget to set the " 50 bool dimMismatch = ndims_A != ndims_B;
51 bool dimsLessThan1D = ndims_A < 2;
53 broadcast || (!dimMismatch && !dimsLessThan1D),
54 noBroadcastErrorMsg(ndims_A, ndims_B));
56 auto* data_A =
A.template data<T>();
57 auto* data_B =
B.template data<T>();
59 auto dimMismatchErrorString = [](
size_t dimnum1,
66 ss <<
"Expected dimension ";
68 ss <<
" of tensor A with value ";
70 ss <<
" to match dimension ";
72 ss <<
" of tensor B with value ";
81 if (ndims_A == 1 && ndims_B == 1) {
86 "Vector-vector product requires each of the vectors to " 89 math::Dot<T, Context>(
90 dims_A[0], data_A, data_B, Y.template mutable_data<T>(), static_cast<Context*>(&context));
92 bool A_broadcasted =
false, B_broadcasted =
false;
94 dims_A.insert(dims_A.begin(), 1);
101 B_broadcasted =
true;
114 size_t num_inner_dims = std::min(ndims_A, ndims_B);
115 for (
size_t i = 2; i < num_inner_dims; ++i) {
116 auto first_r_itr = dims_A.rbegin();
117 auto second_r_itr = dims_B.rbegin();
121 dimMismatchErrorString(
129 size_t num_outer_dims = std::max(ndims_A, ndims_B) - num_inner_dims;
133 size_t M, N, K, K_dim;
135 M = dims_A[ndims_A - 1];
136 K = dims_A[ndims_A - 2];
139 M = dims_A[ndims_A - 2];
140 K = dims_A[ndims_A - 1];
144 N = dims_B[ndims_B - 2];
148 dimMismatchErrorString(
156 N = dims_B[ndims_B - 1];
160 dimMismatchErrorString(
172 std::vector<int64_t> new_dims;
173 if (ndims_A >= ndims_B) {
174 new_dims.assign(dims_A.begin(), dims_A.end() - 2);
176 new_dims.assign(dims_B.begin(), dims_B.end() - 2);
178 if (!A_broadcasted) {
179 new_dims.push_back(M);
181 new_dims.push_back(1);
183 if (!B_broadcasted) {
184 new_dims.push_back(N);
186 new_dims.push_back(1);
203 size_t num_sub_batches = 1;
204 if (ndims_A >= ndims_B) {
205 auto first_r_itr = dims_A.rbegin();
206 auto output_r_itr = new_dims.rbegin();
207 for (
size_t i = 0; i < num_inner_dims; ++i) {
208 A_stride *= *(first_r_itr + i);
209 Y_stride *= *(output_r_itr + i);
211 num_sub_batches *= *(first_r_itr + i);
217 auto second_r_itr = dims_B.rbegin();
218 auto output_r_itr = new_dims.rbegin();
219 for (
size_t i = 0; i < num_inner_dims; ++i) {
220 B_stride *= *(second_r_itr + i);
221 Y_stride *= *(output_r_itr + i);
223 num_sub_batches *= *(second_r_itr + i);
228 size_t num_outer_batches = 1;
229 for (
size_t i = 0; i < num_outer_dims; ++i) {
230 num_outer_batches *= new_dims[i];
236 new_dims.erase(new_dims.end() - 2);
237 }
else if (B_broadcasted) {
238 new_dims.erase(new_dims.end() - 1);
243 auto* Y_data = Y.template mutable_data<T>();
246 if (num_sub_batches == 0 || num_outer_batches == 0) {
251 for (
size_t p = 0; p < num_outer_batches; ++p) {
252 math::GemmStridedBatched<T, Context, Engine>(
253 trans_a ? CblasTrans : CblasNoTrans,
254 trans_b ? CblasTrans : CblasNoTrans,
260 data_A + p * A_stride,
262 data_B + p * B_stride,
265 Y_data + p * Y_stride,
267 static_cast<Context*
>(&context));
275 C10_REGISTER_KERNEL(caffe2::ops::BatchMatmul)
276 .withCache<caffe2::Cache>()
277 .kernel<decltype(caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>>()
278 .dispatchKey(CPUTensorId());
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Virtual interface for the Context class in Caffe2.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
A kernel can keep around a cache to have better performance when it's called multiple times...