Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op_prune.cc
1 
17 #include "caffe2/experiments/operators/fully_connected_op_prune.h"
18 
19 namespace caffe2 {
20 namespace {
21 
22 REGISTER_CPU_OPERATOR(FC_Prune, FullyConnectedOpPrune<float, CPUContext>);
23 REGISTER_CPU_OPERATOR(FCGradient_Prune,
24  FullyConnectedPruneGradientOp<float, CPUContext>);
25 /* 8 Inputs:
26  * X W Mask bias Ag_dw Mask_seq thres comp_lb
27  * */
28 OPERATOR_SCHEMA(FC_Prune).NumInputs(8).NumOutputs(1, 2);
29 OPERATOR_SCHEMA(FCGradient_Prune).NumInputs(8).NumOutputs(6, 7)
30  .AllowInplace({{1, 2}, {2, 3}, {4, 4}, {5, 5}});
31 
32 class GetFCPruneGradient : public GradientMakerBase {
33  using GradientMakerBase::GradientMakerBase;
34  vector<OperatorDef> GetGradientDefs() override {
35  CAFFE_ENFORCE_EQ(def_.input_size(), 8);
36  return SingleGradientDef(
37  "FCGradient_Prune", "",
38  vector<string>{I(0), I(1), I(2), GO(0), I(4), I(5), I(6), I(7)},
39  vector<string>{GI(1), GI(3), I(1), I(2), I(4), I(5), GI(0)});
40  }
41 };
42 REGISTER_GRADIENT(FC_Prune, GetFCPruneGradient);
43 } // namespace
44 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.