Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op_prune.h
1 
17 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_
18 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26  namespace {
27 
28  template<int N>
29  using Shape = std::array<int, N>;
30 
31  template<int N>
32  const std::vector<int64_t>& shape(Shape<N> vs) {
33  static thread_local std::vector<int64_t> cache;
34  cache.resize(vs.size());
35  for (auto i = 0; i < vs.size(); ++i) {
36  cache[i] = vs[i];
37  }
38  return cache;
39  }
40 
41  inline const std::vector<int64_t>& shape(int i) {
42  return shape<1>(Shape<1>({i}));
43  }
44 
45  inline const std::vector<int64_t>& shape(int i, int j) {
46  return shape<2>(Shape<2>({i, j}));
47  }
48 
49  template <typename T, class Context>
50  void MaskMatrix(const T* mask, T* mat,
51  int M, int N);
52 
53  template <typename T, class Context>
54  void MaskMatrix_Inc(T* mask_seq, T* mat,
55  int M, int N, int seq_len, T target);
56 
57  template <typename T, class Context>
58  void AggrDW(T* ag_dw, const T* dw, int N, int K, Context* context);
59 
60  template <typename T>
61  int MatrixCompare_LT(const T* mat, float thres,
62  T* mask_seq, int M, int N);
63 
64  // TODO(wyiming): write an incremental Mask
65  // Incremental Mask: only give the new mask positions;
66  // Assuming that weights masked will not be mask again;
67  // The incremental mask can also be used to update mask matrix;
68  // But this will include template for bool and float;
69  template <>
70  void MaskMatrix<float, CPUContext>(
71  const float* mask, float* mat, int M, int N) {
72  int offset = 0;
73  for (int i = 0; i < M; ++i) {
74  for (int j = 0; j < N; ++j) {
75  mat[offset] = mask[offset]? mat[offset] : 0;
76  offset++;
77  }
78  }
79  }
80 
81  template <>
82  void MaskMatrix_Inc<float, CPUContext>(
83  float* mask_seq,
84  float* mat,
85  int /*M*/,
86  int /*N*/,
87  int seq_len,
88  float target) {
89  for (int i = 0; i < seq_len; ++i) {
90  // assume that the mask_seq is smaller than size
91  // Although it seems that random access gets bad performance,
92  // we make sure that seq is in order;
93  mat[static_cast<int>(mask_seq[i])] = target;
94  }
95  }
96 
97  template <>
98  void AggrDW<float, CPUContext>(
99  float* ag_dw, const float* dw,
100  int N, int K, CPUContext* context) {
101  math::Add<float, CPUContext>(N*K, dw, ag_dw, ag_dw, context);
102  }
103 
104  template <>
105  int MatrixCompare_LT<float>(
106  const float* mat, float thres,
107  float* mask_seq, int M, int N) {
108  int seq_len = 0;
109  int offset = 0;
110  for (int i = 0 ; i < M; ++i) {
111  for (int j = 0; j < N; ++j) {
112  if (mat[offset] != 0 &&
113  (mat[offset] < thres && mat[offset] > -thres)) {
114  mask_seq[seq_len++] = static_cast<float>(offset);
115  }
116  offset++;
117  }
118  }
119  return seq_len;
120  }
121 
122  }
123 
124  // This is Caffe's InnerProductOp, with a name that fits its purpose better.
125  template <typename T, class Context, class Engine=DefaultEngine>
126  class FullyConnectedOpPrune final : public Operator<Context> {
127  public:
128  USE_OPERATOR_CONTEXT_FUNCTIONS;
129  FullyConnectedOpPrune(const OperatorDef& operator_def, Workspace* ws)
130  : Operator<Context>(operator_def, ws) {}
132 
133  bool RunOnDevice() override {
134  const auto& X = Input(0);
135  const auto& W = Input(1);
136  const auto& Mask = Input(2);
137  const auto& b = Input(3);
138 
139  CAFFE_ENFORCE_GE(X.dim(), 1);
140  CAFFE_ENFORCE_GE(W.dim(), 2);
141  if (X.dim() > 2 || W.dim() > 2) {
142  VLOG(1) << "Using legacy support for arbitrary input and weight "
143  "dimensions.";
144  }
145  CAFFE_ENFORCE_EQ(b.dim(), 1);
146  // batch size
147  int M = X.dim() > 1 ? X.dim32(0) : 1;
148  // Feature dimension
149  int K = X.numel() / M;
150  // number of outputs.
151  int N = W.dim32(0);
152  CAFFE_ENFORCE_EQ(K, W.numel() / W.dim32(0));
153  CAFFE_ENFORCE_EQ(N, b.dim32(0));
154  std::vector<int64_t> dims;
155  if (X.dim() > 1) {
156  dims = {M, N};
157  } else {
158  dims = {N};
159  }
160  auto* Y = Output(0, dims, at::dtype<T>());
161  // W * x
162  math::Gemm<T, Context, Engine>(
163  CblasNoTrans, CblasTrans, M, N, K, 1, X.template data<T>(),
164  W.template data<T>(), 0, Y->template mutable_data<T>(),
165  &context_);
166  // Add bias term
167  if (bias_multiplier_.numel() != M) {
168  // If the helper bias multiplier is not M,
169  // reshape and fill it with one.
170  bias_multiplier_.Resize(M);
171  math::Set<T, Context>(
172  M, static_cast<T>(1),
173  bias_multiplier_.template mutable_data<T>(),
174  &context_);
175  }
176  math::Gemm<T, Context, Engine>(
177  CblasNoTrans, CblasNoTrans, M, N, 1, 1,
178  bias_multiplier_.template data<T>(), b.template data<T>(), 1,
179  Y->template mutable_data<T>(), &context_);
180  if (OutputSize() == 2){
181  auto* Comp_rate = Output(1, vector<int64_t>(), at::dtype<T>());
182  T* comp_data = Comp_rate->template mutable_data<T>();
183  math::Sum<T, Context>(
184  Mask.numel(), Mask.template data<T>(), comp_data, &context_);
185  math::Scale<float, T, Context>(
186  1,
187  static_cast<T>(1.) / Mask.numel(),
188  comp_data,
189  comp_data,
190  &context_);
191  }
192  return true;
193  }
194 
195  protected:
196  Tensor bias_multiplier_{Context::GetDeviceType()};
197  };
198 
199  template <typename T, class Context, class Engine=DefaultEngine>
200  class FullyConnectedPruneGradientOp : public Operator<Context> {
201  public:
202  int iter_offset;
203  public:
204  USE_OPERATOR_CONTEXT_FUNCTIONS;
206  (const OperatorDef& operator_def, Workspace* ws)
207  : Operator<Context>(operator_def, ws) { iter_offset = 0; }
209 
210  bool RunOnDevice() override {
211  const auto& X = Input(0);
212  //const auto& W = Input(1);
213  auto* W_ptr = Output(2);
214  auto& W = *W_ptr;
215  //const auto& Mask = Input(2);
216  auto* Mask_ptr = Output(3);
217  auto& Mask = *Mask_ptr;
218  const auto& dY = Input(3);
219  //const auto& Ag_dW = Input(4);
220  auto* Ag_dW_ptr = Output(4);
221  auto& Ag_dW = *Ag_dW_ptr;
222  // it is also the Input(5)
223 
224  // how about get threshold
225  auto& thres = Input(6);
226  //TODO(wyiming): check comp_lb is a float
227  auto& comp_lb = Input(7);
228  DCHECK_GE(X.dim(), 1);
229  DCHECK_GE(W.dim(), 2);
230  DCHECK_LE(dY.dim(), 2);
231  // batch size
232  int M = X.dim() > 1 ? X.dim32(0) : 1;
233  // Feature dimension
234  int K = X.numel() / M;
235  // number of outputs.
236  int N = W.dim32(0);
237  // TODO(wyiming): add this window_size to workspace?
238  int window_size = 100;
239  // TODO(wyiming): this threshold should be
240  // based on distribution of the layer weight
241  float thr = 0.01;
242  DCHECK_EQ(Mask.dim32(0), W.dim32(0));
243  DCHECK_EQ(Mask.dim32(1), W.dim32(1));
244  DCHECK_EQ(Ag_dW.dim32(0), W.dim32(0));
245  DCHECK_EQ(Ag_dW.dim32(1), W.dim32(1));
246  DCHECK_EQ(K, W.numel() / W.dim32(0));
247  if (dY.dim() > 1) {
248  DCHECK_EQ(M, dY.dim32(0));
249  DCHECK_EQ(N, dY.dim32(1));
250  } else {
251  DCHECK_EQ(X.dim(), 1);
252  DCHECK_EQ(N, dY.numel());
253  }
254 
255  auto* dW = Output(0, W.sizes(), at::dtype<T>());
256  auto* db = Output(1, {N}, at::dtype<T>());
257 
258  // Compute dW
259  math::Gemm<T, Context, Engine>(
260  CblasTrans, CblasNoTrans, N, K, M, 1,
261  dY.template data<T>(), X.template data<T>(),
262  0, dW->template mutable_data<T>(),
263  &context_);
264 
265  comp_r_buf_.Resize(vector<int64_t>());
266  T* comp_data = comp_r_buf_.template mutable_data<T>();
267  math::Sum<T, Context>(
268  Mask.numel(), Mask.template data<T>(), comp_data, &context_);
269  math::Scale<float, T, Context>(
270  1,
271  static_cast<T>(1.) / Mask.numel(),
272  comp_data,
273  comp_data,
274  &context_);
275  // update W size window
276  // Notice here we need to maintain state in OP.
277  // This is new in Caffe2.
278  // And this is something we might need to discuss in the future.
279  // at most mask half of the matrix at time
280  // 1. mask dw with previous mask
281  MaskMatrix<T, Context>(Mask.template mutable_data<T>(),
282  dW->template mutable_data<T>(), N, K);
283  if(*comp_data > *(comp_lb.template data<T>())){
284  iter_offset++;
285  if (iter_offset % window_size == 0) {
286  // TODO(wyiming):do the prune here;
287  sum_buffer_.ResizeLike(W);
288  math::Add<T, Context>(
289  W.numel(),
290  W.template mutable_data<T>(),
291  Ag_dW.template mutable_data<T>(),
292  sum_buffer_.template mutable_data<T>(),
293  &context_);
294  auto* mask_seq_auto = Output(5, W.sizes(), at::dtype<T>());
295  T* mask_seq = mask_seq_auto->template mutable_data<T>();
296  math::Set<T, Context>(N*K, static_cast<T>(0),
297  mask_seq_auto->template mutable_data<T>(), &context_);
298  // 2. find dw below thres but not eq 0
299  int seq_len = MatrixCompare_LT<T>(
300  Ag_dW_ptr->template mutable_data<T>(),
301  *thres.template data<T>(), mask_seq, N, K);
302  // 3. use the mask_seq to update W and dw
303  MaskMatrix_Inc<T, Context>(mask_seq,
304  dW->template mutable_data<T>(),
305  N, K, seq_len, 0);
306  MaskMatrix_Inc<T, Context>(mask_seq,
307  W.template mutable_data<T>(),
308  N, K, seq_len, 0);
309  MaskMatrix_Inc<T, Context>(mask_seq,
310  Mask.template mutable_data<T>(),
311  N, K, seq_len, 0);
312  math::Set<T, Context>(N*K, static_cast<T>(0),
313  Ag_dW.template mutable_data<T>(),
314  &context_);
315  } else {
316  // add dW to Aggregate dW.
317  AggrDW<T, Context>(
318  Ag_dW.template mutable_data<T>(),
319  dW->template mutable_data<T>(),
320  N, K, &context_);
321  }
322  }
323  if (bias_multiplier_.numel() != M) {
324  // If the helper bias multiplier is not M,
325  // reshape and fill it with one.
326  bias_multiplier_.Resize(M);
327  math::Set<T, Context>(
328  M, static_cast<T>(1),
329  bias_multiplier_.template mutable_data<T>(),
330  &context_);
331  }
332  // Compute dB
333  math::Gemv<T, Context>(
334  CblasTrans, M, N, 1, dY.template data<T>(),
335  bias_multiplier_.template data<T>(), 0,
336  db->template mutable_data<T>(),
337  &context_);
338  // Compute dX if necessary.
339  if (OutputSize() == 7) {
340  auto* dX = Output(6, X.sizes(), at::dtype<T>());
341  math::Gemm<T, Context, Engine>(
342  CblasNoTrans, CblasNoTrans, M, K, N, 1,
343  dY.template data<T>(), W.template data<T>(),
344  0, dX->template mutable_data<T>(),
345  &context_);
346  }
347 
348  return true;
349  }
350 
351  protected:
352  Tensor bias_multiplier_{Context::GetDeviceType()};
353  Tensor sum_buffer_{Context::GetDeviceType()};
354  Tensor comp_r_buf_{Context::GetDeviceType()};
355  };
356 
357 } // namespace caffe2
358 
359 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Definition: any.cpp:108
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13