Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_matmul_op.h
1 
16 #ifndef CAFFE2_OPERATORS_MATMUL_OP_H_
17 #define CAFFE2_OPERATORS_MATMUL_OP_H_
18 
19 #include <sstream>
20 
21 #include "caffe2/core/context.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <class Context, class Engine = DefaultEngine>
28 class BatchMatMulOp final : public Operator<Context> {
29  public:
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31  BatchMatMulOp(const OperatorDef& operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws),
33  trans_a_(OperatorBase::GetSingleArgument<int>("trans_a", 0)),
34  trans_b_(OperatorBase::GetSingleArgument<int>("trans_b", 0)),
35  broadcast_(OperatorBase::GetSingleArgument<int>("broadcast", 0)),
36  use_scratch_(OperatorBase::GetSingleArgument<int>("use_scratch", 0)) {
37  if (use_scratch_) {
38  scratch_ = std::make_shared<Tensor<Context>>();
39  }
40  }
41 
42  ~BatchMatMulOp() {}
43 
44  bool RunOnDevice() override {
45  return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
46  }
47 
48  template <typename T>
49  bool DoRunWithType() {
50  const auto& A = Input(0);
51  const auto& B = Input(1);
52  auto* Y = Output(0);
53 
54  auto ndims_A = A.ndim();
55  auto dims_A = A.dims();
56  auto ndims_B = B.ndim();
57  auto dims_B = B.dims();
58 
59  auto noBroadcastErrorMsg = [](size_t dim1, size_t dim2) {
60  std::stringstream ss;
61  ss << "Inputs with dimensions A = ";
62  ss << dim1;
63  ss << " and B = ";
64  ss << dim2;
65  ss << " is not supported with broadcast=0. Did you forget to set the "
66  "broadcast flag?";
67  return ss.str();
68  };
69 
70  // These should all be false if we're not broadcasting.
71  bool dimMismatch = ndims_A != ndims_B;
72  bool dimsLessThan1D = ndims_A < 2;
73  CAFFE_ENFORCE(
74  broadcast_ || (!dimMismatch && !dimsLessThan1D),
75  noBroadcastErrorMsg(ndims_A, ndims_B));
76 
77  auto* data_A = A.template data<T>();
78  auto* data_B = B.template data<T>();
79 
80  auto dimMismatchErrorString = [](size_t dimnum1,
81  size_t dim1,
82  size_t dimnum2,
83  size_t dim2,
84  bool trans_a,
85  bool trans_b) {
86  std::stringstream ss;
87  ss << "Expected dimension ";
88  ss << dimnum1;
89  ss << " of tensor A with value ";
90  ss << dim1;
91  ss << " to match dimension ";
92  ss << dimnum2;
93  ss << " of tensor B with value ";
94  ss << dim2;
95  ss << ". trans_a = ";
96  ss << trans_a;
97  ss << " trans_b = ";
98  ss << trans_b;
99  return ss.str();
100  };
101 
102  if (ndims_A == 1 && ndims_B == 1) {
103  // vector-vector
104  CAFFE_ENFORCE_EQ(
105  dims_A[0],
106  dims_B[0],
107  "Vector-vector product requires each of the vectors to "
108  "be the same size.");
109  Y->Resize(1);
110  math::Dot<T, Context>(
111  dims_A[0], data_A, data_B, Y->template mutable_data<T>(), &context_);
112  } else {
113  bool A_broadcasted = false, B_broadcasted = false;
114  if (ndims_A == 1) {
115  dims_A.insert(dims_A.begin(), 1);
116  ndims_A = 2;
117  A_broadcasted = true;
118  }
119  if (ndims_B == 1) {
120  dims_B.push_back(1);
121  ndims_B = 2;
122  B_broadcasted = true;
123  }
124  // matrix-matrix with batches
125  // [B1..., M, K] * [B2..., K, N] -> [B..., M, N]
126  // In the event that A or B are one-dimensional, the trailing or leading
127  // 1 is not added to the output tensor's size.
128 
129  // First step: partition the tensors into inner and outer blocks.
130  // Ignoring the last two dimensions of A and B, ensure that one of the
131  // tensors' dimensions is a suffix of the other. For example,
132  // [4, x, x] is a suffix of [2, 3, 4, x, x]. In this example, the
133  // dimensions of size 2 and 3 will be broadcasted, so we partition into
134  // 2*3=6 individual instances of batched GEMM with A and B \in [4, x, x].
135  size_t num_inner_dims = std::min(ndims_A, ndims_B);
136  for (size_t i = 2; i < num_inner_dims; ++i) {
137  auto first_r_itr = dims_A.rbegin();
138  auto second_r_itr = dims_B.rbegin();
139  CAFFE_ENFORCE_EQ(
140  *(first_r_itr + i),
141  *(second_r_itr + i),
142  dimMismatchErrorString(
143  ndims_A - i - 1,
144  *(first_r_itr + i),
145  ndims_B - i - 1,
146  *(second_r_itr + i),
147  trans_a_,
148  trans_b_));
149  }
150  size_t num_outer_dims = std::max(ndims_A, ndims_B) - num_inner_dims;
151 
152  // Standard M, N, and K parameters respecting GEMM API and transpose
153  // flags
154  size_t M, N, K, K_dim;
155  if (trans_a_) {
156  M = dims_A[ndims_A - 1];
157  K = dims_A[ndims_A - 2];
158  K_dim = ndims_A - 2;
159  } else {
160  M = dims_A[ndims_A - 2];
161  K = dims_A[ndims_A - 1];
162  K_dim = ndims_A - 1;
163  }
164  if (trans_b_) {
165  N = dims_B[ndims_B - 2];
166  CAFFE_ENFORCE_EQ(
167  K,
168  dims_B[ndims_B - 1],
169  dimMismatchErrorString(
170  K_dim,
171  K,
172  ndims_B - 1,
173  dims_B[ndims_B - 1],
174  trans_a_,
175  trans_b_));
176  } else {
177  N = dims_B[ndims_B - 1];
178  CAFFE_ENFORCE_EQ(
179  K,
180  dims_B[ndims_B - 2],
181  dimMismatchErrorString(
182  K_dim,
183  K,
184  ndims_B - 2,
185  dims_B[ndims_B - 2],
186  trans_a_,
187  trans_b_));
188  }
189 
190  // Calculate output tensor shapes [B..., (M), (N)]
191  // Batch dimensions will be broadcasted out to those of the longer tensor
192  // A or B. Either M or N are optional if A or B, respectively are 1-D.
193  std::vector<TIndex> new_dims;
194  if (ndims_A >= ndims_B) {
195  new_dims.assign(dims_A.begin(), dims_A.end() - 2);
196  } else {
197  new_dims.assign(dims_B.begin(), dims_B.end() - 2);
198  }
199  if (!A_broadcasted) {
200  new_dims.push_back(M);
201  } else {
202  new_dims.push_back(1);
203  }
204  if (!B_broadcasted) {
205  new_dims.push_back(N);
206  } else {
207  new_dims.push_back(1);
208  }
209 
210  // Calculate strides. Continuing our example above,
211  // [4, M, K] * [2, 3, 4, K, N] = [2, 3, 4, M, N]
212  // We calculate this as follows:
213  // 1) Treat the outer batch dimensions as flattened, i.e. view the B
214  // tensor here as [6, 4, K, N] and Y as [6, 4, M, N]. The same rea-
215  // soning is analogous for the case where # dims A >= # dims B.
216  // 2) Perform this operation:
217  // for i in range(6):
218  // Y[i, :, :, :] = BatchMatMul(A, B[i, :, :, :])
219  size_t A_stride = 1; // How far to increment A pointer each itr
220  size_t B_stride = 1; // How far to increment B pointer each itr
221  size_t Y_stride = 1; // How far to increment Y pointer each itr
222  // How many "inner batches" we have. That is, the product of sizes for
223  // the slices excluding M, K, and N, for their respective matrices.
224  size_t num_sub_batches = 1;
225  if (ndims_A >= ndims_B) {
226  auto first_r_itr = dims_A.rbegin();
227  auto output_r_itr = new_dims.rbegin();
228  for (size_t i = 0; i < num_inner_dims; ++i) {
229  A_stride *= *(first_r_itr + i);
230  Y_stride *= *(output_r_itr + i);
231  if (i >= 2) {
232  num_sub_batches *= *(first_r_itr + i);
233  }
234  }
235  B_stride = 0;
236  } else {
237  A_stride = 0;
238  auto second_r_itr = dims_B.rbegin();
239  auto output_r_itr = new_dims.rbegin();
240  for (size_t i = 0; i < num_inner_dims; ++i) {
241  B_stride *= *(second_r_itr + i);
242  Y_stride *= *(output_r_itr + i);
243  if (i >= 2) {
244  num_sub_batches *= *(second_r_itr + i);
245  }
246  }
247  }
248 
249  size_t num_outer_batches = 1;
250  for (size_t i = 0; i < num_outer_dims; ++i) {
251  num_outer_batches *= new_dims[i];
252  }
253 
254  // Mutually exclusive since otherwise we would've taken the vector-vector
255  // path above
256  if (A_broadcasted) {
257  new_dims.erase(new_dims.end() - 2);
258  } else if (B_broadcasted) {
259  new_dims.erase(new_dims.end() - 1);
260  }
261 
262  // Allocate output tensor
263  Y->Resize(new_dims);
264  auto* Y_data = Y->template mutable_data<T>();
265 
266  // Zero batch dimension indicates no elements
267  if (num_sub_batches == 0 || num_outer_batches == 0) {
268  return true;
269  }
270 
271  // TODO(T23893772): doing this in a loop is likely going to be slow on GPU
272  for (size_t p = 0; p < num_outer_batches; ++p) {
273  math::GemmBatched<T, Context, Engine>(
274  trans_a_ ? CblasTrans : CblasNoTrans,
275  trans_b_ ? CblasTrans : CblasNoTrans,
276  num_sub_batches,
277  M,
278  N,
279  K,
280  1.0f,
281  data_A + p * A_stride,
282  data_B + p * B_stride,
283  0.0f,
284  Y_data + p * Y_stride,
285  &context_,
286  use_scratch_ ? scratch_.get() : nullptr);
287  }
288  }
289  return true;
290  }
291 
292  protected:
293  bool trans_a_;
294  bool trans_b_;
295  bool broadcast_;
296 
297  bool use_scratch_;
298  std::shared_ptr<Tensor<Context>> scratch_;
299 };
300 
301 } // namespace caffe2
302 
303 #endif /* CAFFE2_OPERATORS_MATMUL_OP_H_ */
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.