1 #ifndef CAFFE2_OPERATORS_MATMUL_OP_H_ 2 #define CAFFE2_OPERATORS_MATMUL_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 10 template <
typename T,
class Context,
class Engine = DefaultEngine>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
14 template <
class... Args>
17 axis_a_(this->
template GetSingleArgument<int>(
"axis_a", 1)),
18 axis_b_(this->
template GetSingleArgument<int>(
"axis_b", 1)),
19 trans_a_(this->
template GetSingleArgument<int>(
"trans_a", 0)),
20 trans_b_(this->
template GetSingleArgument<int>(
"trans_b", 0)) {}
23 bool RunOnDevice()
override {
27 const auto canonical_axis_a =
A.canonical_axis_index(axis_a_);
28 const auto canonical_axis_b =
B.canonical_axis_index(axis_b_);
29 int A_dim0 =
A.size_to_dim(canonical_axis_a);
30 int A_dim1 =
A.size_from_dim(canonical_axis_a);
31 int B_dim0 =
B.size_to_dim(canonical_axis_b);
32 int B_dim1 =
B.size_from_dim(canonical_axis_b);
34 int a_dim0, a_dim1, b_dim0, b_dim1;
52 auto dimErrorString = [&]() {
54 "Dimension mismatch: ",
55 trans_a_ ?
"trans(A): " :
"A: ",
59 trans_b_ ?
", trans(B): " :
", B: ",
65 CAFFE_ENFORCE(a_dim1 == b_dim0, dimErrorString());
67 Y_shape_cache_[0] = a_dim0;
68 Y_shape_cache_[1] = b_dim1;
69 auto* Y = Output(0, Y_shape_cache_, at::dtype<T>());
70 CAFFE_ENFORCE(a_dim0 * b_dim1 == Y->numel(), dimErrorString());
72 math::Gemm<T, Context, Engine>(
73 trans_a_ ? CblasTrans : CblasNoTrans,
74 trans_b_ ? CblasTrans : CblasNoTrans,
82 Y->template mutable_data<T>(),
85 if (InputSize() == 3) {
87 Y->ResizeLike(
Input(2));
95 vector<int64_t> Y_shape_cache_{0, 0};
104 #endif // CAFFE2_OPERATORS_MATMUL_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...