Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op_decomposition.h
1 
17 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_DECOMPOSITION_H_
18 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_DECOMPOSITION_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 /*
26  * Although a FC_decomp is just like 2 small FC,
27  * it is better to have it as one op for future analysis.
28  * And if we have 2 FC with bias, it is not right.
29  * TODO(wyiming): decompose the layer into 2 matrices
30  * W(N * K) = U(N * middle) * trans(V(K * middle))
31  * */
32 // This is Caffe's InnerProductOp, with a name that fits its purpose better.
33 template <typename T, class Context, class Engine=DefaultEngine>
34 class FullyConnectedOpDecomp final : public Operator<Context> {
35  public:
36  USE_OPERATOR_CONTEXT_FUNCTIONS;
37  FullyConnectedOpDecomp(const OperatorDef& operator_def, Workspace* ws)
38  : Operator<Context>(operator_def, ws) {}
40 
41  bool RunOnDevice() override {
42  const auto& X = Input(0);
43  const auto& U = Input(1);
44  const auto& V = Input(2);
45  const auto& b = Input(3);
46 
47  //auto* buffer_ptr = Output(1);
48  // Size M * middle;
49  //auto& multi_buffer_ = *buffer_ptr;
50  CAFFE_ENFORCE_GE(X.dim(), 1);
51  CAFFE_ENFORCE_GE(U.dim(), 2);
52  CAFFE_ENFORCE_GE(V.dim(), 2);
53  if (X.dim() > 2 || U.dim() > 2 || V.dim() > 2) {
54  VLOG(1) << "Using legacy support for arbitrary input and weight "
55  "dimensions.";
56  }
57  CAFFE_ENFORCE_EQ(b.dim(), 1);
58  // batch size
59  int M = X.dim() > 1 ? X.dim32(0) : 1;
60  // Feature dimension
61  int K = X.numel() / M;
62  // number of outputs.
63  int N = U.dim32(0);
64  int middle = U.dim32(0);
65  CAFFE_ENFORCE_EQ(K, V.dim32(0));
66  CAFFE_ENFORCE_EQ(N, b.dim32(0));
67  std::vector<int64_t> dims;
68  if (X.dim() > 1) {
69  dims = {M, N};
70  multi_buffer_.Resize(M, middle);
71  } else {
72  dims = {N};
73  multi_buffer_.Resize(middle);
74  }
75  auto* Y = Output(0, dims, at::dtype<T>());
76  // The col buffer is stored in CHW order as well - kernel_dim, and the
77  // height and width.
78  // multi_buffer_.Resize(M, middle);
79  T* multi_buffer_data = multi_buffer_.template mutable_data<T>();
80  // X * V * tans(U)
81  math::Gemm<T, Context, Engine>(
82  CblasNoTrans, CblasNoTrans, M, middle, K, 1, X.template data<T>(),
83  V.template data<T>(), 0, multi_buffer_data,
84  &context_);
85  math::Gemm<T, Context, Engine>(
86  CblasNoTrans, CblasTrans, M, N, middle, 1, multi_buffer_data,
87  U.template data<T>(), 0, Y->template mutable_data<T>(),
88  &context_);
89  // Add bias term
90  if (bias_multiplier_.numel() != M) {
91  // If the helper bias multiplier is not M, reshape and fill it with one.
92  bias_multiplier_.Resize(M);
93  math::Set<T, Context>(
94  M, static_cast<T>(1), bias_multiplier_.template mutable_data<T>(),
95  &context_);
96  }
97  math::Gemm<T, Context, Engine>(
98  CblasNoTrans, CblasNoTrans, M, N, 1, 1,
99  bias_multiplier_.template data<T>(), b.template data<T>(), 1,
100  Y->template mutable_data<T>(), &context_);
101  return true;
102  }
103 
104  protected:
105  Tensor bias_multiplier_{Context::GetDeviceType()};
106  Tensor multi_buffer_{Context::GetDeviceType()};
107 };
108 
109 template <typename T, class Context, class Engine=DefaultEngine>
110 class FullyConnectedDecompGradientOp : public Operator<Context> {
111  public:
112  USE_OPERATOR_CONTEXT_FUNCTIONS;
113  FullyConnectedDecompGradientOp(const OperatorDef& operator_def, Workspace* ws)
114  : Operator<Context>(operator_def, ws) {}
116 
117  bool RunOnDevice() override {
118  const auto& X = Input(0);
119  const auto& U = Input(1);
120  const auto& V = Input(2);
121  const auto& dY = Input(3);
122  DCHECK_GE(X.dim(), 1);
123  DCHECK_GE(U.dim(), 2);
124  DCHECK_GE(V.dim(), 2);
125  DCHECK_LE(dY.dim(), 2);
126  // batch size
127  int M = X.dim() > 1 ? X.dim32(0) : 1;
128  // Feature dimension
129  int K = X.numel() / M;
130  // number of outputs.
131  int N = U.dim32(0);
132  int middle = U.dim32(1);
133  DCHECK_EQ(K, V.dim32(0));
134  if (dY.dim() > 1) {
135  DCHECK_EQ(M, dY.dim32(0));
136  DCHECK_EQ(N, dY.dim32(1));
137  } else {
138  DCHECK_EQ(X.dim(), 1);
139  DCHECK_EQ(N, dY.numel());
140  }
141 
142  auto* dU = Output(0, U.sizes(), at::dtype<T>());
143  auto* dV = Output(1, V.sizes(), at::dtype<T>());
144  auto* db = Output(2, {N}, at::dtype<T>());
145 
146  // Compute dU
147  // first compute X * V
148  du_buffer_.Resize(N, middle);
149  T* du_buffer_data = du_buffer_.template mutable_data<T>();
150  math::Gemm<T, Context, Engine>(
151  CblasNoTrans, CblasNoTrans, M, middle, K, 1,
152  X.template data<T>(), V.template data<T>(),
153  0, du_buffer_data,
154  &context_);
155  math::Gemm<T, Context, Engine>(
156  CblasTrans, CblasNoTrans, N, middle, M, 1,
157  dY.template data<T>(), du_buffer_data,
158  0, dU->template mutable_data<T>(),
159  &context_);
160  // Compute dV
161  // first compute dY * U
162  dv_buffer_.Resize(M, middle);
163  T* dv_buffer_data = dv_buffer_.template mutable_data<T>();
164  math::Gemm<T, Context, Engine>(
165  CblasNoTrans, CblasNoTrans, M, middle, N, 1,
166  dY.template data<T>(), U.template data<T>(),
167  0, dv_buffer_data,
168  &context_);
169  math::Gemm<T, Context, Engine>(
170  CblasTrans, CblasNoTrans, K, middle, M, 1,
171  dY.template data<T>(), du_buffer_data,
172  0, dV->template mutable_data<T>(),
173  &context_);
174  if (bias_multiplier_.numel() != M) {
175  // If the helper bias multiplier is not M, reshape and fill it with one.
176  bias_multiplier_.Resize(M);
177  math::Set<T, Context>(
178  M, static_cast<T>(1),
179  bias_multiplier_.template mutable_data<T>(),
180  &context_);
181  }
182  // Compute dB
183  math::Gemv<T, Context>(
184  CblasTrans, M, N, 1, dY.template data<T>(),
185  bias_multiplier_.template data<T>(), 0,
186  db->template mutable_data<T>(),
187  &context_);
188  // Compute dX if necessary.
189  if (OutputSize() == 4) {
190  auto* dX = Output(3, X.sizes(), at::dtype<T>());
191  dx_buffer_.Resize(M, middle);
192  T* dx_buffer_data = dx_buffer_.template mutable_data<T>();
193  math::Gemm<T, Context, Engine>(
194  CblasNoTrans, CblasNoTrans, M, middle, N, 1,
195  dY.template data<T>(), U.template data<T>(),
196  0, dx_buffer_data,
197  &context_);
198  math::Gemm<T, Context, Engine>(
199  CblasNoTrans, CblasTrans, M, K, middle, 1,
200  dx_buffer_data, V.template data<T>(),
201  0, dX->template mutable_data<T>(),
202  &context_);
203  }
204 
205  return true;
206  }
207 
208  protected:
209  Tensor bias_multiplier_{Context::GetDeviceType()};
210  Tensor du_buffer_{Context::GetDeviceType()};
211  Tensor dv_buffer_{Context::GetDeviceType()};
212  Tensor dx_buffer_{Context::GetDeviceType()};
213 };
214 
215 } // namespace caffe2
216 
217 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Definition: any.cpp:108
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13