Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_fake_lowp_op.cc
1 
17 #include <functional>
18 
19 #include "fully_connected_fake_lowp_op.h"
20 
21 namespace caffe2 {
22 
23 constexpr int nlines_log = 10000;
24 
25 template <
26  void (*Q)(const float*, size_t, float*),
27  class Context,
28  class Engine,
29  bool TransposeWeight>
30 template <typename T_X, typename T_W, typename T_B, typename T_Y, typename MATH>
31 bool FullyConnectedFakeLowpFPOp<Q, Context, Engine, TransposeWeight>::
32  DoRunWithType() {
33  const auto& X = Input(0);
34  const auto& W = Input(1);
35  const auto& b = Input(2);
36 
37  CAFFE_ENFORCE(b.dim() == 1, b.dim());
38  // batch size
39  const auto canonical_axis = X.canonical_axis_index(axis_);
40  const auto M = X.size_to_dim(canonical_axis);
41  const auto K = X.size_from_dim(canonical_axis);
42  const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
43  const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
44  : W.size_from_dim(canonical_axis_w);
45 
46  auto dimErrorString = [&]() {
47  return c10::str(
48  "Dimension mismatch: ",
49  "X: ",
50  X.sizes(),
51  ", W: ",
52  W.sizes(),
53  ", b: ",
54  b.sizes(),
55  ", axis: ",
56  axis_,
57  ", M: ",
58  M,
59  ", N: ",
60  N,
61  ", K: ",
62  K);
63  };
64 
65  // Error checking
66  CAFFE_ENFORCE(M == X.size() / K, dimErrorString());
67  CAFFE_ENFORCE(K == W.size() / N, dimErrorString());
68  CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
69  CAFFE_ENFORCE(N == b.size(), dimErrorString());
70 
71  static int log_occurences = 0;
72  if (log_occurences % nlines_log == 0) {
73  ++log_occurences;
74  LOG(INFO) << "FAKE_FP16 fc running";
75  }
76 
77  Y_shape_cache_ = X.sizes().vec();
78  // This is an invariant of canonical_axis, so we can DCHECK.
79  DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
80  Y_shape_cache_.resize(canonical_axis + 1);
81  Y_shape_cache_[canonical_axis] = N;
82  auto* Y = Output(0, Y_shape_cache_, at::dtype<T_Y>());
83  CAFFE_ENFORCE(M * N == Y->size(), dimErrorString());
84 
85  if (X.size() == 0) {
86  // skip the rest of the computation if X is empty
87  Y->template mutable_data<T_Y>();
88  return true;
89  }
90 
91  // default to FLOAT as math.h does.
92  TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
93  if (fp16_type<MATH>()) {
94  math_type = TensorProto_DataType_FLOAT16;
95  }
96 
97  // Y = W * X + b
98  // Quantize W, X, b
99  auto type = Context::GetDeviceType();
100  Tensor Xh(type);
101  Xh.ResizeLike(X);
102  Q(X.template data<T_X>(), Xh.size(), Xh.template mutable_data<T_X>());
103 
104  Tensor Wh(type);
105  Wh.ResizeLike(W);
106  Q(W.template data<T_W>(), Wh.size(), Wh.template mutable_data<T_W>());
107 
108  Tensor bh(type);
109  bh.ResizeLike(b);
110  Q(b.template data<T_B>(), bh.size(), bh.template mutable_data<T_B>());
111 
112  // W * x
113  math::Gemm<T_X, Context, Engine>(
114  CblasNoTrans,
115  TransposeWeight ? CblasTrans : CblasNoTrans,
116  M,
117  N,
118  K,
119  1,
120  Xh.template data<T_X>(),
121  Wh.template data<T_W>(),
122  0,
123  Y->template mutable_data<T_Y>(),
124  &context_,
125  math_type);
126  // Add bias term
127  if (bias_multiplier_.size() != M) {
128  // If the helper bias multiplier is not M, reshape and fill it with one.
130  &bias_multiplier_,
131  {M},
132  at::dtype<T_B>().device(Context::GetDeviceType()));
133  math::Set<T_B, Context>(
134  M,
135  convert::To<float, T_B>(1),
136  bias_multiplier_.template mutable_data<T_B>(),
137  &context_);
138  }
139  math::Gemm<T_B, Context, Engine>(
140  CblasNoTrans,
141  CblasNoTrans,
142  M,
143  N,
144  1,
145  1,
146  bias_multiplier_.template data<T_B>(),
147  bh.template data<T_B>(),
148  1,
149  Y->template mutable_data<T_Y>(),
150  &context_,
151  math_type);
152 
153  return true;
154 }
155 
156 template <
157  void (*Q)(const float*, size_t, float*),
158  class Context,
159  class Engine,
160  bool TransposeWeight>
161 template <
162  typename T_X,
163  typename T_W,
164  typename T_DY,
165  typename T_B,
166  typename T_DX,
167  typename T_DW,
168  typename T_DB,
169  typename MATH>
170 bool FullyConnectedGradientFakeLowpFPOp<Q, Context, Engine, TransposeWeight>::
171  DoRunWithType() {
172  const auto& X = Input(0);
173  const auto& W = Input(1);
174  const auto& dY = Input(2);
175  // batch size
176  const auto canonical_axis = X.canonical_axis_index(axis_);
177  const int M = X.size_to_dim(canonical_axis);
178  const int K = X.size_from_dim(canonical_axis);
179  const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
180  const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
181  : W.size_from_dim(canonical_axis_w);
182  CAFFE_ENFORCE(M * K == X.size());
183  CAFFE_ENFORCE(K * N == W.size());
184 
185  auto* dW = Output(0, W.sizes(), at::dtype<T_DW>());
186  auto* db = Output(1, {N}, at::dtype<T_DB>());
187 
188  if (X.size() == 0) {
189  // generate a zero blob for db and dW when X is empty
190  math::Set<T_DB, Context>(
191  db->size(),
192  convert::To<float, T_DB>(0),
193  db->template mutable_data<T_DB>(),
194  &context_);
195  math::Set<T_DW, Context>(
196  dW->size(),
197  convert::To<float, T_DW>(0),
198  dW->template mutable_data<T_DW>(),
199  &context_);
200 
201  if (OutputSize() == 3) {
202  Output(2, X.sizes(), at::dtype<T_DX>());
203  }
204 
205  return true;
206  }
207 
208  // default to FLOAT as math.h does.
209  TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
210  if (fp16_type<MATH>()) {
211  math_type = TensorProto_DataType_FLOAT16;
212  }
213 
214  auto type = Context::GetDeviceType();
215  // Quantize: W, X, dY
216  Tensor Xh(type);
217  Xh.ResizeLike(X);
218  Q(X.template data<T_X>(), Xh.size(), Xh.template mutable_data<T_X>());
219 
220  Tensor Wh(type);
221  Wh.ResizeLike(W);
222  Q(W.template data<T_W>(), Wh.size(), Wh.template mutable_data<T_W>());
223 
224  Tensor dYh(type);
225  dYh.ResizeLike(dY);
226  Q(dY.template data<T_DY>(), dYh.size(), dYh.template mutable_data<T_DY>());
227 
228  static int log_occurences = 0;
229  if (log_occurences % nlines_log == 0) {
230  ++log_occurences;
231  LOG(INFO) << "FAKE_FP16 fc grad running";
232  }
233 
234  // Compute dW
235  math::Gemm<T_DY, Context, Engine>(
236  CblasTrans,
237  CblasNoTrans,
238  TransposeWeight ? N : K,
239  TransposeWeight ? K : N,
240  M,
241  1,
242  TransposeWeight ? dYh.template data<T_DY>() : Xh.template data<T_X>(),
243  TransposeWeight ? Xh.template data<T_X>() : dYh.template data<T_DY>(),
244  0,
245  dW->template mutable_data<T_DW>(),
246  &context_,
247  math_type);
248  if (bias_multiplier_.size() != M) {
249  // If the helper bias multiplier is not M, reshape and fill it
250  // with one.
252  &bias_multiplier_,
253  {M},
254  at::dtype<T_B>().device(Context::GetDeviceType()));
255  math::Set<T_B, Context>(
256  M,
257  convert::To<float, T_B>(1),
258  bias_multiplier_.template mutable_data<T_B>(),
259  &context_);
260  }
261  // Compute dB
262  math::Gemv<T_DY, Context>(
263  CblasTrans,
264  M,
265  N,
266  1,
267  dYh.template data<T_DY>(),
268  bias_multiplier_.template data<T_B>(),
269  0,
270  db->template mutable_data<T_DB>(),
271  &context_);
272 
273  // Compute dX
274  if (OutputSize() == 3) {
275  auto* dX = Output(2, X.sizes(), at::dtype<T_DX>());
276  math::Gemm<T_DX, Context, Engine>(
277  CblasNoTrans,
278  TransposeWeight ? CblasNoTrans : CblasTrans,
279  M,
280  K,
281  N,
282  1,
283  dYh.template data<T_DY>(),
284  Wh.template data<T_W>(),
285  0,
286  dX->template mutable_data<T_DX>(),
287  &context_,
288  math_type);
289  }
290 
291  return true;
292 }
293 
294 // IEEE FP16
295 REGISTER_CPU_OPERATOR_WITH_ENGINE(
296  FC,
297  FAKE_FP16,
298  FullyConnectedFakeLowpFPOp<fp32_to_fp16, CPUContext>);
299 REGISTER_CPU_OPERATOR_WITH_ENGINE(
300  FCGradient,
301  FAKE_FP16,
302  FullyConnectedGradientFakeLowpFPOp<fp32_to_fp16, CPUContext>);
303 
304 // BFLOAT 16
305 REGISTER_CPU_OPERATOR_WITH_ENGINE(
306  FC,
307  FAKE_BFP_16,
308  FullyConnectedFakeLowpFPOp<fp32_to_bfp16, CPUContext>);
309 REGISTER_CPU_OPERATOR_WITH_ENGINE(
310  FCGradient,
311  FAKE_BFP_16,
312  FullyConnectedGradientFakeLowpFPOp<fp32_to_bfp16, CPUContext>);
313 
314 // BFLOAT 24 (chop the least significant 8 bits)
315 REGISTER_CPU_OPERATOR_WITH_ENGINE(
316  FC,
317  FAKE_BFP_24,
318  FullyConnectedFakeLowpFPOp<fp32_to_bfp24, CPUContext>);
319 REGISTER_CPU_OPERATOR_WITH_ENGINE(
320  FCGradient,
321  FAKE_BFP_24,
322  FullyConnectedGradientFakeLowpFPOp<fp32_to_bfp24, CPUContext>);
323 
324 // BFLOAT 14 (chop 2 extra bits from BFLOAT 16)
325 REGISTER_CPU_OPERATOR_WITH_ENGINE(
326  FC,
327  FAKE_BFP_14,
328  FullyConnectedFakeLowpFPOp<fp32_to_bfp14, CPUContext>);
329 REGISTER_CPU_OPERATOR_WITH_ENGINE(
330  FCGradient,
331  FAKE_BFP_14,
332  FullyConnectedGradientFakeLowpFPOp<fp32_to_bfp14, CPUContext>);
333 
334 // BFLOAT16 with rounding
335 REGISTER_CPU_OPERATOR_WITH_ENGINE(
336  FC,
337  FAKE_BFP_16_ROUND,
338  FullyConnectedFakeLowpFPOp<fp32_to_bfp16_round, CPUContext>);
339 REGISTER_CPU_OPERATOR_WITH_ENGINE(
340  FCGradient,
341  FAKE_BFP_16_ROUND,
342  FullyConnectedGradientFakeLowpFPOp<fp32_to_bfp16_round, CPUContext>);
343 
344 } // namespace caffe2
Definition: any.cpp:108
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: OpClasses.h:566