Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op.h
1 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
2 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
3 
4 #include <c10/util/Optional.h>
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/conversions.h"
8 #include "caffe2/utils/math.h"
9 
10 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
11 #include <chrono>
12 #endif
13 
14 namespace caffe2 {
15 
16 // This is Caffe's InnerProductOp, with a name that fits its purpose better.
17 template <
18  class Context,
19  class Engine = DefaultEngine,
20  bool TransposeWeight = true>
21 class FullyConnectedOp final : public Operator<Context> {
22  public:
23  USE_OPERATOR_CONTEXT_FUNCTIONS;
24  template <class... Args>
25  explicit FullyConnectedOp(Args&&... args)
26  : Operator<Context>(std::forward<Args>(args)...),
27  axis_(this->template GetSingleArgument<int32_t>("axis", 1)),
28  axis_w_(this->template GetSingleArgument<int32_t>("axis_w", 1)),
29  float16_compute_(
30  this->template GetSingleArgument<bool>("float16_compute", false)) {}
31  ~FullyConnectedOp() {}
32 
33  template <
34  typename T_X,
35  typename T_W,
36  typename T_B,
37  typename T_Y,
38  typename MATH>
39  bool DoRunWithType() {
40 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
41  std::chrono::time_point<std::chrono::system_clock> t_very_begin, t_begin,
42  t_end;
43  /* if (VLOG_IS_ON(3)) */
44  {
45  t_begin = std::chrono::system_clock::now();
46  t_very_begin = t_begin;
47  }
48 #endif
49 
50  const auto& X = Input(0);
51  const auto& W = Input(1);
52  const auto& b = Input(2);
53 
54  CAFFE_ENFORCE(b.dim() == 1, b.dim());
55  // batch size
56  const auto canonical_axis = X.canonical_axis_index(axis_);
57  const auto M = X.size_to_dim(canonical_axis);
58  const auto K = X.size_from_dim(canonical_axis);
59  const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
60  const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
61  : W.size_from_dim(canonical_axis_w);
62 
63  auto dimErrorString = [&]() {
64  return c10::str(
65  "Dimension mismatch: ",
66  "X: ",
67  X.sizes(),
68  ", W: ",
69  W.sizes(),
70  ", b: ",
71  b.sizes(),
72  ", axis: ",
73  axis_,
74  ", M: ",
75  M,
76  ", N: ",
77  N,
78  ", K: ",
79  K);
80  };
81 
82  // Error checking
83  CAFFE_ENFORCE(M == X.numel() / K, dimErrorString());
84  CAFFE_ENFORCE(K == W.numel() / N, dimErrorString());
85  CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
86  CAFFE_ENFORCE(N == b.numel(), dimErrorString());
87 
88  Y_shape_cache_ = X.sizes().vec();
89  // This is an invariant of canonical_axis, so we can DCHECK.
90  DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
91  Y_shape_cache_.resize(canonical_axis + 1);
92  Y_shape_cache_[canonical_axis] = N;
93  auto* Y = Output(0, Y_shape_cache_, at::dtype<T_Y>());
94  CAFFE_ENFORCE(M * N == Y->numel(), dimErrorString());
95 
96  if (X.numel() == 0) {
97  // skip the rest of the computation if X is empty
98  Y->template mutable_data<T_Y>();
99  return true;
100  }
101 
102  // default to FLOAT as math.h does.
103  TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
104  if (fp16_type<MATH>()) {
105  math_type = TensorProto_DataType_FLOAT16;
106  }
107 
108 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
109  /* if (VLOG_IS_ON(3)) */
110  {
111  t_end = std::chrono::system_clock::now();
112  double dt = std::chrono::duration<double>(t_end - t_begin).count();
113  LOG(INFO) << "@PERF this=" << this << " before_gemm: " << dt * 1e3
114  << " ms";
115  t_begin = std::chrono::system_clock::now();
116  }
117 #endif
118  // W * x
119  math::Gemm<T_X, Context, Engine>(
120  CblasNoTrans,
121  TransposeWeight ? CblasTrans : CblasNoTrans,
122  M,
123  N,
124  K,
125  1,
126  X.template data<T_X>(),
127  W.template data<T_W>(),
128  0,
129  Y->template mutable_data<T_Y>(),
130  &context_,
131  math_type);
132 
133 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
134  /* if (VLOG_IS_ON(3)) */
135  {
136  t_end = std::chrono::system_clock::now();
137  double dt = std::chrono::duration<double>(t_end - t_begin).count();
138  LOG(INFO) << "@PERF this=" << this << " gemm: " << dt * 1e3 << " ms";
139  t_begin = std::chrono::system_clock::now();
140  }
141 #endif
142  // Add bias term
143  if (!bias_multiplier_.has_value()) {
144  bias_multiplier_ =
145  caffe2::empty({M}, at::dtype<T_B>().device(Context::GetDeviceType()));
146  math::Set<T_B, Context>(
147  M,
148  convert::To<float, T_B>(1),
149  bias_multiplier_->template mutable_data<T_B>(),
150  &context_);
151  } else if (bias_multiplier_->numel() != M) {
152  bias_multiplier_->Resize(M);
153  math::Set<T_B, Context>(
154  M,
155  convert::To<float, T_B>(1),
156  bias_multiplier_->template mutable_data<T_B>(),
157  &context_);
158  }
159 
160  math::Gemm<T_B, Context, Engine>(
161  CblasNoTrans,
162  CblasNoTrans,
163  M,
164  N,
165  1,
166  1,
167  bias_multiplier_->template data<T_B>(),
168  b.template data<T_B>(),
169  1,
170  Y->template mutable_data<T_Y>(),
171  &context_,
172  math_type);
173 
174 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
175  /* if (VLOG_IS_ON(3)) */
176  {
177  t_end = std::chrono::system_clock::now();
178  double dt = std::chrono::duration<double>(t_end - t_begin).count();
179  LOG(INFO) << "@PERF this=" << this << " add_bias : " << dt * 1e3 << " ms";
180  t_begin = std::chrono::system_clock::now();
181  }
182 #endif
183  return true;
184  }
185 
186  bool RunOnDevice() override {
187  return DoRunWithType<
188  float, // X
189  float, // W
190  float, // B
191  float, // Y
192  float>(); // Math
193  }
194 
195  protected:
196  size_t axis_{1};
197  size_t axis_w_{1};
198  // A local vector to cache the output shape so we don't need to recreate
199  // a vector object every time we run Run().
200  vector<int64_t> Y_shape_cache_;
201  c10::optional<Tensor> bias_multiplier_;
202 
203  bool float16_compute_;
204 };
205 
206 template <
207  class Context,
208  class Engine = DefaultEngine,
209  bool TransposeWeight = true>
210 class FullyConnectedGradientOp : public Operator<Context> {
211  public:
212  USE_OPERATOR_CONTEXT_FUNCTIONS;
213  template <class... Args>
214  explicit FullyConnectedGradientOp(Args&&... args)
215  : Operator<Context>(std::forward<Args>(args)...),
216  axis_(this->template GetSingleArgument<int32_t>("axis", 1)),
217  axis_w_(this->template GetSingleArgument<int32_t>("axis_w", 1)),
218  float16_compute_(
219  this->template GetSingleArgument<bool>("float16_compute", false)) {}
220  ~FullyConnectedGradientOp() {}
221 
222  template <
223  typename T_X,
224  typename T_W,
225  typename T_DY,
226  typename T_B,
227  typename T_DX,
228  typename T_DW,
229  typename T_DB,
230  typename MATH>
231  bool DoRunWithType() {
232  const auto& X = Input(0);
233  const auto& W = Input(1);
234  const auto& dY = Input(2);
235  // batch size
236  const auto canonical_axis = X.canonical_axis_index(axis_);
237  const int M = X.size_to_dim(canonical_axis);
238  const int K = X.size_from_dim(canonical_axis);
239  const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
240  const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
241  : W.size_from_dim(canonical_axis_w);
242 
243  auto dimErrorString = [&]() {
244  return c10::str(
245  "Dimension mismatch: ",
246  "X: ",
247  X.sizes(),
248  ", W: ",
249  W.sizes(),
250  ", dY: ",
251  dY.sizes(),
252  ", axis: ",
253  axis_,
254  ", M: ",
255  M,
256  ", N: ",
257  N,
258  ", K: ",
259  K);
260  };
261 
262  CAFFE_ENFORCE(M * K == X.numel(), dimErrorString());
263  CAFFE_ENFORCE(K * N == W.numel(), dimErrorString());
264 
265  auto* dW = Output(0, W.sizes(), at::dtype<T_DW>());
266  auto* db = Output(1, {N}, at::dtype<T_DB>());
267 
268  if (X.numel() == 0) {
269  // generate a zero blob for db and dW when X is empty
270  math::Set<T_DB, Context>(
271  db->numel(),
272  convert::To<float, T_DB>(0),
273  db->template mutable_data<T_DB>(),
274  &context_);
275  math::Set<T_DW, Context>(
276  dW->numel(),
277  convert::To<float, T_DW>(0),
278  dW->template mutable_data<T_DW>(),
279  &context_);
280 
281  if (OutputSize() == 3) {
282  Output(2, X.sizes(), at::dtype<T_DX>());
283  }
284 
285  return true;
286  }
287 
288  // default to FLOAT as math.h does.
289  TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
290  if (fp16_type<MATH>()) {
291  math_type = TensorProto_DataType_FLOAT16;
292  }
293 
294  // Compute dW
295  math::Gemm<T_DY, Context, Engine>(
296  CblasTrans,
297  CblasNoTrans,
298  TransposeWeight ? N : K,
299  TransposeWeight ? K : N,
300  M,
301  1,
302  TransposeWeight ? dY.template data<T_DY>() : X.template data<T_X>(),
303  TransposeWeight ? X.template data<T_X>() : dY.template data<T_DY>(),
304  0,
305  dW->template mutable_data<T_DW>(),
306  &context_,
307  math_type);
308  if (!bias_multiplier_.has_value()) {
309  bias_multiplier_ = caffe2::empty({M}, at::dtype<T_B>().device(Context::GetDeviceType()));
310  math::Set<T_B, Context>(
311  M,
312  convert::To<float, T_B>(1),
313  bias_multiplier_->template mutable_data<T_B>(),
314  &context_);
315  } else if (bias_multiplier_->numel() != M) {
316  bias_multiplier_->Resize(M);
317  math::Set<T_B, Context>(
318  M,
319  convert::To<float, T_B>(1),
320  bias_multiplier_->template mutable_data<T_B>(),
321  &context_);
322  }
323  // Compute dB
324  math::Gemv<T_DY, Context>(
325  CblasTrans,
326  M,
327  N,
328  1,
329  dY.template data<T_DY>(),
330  bias_multiplier_->template data<T_B>(),
331  0,
332  db->template mutable_data<T_DB>(),
333  &context_);
334 
335  // Compute dX
336  if (OutputSize() == 3) {
337  auto* dX = Output(2, X.sizes(), at::dtype<T_DX>());
338  math::Gemm<T_DX, Context, Engine>(
339  CblasNoTrans,
340  TransposeWeight ? CblasNoTrans : CblasTrans,
341  M,
342  K,
343  N,
344  1,
345  dY.template data<T_DY>(),
346  W.template data<T_W>(),
347  0,
348  dX->template mutable_data<T_DX>(),
349  &context_,
350  math_type);
351  }
352  return true;
353  }
354 
355  bool RunOnDevice() override {
356  return DoRunWithType<
357  float, // X
358  float, // W
359  float, // dY
360  float, // B
361  float, // dX
362  float, // dW
363  float, // dB
364  float>(); // Math
365  }
366 
367  protected:
368  size_t axis_{1};
369  size_t axis_w_{1};
370  c10::optional<Tensor> bias_multiplier_;
371  bool float16_compute_;
372 };
373 
374 } // namespace caffe2
375 
376 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Definition: any.cpp:108
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13