1 #ifndef CAFFE2_OPERATORS_MATMUL_OP_H_ 2 #define CAFFE2_OPERATORS_MATMUL_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/utils/math.h" 12 template <
class Context,
class Engine = DefaultEngine>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
16 template <
class... Args>
19 trans_a_(this->
template GetSingleArgument<int>(
"trans_a", 0)),
20 trans_b_(this->
template GetSingleArgument<int>(
"trans_b", 0)),
21 broadcast_(this->
template GetSingleArgument<int>(
"broadcast", 0)) {}
25 bool RunOnDevice()
override {
30 bool DoRunWithType() {
34 auto ndims_A =
A.dim();
35 auto dims_A =
A.sizes().vec();
36 auto ndims_B =
B.dim();
37 auto dims_B =
B.sizes().vec();
39 auto noBroadcastErrorMsg = [](
size_t dim1,
size_t dim2) {
41 ss <<
"Inputs with dimensions A = ";
45 ss <<
" is not supported with broadcast=0. Did you forget to set the " 51 bool dimMismatch = ndims_A != ndims_B;
52 bool dimsLessThan1D = ndims_A < 2;
54 broadcast_ || (!dimMismatch && !dimsLessThan1D),
55 noBroadcastErrorMsg(ndims_A, ndims_B));
57 auto* data_A =
A.template data<T>();
58 auto* data_B =
B.template data<T>();
60 auto dimMismatchErrorString = [](
size_t dimnum1,
67 ss <<
"Expected dimension ";
69 ss <<
" of tensor A with value ";
71 ss <<
" to match dimension ";
73 ss <<
" of tensor B with value ";
82 if (ndims_A == 1 && ndims_B == 1) {
87 "Vector-vector product requires each of the vectors to " 89 auto* Y = Output(0, {1}, at::dtype<T>());
90 math::Dot<T, Context>(
91 dims_A[0], data_A, data_B, Y->template mutable_data<T>(), &context_);
93 bool A_broadcasted =
false, B_broadcasted =
false;
95 dims_A.insert(dims_A.begin(), 1);
102 B_broadcasted =
true;
115 size_t num_inner_dims = std::min(ndims_A, ndims_B);
116 for (
size_t i = 2; i < num_inner_dims; ++i) {
117 auto first_r_itr = dims_A.rbegin();
118 auto second_r_itr = dims_B.rbegin();
122 dimMismatchErrorString(
130 size_t num_outer_dims = std::max(ndims_A, ndims_B) - num_inner_dims;
134 size_t M, N, K, K_dim;
136 M = dims_A[ndims_A - 1];
137 K = dims_A[ndims_A - 2];
140 M = dims_A[ndims_A - 2];
141 K = dims_A[ndims_A - 1];
145 N = dims_B[ndims_B - 2];
149 dimMismatchErrorString(
157 N = dims_B[ndims_B - 1];
161 dimMismatchErrorString(
173 std::vector<int64_t> new_dims;
174 if (ndims_A >= ndims_B) {
175 new_dims.assign(dims_A.begin(), dims_A.end() - 2);
177 new_dims.assign(dims_B.begin(), dims_B.end() - 2);
179 if (!A_broadcasted) {
180 new_dims.push_back(M);
182 new_dims.push_back(1);
184 if (!B_broadcasted) {
185 new_dims.push_back(N);
187 new_dims.push_back(1);
204 size_t num_sub_batches = 1;
205 if (ndims_A >= ndims_B) {
206 auto first_r_itr = dims_A.rbegin();
207 auto output_r_itr = new_dims.rbegin();
208 for (
size_t i = 0; i < num_inner_dims; ++i) {
209 A_stride *= *(first_r_itr + i);
210 Y_stride *= *(output_r_itr + i);
212 num_sub_batches *= *(first_r_itr + i);
218 auto second_r_itr = dims_B.rbegin();
219 auto output_r_itr = new_dims.rbegin();
220 for (
size_t i = 0; i < num_inner_dims; ++i) {
221 B_stride *= *(second_r_itr + i);
222 Y_stride *= *(output_r_itr + i);
224 num_sub_batches *= *(second_r_itr + i);
229 size_t num_outer_batches = 1;
230 for (
size_t i = 0; i < num_outer_dims; ++i) {
231 num_outer_batches *= new_dims[i];
237 new_dims.erase(new_dims.end() - 2);
238 }
else if (B_broadcasted) {
239 new_dims.erase(new_dims.end() - 1);
243 auto* Y = Output(0, new_dims, at::dtype<T>());
244 auto* Y_data = Y->template mutable_data<T>();
247 if (num_sub_batches == 0 || num_outer_batches == 0) {
252 for (
size_t p = 0; p < num_outer_batches; ++p) {
253 math::GemmStridedBatched<T, Context, Engine>(
254 trans_a_ ? CblasTrans : CblasNoTrans,
255 trans_b_ ? CblasTrans : CblasNoTrans,
261 data_A + p * A_stride,
263 data_B + p * B_stride,
266 Y_data + p * Y_stride,
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...