Caffe2 - C++ API
A deep learning, cross platform ML framework
mobile.cc
1 #include "caffe2/opt/mobile.h"
2 #include "caffe2/core/logging.h"
3 #include "caffe2/opt/converter.h"
4 #include "caffe2/opt/fusion.h"
5 #include "caffe2/opt/passes.h"
6 
7 namespace caffe2 {
8 namespace opt {
9 
10 using namespace nom;
11 
12 void addNNPACK(repr::NNModule* nn, bool low_memory) {
13  for (auto node : nn->dataFlow.getMutableNodes()) {
14  // Skip blobs.
15  NOM_REQUIRE_OR_CONT(repr::nn::is<repr::NeuralNetOperator>(node));
16 
17  // Check if it is a convolution.
18  auto nnOp = repr::nn::get<repr::NeuralNetOperator>(node);
19  NOM_REQUIRE_OR_CONT(isa<nom::repr::Conv>(nnOp));
20 
21  // Requires X, W, b for NNPACK
22  NOM_REQUIRE_OR_CONT(node->getInEdges().size() >= 3);
23 
24  std::string engine = "NNPACK";
25 
26  // Now do some specific checks to see if an NNPACK engine is correct.
27  bool validTransformCandidate = true;
28  auto conv = dyn_cast<nom::repr::Conv>(nnOp);
29 
30  NOM_REQUIRE_OR_CONT(conv->getLayout() == nom::repr::Conv::NNLayout::NCHW);
31 
32  // NNPACK only supports stride == 1
33  for (auto stride : conv->getStrides()) {
34  if (stride != 1) {
35  validTransformCandidate = false;
36  break;
37  }
38  }
39  NOM_REQUIRE_OR_CONT(validTransformCandidate);
40 
41  // NNPACK only supports 2DConv.
42  const auto& kernelShape = conv->getKernelShape();
43  NOM_REQUIRE_OR_CONT(kernelShape.size() == 2);
44 
45  // Kx1 and 1xK convs are inefficient in NNPACK.
46  if (kernelShape[0] != kernelShape[1]) {
47  NOM_REQUIRE_OR_CONT(kernelShape[0] != 1 && kernelShape[1] != 1);
48  }
49 
50  // We're good to use our engine.
51  auto annotation = conv->getMutableAnnotation();
52  NOM_REQUIRE_OR_CONT(annotation && isa<Caffe2Annotation>(annotation));
53 
54  auto* op = dyn_cast<Caffe2Annotation>(annotation)->getMutableOperatorDef();
55  op->set_engine(engine);
56  if (!low_memory) {
57  auto* precompute_argument = op->add_arg();
58  precompute_argument->set_name("convolution_transform_strategy");
59  precompute_argument->set_s("PRECOMPUTE");
60  }
61  }
62 }
63 
64 namespace {
65 
66 inline bool isNNPACKConvReluEfficient(
67  const std::string& algo,
68  const repr::Conv& conv) {
69  if (algo == "AUTO" || algo == "") {
70  for (auto stride : conv.getStrides()) {
71  if (stride > 1) {
72  return false;
73  }
74  }
75  for (auto kernel : conv.getKernelShape()) {
76  if (kernel < 2) {
77  return false;
78  }
79  }
80  } else if (!(algo == "WINOGRAD" || algo == "WINOGRAD_FP16" ||
81  algo == "FT8x8" || algo == "FT16x16")) {
82  return false;
83  }
84  return true;
85 }
86 
87 } // namespace
88 
89 void fuseNNPACKConvRelu(repr::NNModule* nn) {
90  auto should_fuse = [](const repr::Conv& conv) {
91  const auto annotation = conv.getAnnotation();
92  if (!annotation || !isa<Caffe2Annotation>(annotation)) {
93  return false;
94  }
95  const auto& op = dyn_cast<Caffe2Annotation>(annotation)->getOperatorDef();
96 
97  // We only want to fuse for fast NNPACK convs
98  if (op.engine() != "NNPACK") {
99  return false;
100  }
101  caffe2::string algo = "AUTO";
102  for (const auto arg : op.arg()) {
103  if (arg.name() == "algo") {
104  algo = arg.s();
105  }
106  }
107  if (!isNNPACKConvReluEfficient(algo, conv)) {
108  return false;
109  }
110  return true;
111  };
112 
113  auto postprocess = [](repr::NNGraph::NodeRef conv_node) {
114  auto conv = repr::nn::get<repr::Conv>(conv_node);
115  auto annotation = conv->getMutableAnnotation();
116  if (!annotation || !isa<Caffe2Annotation>(annotation)) {
117  return;
118  }
119  auto* op = dyn_cast<Caffe2Annotation>(annotation)->getMutableOperatorDef();
120  auto* arg = op->add_arg();
121  arg->set_name("activation");
122  arg->set_s("Relu");
123  };
124 
125  fuseActivation<repr::Conv, repr::Relu>(nn, should_fuse, postprocess);
126 }
127 
128 REGISTER_OPT_PASS_FROM_FUNC(FuseNNPACKConvRelu, fuseNNPACKConvRelu);
129 REGISTER_OPT_PASS_FROM_FUNC(AddNNPACK, addNNPACK);
130 
131 } // namespace opt
132 } // namespace caffe2
Definition: Dot.h:16
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13