Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op.h
1 
17 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
18 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/conversions.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 // This is Caffe's InnerProductOp, with a name that fits its purpose better.
28 template <
29  class Context,
30  class Engine = DefaultEngine,
31  bool TransposeWeight = true>
32 class FullyConnectedOp final : public Operator<Context> {
33  public:
34  USE_OPERATOR_CONTEXT_FUNCTIONS;
35  FullyConnectedOp(const OperatorDef& operator_def, Workspace* ws)
36  : Operator<Context>(operator_def, ws),
37  axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)),
38  axis_w_(OperatorBase::GetSingleArgument<int32_t>("axis_w", 1)),
39  float16_compute_(
40  OperatorBase::GetSingleArgument<bool>("float16_compute", false)) {}
41  ~FullyConnectedOp() {}
42 
43  template <
44  typename T_X,
45  typename T_W,
46  typename T_B,
47  typename T_Y,
48  typename MATH>
49  bool DoRunWithType() {
50  const auto& X = Input(0);
51  const auto& W = Input(1);
52  const auto& b = Input(2);
53  auto* Y = Output(0);
54  CAFFE_ENFORCE(b.ndim() == 1, b.ndim());
55  // batch size
56  const auto canonical_axis = X.canonical_axis_index(axis_);
57  const auto M = X.size_to_dim(canonical_axis);
58  const auto K = X.size_from_dim(canonical_axis);
59  const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
60  const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
61  : W.size_from_dim(canonical_axis_w);
62 
63  auto dimErrorString = [&]() {
64  return MakeString(
65  "Dimension mismatch: ",
66  "X: ",
67  X.dims(),
68  ", W: ",
69  W.dims(),
70  ", b: ",
71  b.dims(),
72  ", axis: ",
73  axis_,
74  ", M: ",
75  M,
76  ", N: ",
77  N,
78  ", K: ",
79  K);
80  };
81 
82  // Error checking
83  CAFFE_ENFORCE(M == X.size() / K, dimErrorString());
84  CAFFE_ENFORCE(K == W.size() / N, dimErrorString());
85  CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
86  CAFFE_ENFORCE(N == b.size(), dimErrorString());
87 
88  Y_shape_cache_ = X.dims();
89  // This is an invariant of canonical_axis, so we can DCHECK.
90  DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
91  Y_shape_cache_.resize(canonical_axis + 1);
92  Y_shape_cache_[canonical_axis] = N;
93  Y->Resize(Y_shape_cache_);
94  CAFFE_ENFORCE(M * N == Y->size(), dimErrorString());
95 
96  if (X.size() == 0) {
97  // skip the rest of the computation if X is empty
98  Y->template mutable_data<T_Y>();
99  return true;
100  }
101 
102  // default to FLOAT as math.h does.
103  TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
104  if (fp16_type<MATH>()) {
105  math_type = TensorProto_DataType_FLOAT16;
106  }
107 
108  // W * x
109  math::Gemm<T_X, Context, Engine>(
110  CblasNoTrans,
111  TransposeWeight ? CblasTrans : CblasNoTrans,
112  M,
113  N,
114  K,
115  1,
116  X.template data<T_X>(),
117  W.template data<T_W>(),
118  0,
119  Y->template mutable_data<T_Y>(),
120  &context_,
121  math_type);
122  // Add bias term
123  if (bias_multiplier_.size() != M) {
124  // If the helper bias multiplier is not M, reshape and fill it with one.
125  bias_multiplier_.Resize(M);
126  math::Set<T_B, Context>(
127  M,
128  convert::To<float, T_B>(1),
129  bias_multiplier_.template mutable_data<T_B>(),
130  &context_);
131  }
132  math::Gemm<T_B, Context, Engine>(
133  CblasNoTrans,
134  CblasNoTrans,
135  M,
136  N,
137  1,
138  1,
139  bias_multiplier_.template data<T_B>(),
140  b.template data<T_B>(),
141  1,
142  Y->template mutable_data<T_Y>(),
143  &context_,
144  math_type);
145  return true;
146  }
147 
148  bool RunOnDevice() override {
149  return DoRunWithType<
150  float, // X
151  float, // W
152  float, // B
153  float, // Y
154  float>(); // Math
155  }
156 
157  protected:
158  size_t axis_{1};
159  size_t axis_w_{1};
160  // A local vector to cache the output shape so we don't need to recreate
161  // a vector object every time we run Run().
162  vector<TIndex> Y_shape_cache_;
163  Tensor<Context> bias_multiplier_;
164 
165  bool float16_compute_;
166 };
167 
168 template <
169  class Context,
170  class Engine = DefaultEngine,
171  bool TransposeWeight = true>
172 class FullyConnectedGradientOp : public Operator<Context> {
173  public:
174  USE_OPERATOR_CONTEXT_FUNCTIONS;
175  FullyConnectedGradientOp(const OperatorDef& operator_def, Workspace* ws)
176  : Operator<Context>(operator_def, ws),
177  axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)),
178  axis_w_(OperatorBase::GetSingleArgument<int32_t>("axis_w", 1)),
179  float16_compute_(
180  OperatorBase::GetSingleArgument<bool>("float16_compute", false)) {}
182 
183  template <
184  typename T_X,
185  typename T_W,
186  typename T_DY,
187  typename T_B,
188  typename T_DX,
189  typename T_DW,
190  typename T_DB,
191  typename MATH>
192  bool DoRunWithType() {
193  const auto& X = Input(0);
194  const auto& W = Input(1);
195  const auto& dY = Input(2);
196  // batch size
197  const auto canonical_axis = X.canonical_axis_index(axis_);
198  const int M = X.size_to_dim(canonical_axis);
199  const int K = X.size_from_dim(canonical_axis);
200  const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
201  const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
202  : W.size_from_dim(canonical_axis_w);
203  CAFFE_ENFORCE(M * K == X.size());
204  CAFFE_ENFORCE(K * N == W.size());
205 
206  auto* dW = Output(0);
207  auto* db = Output(1);
208  dW->ResizeLike(W);
209  db->Resize(N);
210 
211  if (X.size() == 0) {
212  // generate a zero blob for db and dW when X is empty
213  math::Set<T_DB, Context>(
214  db->size(),
215  convert::To<float, T_DB>(0),
216  db->template mutable_data<T_DB>(),
217  &context_);
218  math::Set<T_DW, Context>(
219  dW->size(),
220  convert::To<float, T_DW>(0),
221  dW->template mutable_data<T_DW>(),
222  &context_);
223 
224  if (OutputSize() == 3) {
225  auto* dX = Output(2);
226  dX->ResizeLike(X);
227  dX->template mutable_data<T_DX>();
228  }
229 
230  return true;
231  }
232 
233  // default to FLOAT as math.h does.
234  TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
235  if (fp16_type<MATH>()) {
236  math_type = TensorProto_DataType_FLOAT16;
237  }
238 
239  // Compute dW
240  math::Gemm<T_DY, Context, Engine>(
241  CblasTrans,
242  CblasNoTrans,
243  TransposeWeight ? N : K,
244  TransposeWeight ? K : N,
245  M,
246  1,
247  TransposeWeight ? dY.template data<T_DY>() : X.template data<T_X>(),
248  TransposeWeight ? X.template data<T_X>() : dY.template data<T_DY>(),
249  0,
250  dW->template mutable_data<T_DW>(),
251  &context_,
252  math_type);
253  if (bias_multiplier_.size() != M) {
254  // If the helper bias multiplier is not M, reshape and fill it
255  // with one.
256  bias_multiplier_.Resize(M);
257  math::Set<T_B, Context>(
258  M,
259  convert::To<float, T_B>(1),
260  bias_multiplier_.template mutable_data<T_B>(),
261  &context_);
262  }
263  // Compute dB
264  math::Gemv<T_DY, Context>(
265  CblasTrans,
266  M,
267  N,
268  1,
269  dY.template data<T_DY>(),
270  bias_multiplier_.template data<T_B>(),
271  0,
272  db->template mutable_data<T_DB>(),
273  &context_);
274 
275  // Compute dX
276  if (OutputSize() == 3) {
277  auto* dX = Output(2);
278  dX->ResizeLike(X);
279  math::Gemm<T_DX, Context, Engine>(
280  CblasNoTrans,
281  TransposeWeight ? CblasNoTrans : CblasTrans,
282  M,
283  K,
284  N,
285  1,
286  dY.template data<T_DY>(),
287  W.template data<T_W>(),
288  0,
289  dX->template mutable_data<T_DX>(),
290  &context_,
291  math_type);
292  }
293  return true;
294  }
295 
296  bool RunOnDevice() override {
297  return DoRunWithType<
298  float, // X
299  float, // W
300  float, // dY
301  float, // B
302  float, // dX
303  float, // dW
304  float, // dB
305  float>(); // Math
306  }
307 
308  protected:
309  size_t axis_{1};
310  size_t axis_w_{1};
311  Tensor<Context> bias_multiplier_;
312  bool float16_compute_;
313 };
314 
315 } // namespace caffe2
316 
317 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.