Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op_prune.h
1 
17 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_
18 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26  namespace {
27 
28  template<int N>
29  using Shape = std::array<int, N>;
30 
31  template<int N>
32  const std::vector<TIndex>& shape(Shape<N> vs) {
33  static thread_local std::vector<TIndex> cache;
34  cache.resize(vs.size());
35  for (auto i = 0; i < vs.size(); ++i) {
36  cache[i] = vs[i];
37  }
38  return cache;
39  }
40 
41  inline const std::vector<TIndex>& shape(int i) {
42  return shape<1>(Shape<1>({i}));
43  }
44 
45  inline const std::vector<TIndex>& shape(int i, int j) {
46  return shape<2>(Shape<2>({i, j}));
47  }
48 
49  template <typename T, class Context>
50  void MaskMatrix(const T* mask, T* mat,
51  int M, int N);
52 
53  template <typename T, class Context>
54  void MaskMatrix_Inc(T* mask_seq, T* mat,
55  int M, int N, int seq_len, T target);
56 
57  template <typename T, class Context>
58  void AggrDW(T* ag_dw, const T* dw, int N, int K, Context* context);
59 
60  template <typename T>
61  int MatrixCompare_LT(const T* mat, float thres,
62  T* mask_seq, int M, int N);
63 
64  // TODO(wyiming): write an incremental Mask
65  // Incremental Mask: only give the new mask positions;
66  // Assuming that weights masked will not be mask again;
67  // The incremental mask can also be used to update mask matrix;
68  // But this will include template for bool and float;
69  template <>
70  void MaskMatrix<float, CPUContext>(
71  const float* mask, float* mat, int M, int N) {
72  int offset = 0;
73  for (int i = 0; i < M; ++i) {
74  for (int j = 0; j < N; ++j) {
75  mat[offset] = mask[offset]? mat[offset] : 0;
76  offset++;
77  }
78  }
79  }
80 
81  template <>
82  void MaskMatrix_Inc<float, CPUContext>(
83  float* mask_seq,
84  float* mat,
85  int /*M*/,
86  int /*N*/,
87  int seq_len,
88  float target) {
89  for (int i = 0; i < seq_len; ++i) {
90  // assume that the mask_seq is smaller than size
91  // Although it seems that random access gets bad performance,
92  // we make sure that seq is in order;
93  mat[static_cast<int>(mask_seq[i])] = target;
94  }
95  }
96 
97  template <>
98  void AggrDW<float, CPUContext>(
99  float* ag_dw, const float* dw,
100  int N, int K, CPUContext* context) {
101  math::Add<float, CPUContext>(N*K, dw, ag_dw, ag_dw, context);
102  }
103 
104  template <>
105  int MatrixCompare_LT<float>(
106  const float* mat, float thres,
107  float* mask_seq, int M, int N) {
108  int seq_len = 0;
109  int offset = 0;
110  for (int i = 0 ; i < M; ++i) {
111  for (int j = 0; j < N; ++j) {
112  if (mat[offset] != 0 &&
113  (mat[offset] < thres && mat[offset] > -thres)) {
114  mask_seq[seq_len++] = static_cast<float>(offset);
115  }
116  offset++;
117  }
118  }
119  return seq_len;
120  }
121 
122  }
123 
124  // This is Caffe's InnerProductOp, with a name that fits its purpose better.
125  template <typename T, class Context, class Engine=DefaultEngine>
126  class FullyConnectedOpPrune final : public Operator<Context> {
127  public:
128  USE_OPERATOR_CONTEXT_FUNCTIONS;
129  FullyConnectedOpPrune(const OperatorDef& operator_def, Workspace* ws)
130  : Operator<Context>(operator_def, ws) {}
132 
133  bool RunOnDevice() override {
134  const auto& X = Input(0);
135  const auto& W = Input(1);
136  const auto& Mask = Input(2);
137  const auto& b = Input(3);
138  auto* Y = Output(0);
139  CAFFE_ENFORCE_GE(X.ndim(), 1);
140  CAFFE_ENFORCE_GE(W.ndim(), 2);
141  if (X.ndim() > 2 || W.ndim() > 2) {
142  VLOG(1) << "Using legacy support for arbitrary input and weight "
143  "dimensions.";
144  }
145  CAFFE_ENFORCE_EQ(b.ndim(), 1);
146  // batch size
147  int M = X.ndim() > 1 ? X.dim32(0) : 1;
148  // Feature dimension
149  int K = X.size() / M;
150  // number of outputs.
151  int N = W.dim32(0);
152  CAFFE_ENFORCE_EQ(K, W.size() / W.dim32(0));
153  CAFFE_ENFORCE_EQ(N, b.dim32(0));
154  if (X.ndim() > 1) {
155  Y->Resize(M, N);
156  } else {
157  Y->Resize(N);
158  }
159  // W * x
160  math::Gemm<T, Context, Engine>(
161  CblasNoTrans, CblasTrans, M, N, K, 1, X.template data<T>(),
162  W.template data<T>(), 0, Y->template mutable_data<T>(),
163  &context_);
164  // Add bias term
165  if (bias_multiplier_.size() != M) {
166  // If the helper bias multiplier is not M,
167  // reshape and fill it with one.
168  bias_multiplier_.Resize(M);
169  math::Set<T, Context>(
170  M, static_cast<T>(1),
171  bias_multiplier_.template mutable_data<T>(),
172  &context_);
173  }
174  math::Gemm<T, Context, Engine>(
175  CblasNoTrans, CblasNoTrans, M, N, 1, 1,
176  bias_multiplier_.template data<T>(), b.template data<T>(), 1,
177  Y->template mutable_data<T>(), &context_);
178  if (OutputSize() == 2){
179  auto* Comp_rate = Output(1);
180  Comp_rate->Resize(vector<TIndex>());
181  T* comp_data = Comp_rate->template mutable_data<T>();
182  math::Sum<T, Context>(
183  Mask.size(), Mask.template data<T>(), comp_data, &context_);
184  math::Scale<T, Context>(
185  1, static_cast<T>(1.) / Mask.size(), comp_data, comp_data,
186  &context_);
187  }
188  return true;
189  }
190 
191  protected:
192  Tensor<Context> bias_multiplier_;
193  };
194 
195  template <typename T, class Context, class Engine=DefaultEngine>
196  class FullyConnectedPruneGradientOp : public Operator<Context> {
197  public:
198  int iter_offset;
199  public:
200  USE_OPERATOR_CONTEXT_FUNCTIONS;
202  (const OperatorDef& operator_def, Workspace* ws)
203  : Operator<Context>(operator_def, ws) { iter_offset = 0; }
205 
206  bool RunOnDevice() override {
207  const auto& X = Input(0);
208  //const auto& W = Input(1);
209  auto* W_ptr = Output(2);
210  auto& W = *W_ptr;
211  //const auto& Mask = Input(2);
212  auto* Mask_ptr = Output(3);
213  auto& Mask = *Mask_ptr;
214  const auto& dY = Input(3);
215  //const auto& Ag_dW = Input(4);
216  auto* Ag_dW_ptr = Output(4);
217  auto& Ag_dW = *Ag_dW_ptr;
218  // it is also the Input(5)
219  auto* mask_seq_auto = Output(5);
220  // how about get threshold
221  auto& thres = Input(6);
222  //TODO(wyiming): check comp_lb is a float
223  auto& comp_lb = Input(7);
224  DCHECK_GE(X.ndim(), 1);
225  DCHECK_GE(W.ndim(), 2);
226  DCHECK_LE(dY.ndim(), 2);
227  // batch size
228  int M = X.ndim() > 1 ? X.dim32(0) : 1;
229  // Feature dimension
230  int K = X.size() / M;
231  // number of outputs.
232  int N = W.dim32(0);
233  // TODO(wyiming): add this window_size to workspace?
234  int window_size = 100;
235  // TODO(wyiming): this threshold should be
236  // based on distribution of the layer weight
237  float thr = 0.01;
238  DCHECK_EQ(Mask.dim32(0), W.dim32(0));
239  DCHECK_EQ(Mask.dim32(1), W.dim32(1));
240  DCHECK_EQ(Ag_dW.dim32(0), W.dim32(0));
241  DCHECK_EQ(Ag_dW.dim32(1), W.dim32(1));
242  DCHECK_EQ(K, W.size() / W.dim32(0));
243  if (dY.ndim() > 1) {
244  DCHECK_EQ(M, dY.dim32(0));
245  DCHECK_EQ(N, dY.dim32(1));
246  } else {
247  DCHECK_EQ(X.ndim(), 1);
248  DCHECK_EQ(N, dY.size());
249  }
250  auto* dW = Output(0);
251  auto* db = Output(1);
252  dW->ResizeLike(W);
253  db->Resize(N);
254 
255  // Compute dW
256  math::Gemm<T, Context, Engine>(
257  CblasTrans, CblasNoTrans, N, K, M, 1,
258  dY.template data<T>(), X.template data<T>(),
259  0, dW->template mutable_data<T>(),
260  &context_);
261 
262  comp_r_buf_.Resize(vector<TIndex>());
263  T* comp_data = comp_r_buf_.template mutable_data<T>();
264  math::Sum<T, Context>(
265  Mask.size(), Mask.template data<T>(), comp_data, &context_);
266  math::Scale<T, Context>(
267  1, static_cast<T>(1.) / Mask.size(), comp_data, comp_data,
268  &context_);
269  // update W size window
270  // Notice here we need to maintain state in OP.
271  // This is new in Caffe2.
272  // And this is something we might need to discuss in the future.
273  // at most mask half of the matrix at time
274  // 1. mask dw with previous mask
275  MaskMatrix<T, Context>(Mask.template mutable_data<T>(),
276  dW->template mutable_data<T>(), N, K);
277  if(*comp_data > *(comp_lb.template data<T>())){
278  iter_offset++;
279  if (iter_offset % window_size == 0) {
280  // TODO(wyiming):do the prune here;
281  sum_buffer_.ResizeLike(W);
282  math::Add<T, Context>(W.size(),
283  W.template mutable_data<T>(),
284  Ag_dW.template mutable_data<T>(),
285  sum_buffer_.template mutable_data<T>(),
286  &context_);
287  mask_seq_auto->ResizeLike(W);
288  T* mask_seq = mask_seq_auto->template mutable_data<T>();
289  math::Set<T, Context>(N*K, static_cast<T>(0),
290  mask_seq_auto->template mutable_data<T>(), &context_);
291  // 2. find dw below thres but not eq 0
292  int seq_len = MatrixCompare_LT<T>(
293  Ag_dW_ptr->template mutable_data<T>(),
294  *thres.template data<T>(), mask_seq, N, K);
295  // 3. use the mask_seq to update W and dw
296  MaskMatrix_Inc<T, Context>(mask_seq,
297  dW->template mutable_data<T>(),
298  N, K, seq_len, 0);
299  MaskMatrix_Inc<T, Context>(mask_seq,
300  W.template mutable_data<T>(),
301  N, K, seq_len, 0);
302  MaskMatrix_Inc<T, Context>(mask_seq,
303  Mask.template mutable_data<T>(),
304  N, K, seq_len, 0);
305  math::Set<T, Context>(N*K, static_cast<T>(0),
306  Ag_dW.template mutable_data<T>(),
307  &context_);
308  } else {
309  // add dW to Aggregate dW.
310  AggrDW<T, Context>(
311  Ag_dW.template mutable_data<T>(),
312  dW->template mutable_data<T>(),
313  N, K, &context_);
314  }
315  }
316  if (bias_multiplier_.size() != M) {
317  // If the helper bias multiplier is not M,
318  // reshape and fill it with one.
319  bias_multiplier_.Resize(M);
320  math::Set<T, Context>(
321  M, static_cast<T>(1),
322  bias_multiplier_.template mutable_data<T>(),
323  &context_);
324  }
325  // Compute dB
326  math::Gemv<T, Context>(
327  CblasTrans, M, N, 1, dY.template data<T>(),
328  bias_multiplier_.template data<T>(), 0,
329  db->template mutable_data<T>(),
330  &context_);
331  // Compute dX if necessary.
332  if (OutputSize() == 7) {
333  auto* dX = Output(6);
334  dX->ResizeLike(X);
335  math::Gemm<T, Context, Engine>(
336  CblasNoTrans, CblasNoTrans, M, K, N, 1,
337  dY.template data<T>(), W.template data<T>(),
338  0, dX->template mutable_data<T>(),
339  &context_);
340  }
341 
342  return true;
343  }
344 
345  protected:
346  Tensor<Context> bias_multiplier_;
347  Tensor<Context> sum_buffer_;
348  Tensor<Context> comp_r_buf_;
349  };
350 
351 } // namespace caffe2
352 
353 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.