Caffe2 - C++ API
A deep learning, cross platform ML framework
matmul_op.h
1 #ifndef CAFFE2_OPERATORS_MATMUL_OP_H_
2 #define CAFFE2_OPERATORS_MATMUL_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context, class Engine = DefaultEngine>
11 class MatMulOp final : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  template <class... Args>
15  explicit MatMulOp(Args&&... args)
16  : Operator<Context>(std::forward<Args>(args)...),
17  axis_a_(this->template GetSingleArgument<int>("axis_a", 1)),
18  axis_b_(this->template GetSingleArgument<int>("axis_b", 1)),
19  trans_a_(this->template GetSingleArgument<int>("trans_a", 0)),
20  trans_b_(this->template GetSingleArgument<int>("trans_b", 0)) {}
21  ~MatMulOp() {}
22 
23  bool RunOnDevice() override {
24  const auto& A = Input(0);
25  const auto& B = Input(1);
26 
27  const auto canonical_axis_a = A.canonical_axis_index(axis_a_);
28  const auto canonical_axis_b = B.canonical_axis_index(axis_b_);
29  int A_dim0 = A.size_to_dim(canonical_axis_a);
30  int A_dim1 = A.size_from_dim(canonical_axis_a);
31  int B_dim0 = B.size_to_dim(canonical_axis_b);
32  int B_dim1 = B.size_from_dim(canonical_axis_b);
33 
34  int a_dim0, a_dim1, b_dim0, b_dim1;
35 
36  if (trans_a_) {
37  a_dim0 = A_dim1;
38  a_dim1 = A_dim0;
39  } else {
40  a_dim0 = A_dim0;
41  a_dim1 = A_dim1;
42  }
43 
44  if (trans_b_) {
45  b_dim0 = B_dim1;
46  b_dim1 = B_dim0;
47  } else {
48  b_dim0 = B_dim0;
49  b_dim1 = B_dim1;
50  }
51 
52  auto dimErrorString = [&]() {
53  return c10::str(
54  "Dimension mismatch: ",
55  trans_a_ ? "trans(A): " : "A: ",
56  a_dim0,
57  " ",
58  a_dim1,
59  trans_b_ ? ", trans(B): " : ", B: ",
60  b_dim0,
61  " ",
62  b_dim1);
63  };
64  // Error checking
65  CAFFE_ENFORCE(a_dim1 == b_dim0, dimErrorString());
66 
67  Y_shape_cache_[0] = a_dim0;
68  Y_shape_cache_[1] = b_dim1;
69  auto* Y = Output(0, Y_shape_cache_, at::dtype<T>());
70  CAFFE_ENFORCE(a_dim0 * b_dim1 == Y->numel(), dimErrorString());
71  // Y = A * B
72  math::Gemm<T, Context, Engine>(
73  trans_a_ ? CblasTrans : CblasNoTrans,
74  trans_b_ ? CblasTrans : CblasNoTrans,
75  a_dim0,
76  b_dim1,
77  a_dim1,
78  1,
79  A.template data<T>(),
80  B.template data<T>(),
81  0,
82  Y->template mutable_data<T>(),
83  &context_);
84 
85  if (InputSize() == 3) {
86  // In gradient op, resize to input
87  Y->ResizeLike(Input(2));
88  }
89  return true;
90  }
91 
92  protected:
93  // A local vector to cache the output shape so we don't need to recreate
94  // a vector object every time we run Run().
95  vector<int64_t> Y_shape_cache_{0, 0};
96  int axis_a_{1};
97  int axis_b_{1};
98  bool trans_a_;
99  bool trans_b_;
100 };
101 
102 } // namespace caffe2
103 
104 #endif // CAFFE2_OPERATORS_MATMUL_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
Definition: static.cpp:52
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:58