17 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_ 18 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_ 20 #include "caffe2/core/context.h" 21 #include "caffe2/core/operator.h" 22 #include "caffe2/utils/math.h" 29 using Shape = std::array<int, N>;
32 const std::vector<int64_t>& shape(Shape<N> vs) {
33 static thread_local std::vector<int64_t> cache;
34 cache.resize(vs.size());
35 for (
auto i = 0; i < vs.size(); ++i) {
41 inline const std::vector<int64_t>& shape(
int i) {
42 return shape<1>(Shape<1>({i}));
45 inline const std::vector<int64_t>& shape(
int i,
int j) {
46 return shape<2>(Shape<2>({i, j}));
49 template <
typename T,
class Context>
50 void MaskMatrix(
const T* mask,
T* mat,
53 template <
typename T,
class Context>
54 void MaskMatrix_Inc(
T* mask_seq,
T* mat,
55 int M,
int N,
int seq_len,
T target);
57 template <
typename T,
class Context>
58 void AggrDW(
T* ag_dw,
const T* dw,
int N,
int K, Context* context);
61 int MatrixCompare_LT(
const T* mat,
float thres,
62 T* mask_seq,
int M,
int N);
70 void MaskMatrix<float, CPUContext>(
71 const float* mask,
float* mat,
int M,
int N) {
73 for (
int i = 0; i < M; ++i) {
74 for (
int j = 0; j < N; ++j) {
75 mat[offset] = mask[offset]? mat[offset] : 0;
82 void MaskMatrix_Inc<float, CPUContext>(
89 for (
int i = 0; i < seq_len; ++i) {
93 mat[
static_cast<int>(mask_seq[i])] = target;
98 void AggrDW<float, CPUContext>(
99 float* ag_dw,
const float* dw,
100 int N,
int K, CPUContext* context) {
101 math::Add<float, CPUContext>(N*K, dw, ag_dw, ag_dw, context);
105 int MatrixCompare_LT<float>(
106 const float* mat,
float thres,
107 float* mask_seq,
int M,
int N) {
110 for (
int i = 0 ; i < M; ++i) {
111 for (
int j = 0; j < N; ++j) {
112 if (mat[offset] != 0 &&
113 (mat[offset] < thres && mat[offset] > -thres)) {
114 mask_seq[seq_len++] =
static_cast<float>(offset);
125 template <
typename T,
class Context,
class Engine=DefaultEngine>
128 USE_OPERATOR_CONTEXT_FUNCTIONS;
133 bool RunOnDevice()
override {
134 const auto& X =
Input(0);
135 const auto& W =
Input(1);
136 const auto& Mask =
Input(2);
137 const auto& b =
Input(3);
139 CAFFE_ENFORCE_GE(X.dim(), 1);
140 CAFFE_ENFORCE_GE(W.dim(), 2);
141 if (X.dim() > 2 || W.dim() > 2) {
142 VLOG(1) <<
"Using legacy support for arbitrary input and weight " 145 CAFFE_ENFORCE_EQ(b.dim(), 1);
147 int M = X.dim() > 1 ? X.dim32(0) : 1;
149 int K = X.numel() / M;
152 CAFFE_ENFORCE_EQ(K, W.numel() / W.dim32(0));
153 CAFFE_ENFORCE_EQ(N, b.dim32(0));
154 std::vector<int64_t> dims;
160 auto* Y = Output(0, dims, at::dtype<T>());
162 math::Gemm<T, Context, Engine>(
163 CblasNoTrans, CblasTrans, M, N, K, 1, X.template data<T>(),
164 W.template data<T>(), 0, Y->template mutable_data<T>(),
167 if (bias_multiplier_.numel() != M) {
170 bias_multiplier_.Resize(M);
171 math::Set<T, Context>(
172 M,
static_cast<T>(1),
173 bias_multiplier_.template mutable_data<T>(),
176 math::Gemm<T, Context, Engine>(
177 CblasNoTrans, CblasNoTrans, M, N, 1, 1,
178 bias_multiplier_.template data<T>(), b.template data<T>(), 1,
179 Y->template mutable_data<T>(), &context_);
180 if (OutputSize() == 2){
181 auto* Comp_rate = Output(1, vector<int64_t>(), at::dtype<T>());
182 T* comp_data = Comp_rate->template mutable_data<T>();
183 math::Sum<T, Context>(
184 Mask.numel(), Mask.template data<T>(), comp_data, &context_);
185 math::Scale<float, T, Context>(
187 static_cast<T>(1.) / Mask.numel(),
196 Tensor bias_multiplier_{Context::GetDeviceType()};
199 template <
typename T,
class Context,
class Engine=DefaultEngine>
204 USE_OPERATOR_CONTEXT_FUNCTIONS;
206 (
const OperatorDef& operator_def,
Workspace* ws)
210 bool RunOnDevice()
override {
211 const auto& X =
Input(0);
213 auto* W_ptr = Output(2);
216 auto* Mask_ptr = Output(3);
217 auto& Mask = *Mask_ptr;
218 const auto& dY =
Input(3);
220 auto* Ag_dW_ptr = Output(4);
221 auto& Ag_dW = *Ag_dW_ptr;
225 auto& thres =
Input(6);
227 auto& comp_lb =
Input(7);
228 DCHECK_GE(X.dim(), 1);
229 DCHECK_GE(W.dim(), 2);
230 DCHECK_LE(dY.dim(), 2);
232 int M = X.dim() > 1 ? X.dim32(0) : 1;
234 int K = X.numel() / M;
238 int window_size = 100;
242 DCHECK_EQ(Mask.dim32(0), W.dim32(0));
243 DCHECK_EQ(Mask.dim32(1), W.dim32(1));
244 DCHECK_EQ(Ag_dW.dim32(0), W.dim32(0));
245 DCHECK_EQ(Ag_dW.dim32(1), W.dim32(1));
246 DCHECK_EQ(K, W.numel() / W.dim32(0));
248 DCHECK_EQ(M, dY.dim32(0));
249 DCHECK_EQ(N, dY.dim32(1));
251 DCHECK_EQ(X.dim(), 1);
252 DCHECK_EQ(N, dY.numel());
255 auto* dW = Output(0, W.sizes(), at::dtype<T>());
256 auto* db = Output(1, {N}, at::dtype<T>());
259 math::Gemm<T, Context, Engine>(
260 CblasTrans, CblasNoTrans, N, K, M, 1,
261 dY.template data<T>(), X.template data<T>(),
262 0, dW->template mutable_data<T>(),
265 comp_r_buf_.Resize(vector<int64_t>());
266 T* comp_data = comp_r_buf_.template mutable_data<T>();
267 math::Sum<T, Context>(
268 Mask.numel(), Mask.template data<T>(), comp_data, &context_);
269 math::Scale<float, T, Context>(
271 static_cast<T>(1.) / Mask.numel(),
281 MaskMatrix<T, Context>(Mask.template mutable_data<T>(),
282 dW->template mutable_data<T>(), N, K);
283 if(*comp_data > *(comp_lb.template data<T>())){
285 if (iter_offset % window_size == 0) {
287 sum_buffer_.ResizeLike(W);
288 math::Add<T, Context>(
290 W.template mutable_data<T>(),
291 Ag_dW.template mutable_data<T>(),
292 sum_buffer_.template mutable_data<T>(),
294 auto* mask_seq_auto = Output(5, W.sizes(), at::dtype<T>());
295 T* mask_seq = mask_seq_auto->template mutable_data<T>();
296 math::Set<T, Context>(N*K,
static_cast<T>(0),
297 mask_seq_auto->template mutable_data<T>(), &context_);
299 int seq_len = MatrixCompare_LT<T>(
300 Ag_dW_ptr->template mutable_data<T>(),
301 *thres.template data<T>(), mask_seq, N, K);
303 MaskMatrix_Inc<T, Context>(mask_seq,
304 dW->template mutable_data<T>(),
306 MaskMatrix_Inc<T, Context>(mask_seq,
307 W.template mutable_data<T>(),
309 MaskMatrix_Inc<T, Context>(mask_seq,
310 Mask.template mutable_data<T>(),
312 math::Set<T, Context>(N*K,
static_cast<T>(0),
313 Ag_dW.template mutable_data<T>(),
318 Ag_dW.template mutable_data<T>(),
319 dW->template mutable_data<T>(),
323 if (bias_multiplier_.numel() != M) {
326 bias_multiplier_.Resize(M);
327 math::Set<T, Context>(
328 M,
static_cast<T>(1),
329 bias_multiplier_.template mutable_data<T>(),
333 math::Gemv<T, Context>(
334 CblasTrans, M, N, 1, dY.template data<T>(),
335 bias_multiplier_.template data<T>(), 0,
336 db->template mutable_data<T>(),
339 if (OutputSize() == 7) {
340 auto* dX = Output(6, X.sizes(), at::dtype<T>());
341 math::Gemm<T, Context, Engine>(
342 CblasNoTrans, CblasNoTrans, M, K, N, 1,
343 dY.template data<T>(), W.template data<T>(),
344 0, dX->template mutable_data<T>(),
352 Tensor bias_multiplier_{Context::GetDeviceType()};
353 Tensor sum_buffer_{Context::GetDeviceType()};
354 Tensor comp_r_buf_{Context::GetDeviceType()};
359 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...