Caffe2 - C++ API
A deep learning, cross platform ML framework
matmul_op.h
1 
17 #ifndef CAFFE2_OPERATORS_MATMUL_OP_H_
18 #define CAFFE2_OPERATORS_MATMUL_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <typename T, class Context, class Engine = DefaultEngine>
27 class MatMulOp final : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30  MatMulOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  axis_a_(OperatorBase::GetSingleArgument<int>("axis_a", 1)),
33  axis_b_(OperatorBase::GetSingleArgument<int>("axis_b", 1)),
34  trans_a_(OperatorBase::GetSingleArgument<int>("trans_a", 0)),
35  trans_b_(OperatorBase::GetSingleArgument<int>("trans_b", 0)) {}
36  ~MatMulOp() {}
37 
38  bool RunOnDevice() override {
39  const auto& A = Input(0);
40  const auto& B = Input(1);
41  auto* Y = Output(0);
42 
43  const auto canonical_axis_a = A.canonical_axis_index(axis_a_);
44  const auto canonical_axis_b = B.canonical_axis_index(axis_b_);
45  int A_dim0 = A.size_to_dim(canonical_axis_a);
46  int A_dim1 = A.size_from_dim(canonical_axis_a);
47  int B_dim0 = B.size_to_dim(canonical_axis_b);
48  int B_dim1 = B.size_from_dim(canonical_axis_b);
49 
50  int a_dim0, a_dim1, b_dim0, b_dim1;
51 
52  if (trans_a_) {
53  a_dim0 = A_dim1;
54  a_dim1 = A_dim0;
55  } else {
56  a_dim0 = A_dim0;
57  a_dim1 = A_dim1;
58  }
59 
60  if (trans_b_) {
61  b_dim0 = B_dim1;
62  b_dim1 = B_dim0;
63  } else {
64  b_dim0 = B_dim0;
65  b_dim1 = B_dim1;
66  }
67 
68  auto dimErrorString = [&]() {
69  return MakeString(
70  "Dimension mismatch: ",
71  trans_a_ ? "trans(A): " : "A: ",
72  a_dim0,
73  " ",
74  a_dim1,
75  trans_b_ ? ", trans(B): " : ", B: ",
76  b_dim0,
77  " ",
78  b_dim1);
79  };
80  // Error checking
81  CAFFE_ENFORCE(a_dim1 == b_dim0, dimErrorString());
82 
83  Y_shape_cache_[0] = a_dim0;
84  Y_shape_cache_[1] = b_dim1;
85  Y->Resize(Y_shape_cache_);
86  CAFFE_ENFORCE(a_dim0 * b_dim1 == Y->size(), dimErrorString());
87  // Y = A * B
88  math::Gemm<T, Context, Engine>(
89  trans_a_ ? CblasTrans : CblasNoTrans,
90  trans_b_ ? CblasTrans : CblasNoTrans,
91  a_dim0,
92  b_dim1,
93  a_dim1,
94  1,
95  A.template data<T>(),
96  B.template data<T>(),
97  0,
98  Y->template mutable_data<T>(),
99  &context_);
100 
101  if (InputSize() == 3) {
102  // In gradient op, resize to input
103  Y->ResizeLike(Input(2));
104  }
105  return true;
106  }
107 
108  protected:
109  // A local vector to cache the output shape so we don't need to recreate
110  // a vector object every time we run Run().
111  vector<TIndex> Y_shape_cache_{0, 0};
112  int axis_a_{1};
113  int axis_b_{1};
114  bool trans_a_;
115  bool trans_b_;
116 };
117 
118 } // namespace caffe2
119 
120 #endif // CAFFE2_OPERATORS_MATMUL_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.