Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op_decomposition.h
1 
17 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_DECOMPOSITION_H_
18 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_DECOMPOSITION_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 /*
26  * Although a FC_decomp is just like 2 small FC,
27  * it is better to have it as one op for future analysis.
28  * And if we have 2 FC with bias, it is not right.
29  * TODO(wyiming): decompose the layer into 2 matrices
30  * W(N * K) = U(N * middle) * trans(V(K * middle))
31  * */
32 // This is Caffe's InnerProductOp, with a name that fits its purpose better.
33 template <typename T, class Context, class Engine=DefaultEngine>
34 class FullyConnectedOpDecomp final : public Operator<Context> {
35  public:
36  USE_OPERATOR_CONTEXT_FUNCTIONS;
37  FullyConnectedOpDecomp(const OperatorDef& operator_def, Workspace* ws)
38  : Operator<Context>(operator_def, ws) {}
40 
41  bool RunOnDevice() override {
42  const auto& X = Input(0);
43  const auto& U = Input(1);
44  const auto& V = Input(2);
45  const auto& b = Input(3);
46  auto* Y = Output(0);
47  //auto* buffer_ptr = Output(1);
48  // Size M * middle;
49  //auto& multi_buffer_ = *buffer_ptr;
50  CAFFE_ENFORCE_GE(X.ndim(), 1);
51  CAFFE_ENFORCE_GE(U.ndim(), 2);
52  CAFFE_ENFORCE_GE(V.ndim(), 2);
53  if (X.ndim() > 2 || U.ndim() > 2 || V.ndim() > 2) {
54  VLOG(1) << "Using legacy support for arbitrary input and weight "
55  "dimensions.";
56  }
57  CAFFE_ENFORCE_EQ(b.ndim(), 1);
58  // batch size
59  int M = X.ndim() > 1 ? X.dim32(0) : 1;
60  // Feature dimension
61  int K = X.size() / M;
62  // number of outputs.
63  int N = U.dim32(0);
64  int middle = U.dim32(0);
65  CAFFE_ENFORCE_EQ(K, V.dim32(0));
66  CAFFE_ENFORCE_EQ(N, b.dim32(0));
67  if (X.ndim() > 1) {
68  Y->Resize(M, N);
69  multi_buffer_.Resize(M, middle);
70  } else {
71  Y->Resize(N);
72  multi_buffer_.Resize(middle);
73  }
74  // The col buffer is stored in CHW order as well - kernel_dim, and the height
75  // and width.
76  // multi_buffer_.Resize(M, middle);
77  T* multi_buffer_data = multi_buffer_.template mutable_data<T>();
78  // X * V * tans(U)
79  math::Gemm<T, Context, Engine>(
80  CblasNoTrans, CblasNoTrans, M, middle, K, 1, X.template data<T>(),
81  V.template data<T>(), 0, multi_buffer_data,
82  &context_);
83  math::Gemm<T, Context, Engine>(
84  CblasNoTrans, CblasTrans, M, N, middle, 1, multi_buffer_data,
85  U.template data<T>(), 0, Y->template mutable_data<T>(),
86  &context_);
87  // Add bias term
88  if (bias_multiplier_.size() != M) {
89  // If the helper bias multiplier is not M, reshape and fill it with one.
90  bias_multiplier_.Resize(M);
91  math::Set<T, Context>(
92  M, static_cast<T>(1), bias_multiplier_.template mutable_data<T>(),
93  &context_);
94  }
95  math::Gemm<T, Context, Engine>(
96  CblasNoTrans, CblasNoTrans, M, N, 1, 1,
97  bias_multiplier_.template data<T>(), b.template data<T>(), 1,
98  Y->template mutable_data<T>(), &context_);
99  return true;
100  }
101 
102  protected:
103  Tensor<Context> bias_multiplier_;
104  Tensor<Context> multi_buffer_;
105 };
106 
107 template <typename T, class Context, class Engine=DefaultEngine>
108 class FullyConnectedDecompGradientOp : public Operator<Context> {
109  public:
110  USE_OPERATOR_CONTEXT_FUNCTIONS;
111  FullyConnectedDecompGradientOp(const OperatorDef& operator_def, Workspace* ws)
112  : Operator<Context>(operator_def, ws) {}
114 
115  bool RunOnDevice() override {
116  const auto& X = Input(0);
117  const auto& U = Input(1);
118  const auto& V = Input(2);
119  const auto& dY = Input(3);
120  DCHECK_GE(X.ndim(), 1);
121  DCHECK_GE(U.ndim(), 2);
122  DCHECK_GE(V.ndim(), 2);
123  DCHECK_LE(dY.ndim(), 2);
124  // batch size
125  int M = X.ndim() > 1 ? X.dim32(0) : 1;
126  // Feature dimension
127  int K = X.size() / M;
128  // number of outputs.
129  int N = U.dim32(0);
130  int middle = U.dim32(1);
131  DCHECK_EQ(K, V.dim32(0));
132  if (dY.ndim() > 1) {
133  DCHECK_EQ(M, dY.dim32(0));
134  DCHECK_EQ(N, dY.dim32(1));
135  } else {
136  DCHECK_EQ(X.ndim(), 1);
137  DCHECK_EQ(N, dY.size());
138  }
139  auto* dU = Output(0);
140  auto* dV = Output(1);
141  auto* db = Output(2);
142  dU->ResizeLike(U);
143  dV->ResizeLike(V);
144  db->Resize(N);
145 
146  // Compute dU
147  // first compute X * V
148  du_buffer_.Resize(N, middle);
149  T* du_buffer_data = du_buffer_.template mutable_data<T>();
150  math::Gemm<T, Context, Engine>(
151  CblasNoTrans, CblasNoTrans, M, middle, K, 1,
152  X.template data<T>(), V.template data<T>(),
153  0, du_buffer_data,
154  &context_);
155  math::Gemm<T, Context, Engine>(
156  CblasTrans, CblasNoTrans, N, middle, M, 1,
157  dY.template data<T>(), du_buffer_data,
158  0, dU->template mutable_data<T>(),
159  &context_);
160  // Compute dV
161  // first compute dY * U
162  dv_buffer_.Resize(M, middle);
163  T* dv_buffer_data = dv_buffer_.template mutable_data<T>();
164  math::Gemm<T, Context, Engine>(
165  CblasNoTrans, CblasNoTrans, M, middle, N, 1,
166  dY.template data<T>(), U.template data<T>(),
167  0, dv_buffer_data,
168  &context_);
169  math::Gemm<T, Context, Engine>(
170  CblasTrans, CblasNoTrans, K, middle, M, 1,
171  dY.template data<T>(), du_buffer_data,
172  0, dV->template mutable_data<T>(),
173  &context_);
174  if (bias_multiplier_.size() != M) {
175  // If the helper bias multiplier is not M, reshape and fill it with one.
176  bias_multiplier_.Resize(M);
177  math::Set<T, Context>(
178  M, static_cast<T>(1),
179  bias_multiplier_.template mutable_data<T>(),
180  &context_);
181  }
182  // Compute dB
183  math::Gemv<T, Context>(
184  CblasTrans, M, N, 1, dY.template data<T>(),
185  bias_multiplier_.template data<T>(), 0,
186  db->template mutable_data<T>(),
187  &context_);
188  // Compute dX if necessary.
189  if (OutputSize() == 4) {
190  auto* dX = Output(3);
191  dX->ResizeLike(X);
192  dx_buffer_.Resize(M, middle);
193  T* dx_buffer_data = dx_buffer_.template mutable_data<T>();
194  math::Gemm<T, Context, Engine>(
195  CblasNoTrans, CblasNoTrans, M, middle, N, 1,
196  dY.template data<T>(), U.template data<T>(),
197  0, dx_buffer_data,
198  &context_);
199  math::Gemm<T, Context, Engine>(
200  CblasNoTrans, CblasTrans, M, K, middle, 1,
201  dx_buffer_data, V.template data<T>(),
202  0, dX->template mutable_data<T>(),
203  &context_);
204  }
205 
206  return true;
207  }
208 
209  protected:
210  Tensor<Context> bias_multiplier_;
211  Tensor<Context> du_buffer_;
212  Tensor<Context> dv_buffer_;
213  Tensor<Context> dx_buffer_;
214 };
215 
216 } // namespace caffe2
217 
218 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.