1 #include "caffe2/opt/mobile.h" 2 #include "caffe2/core/logging.h" 3 #include "caffe2/opt/converter.h" 4 #include "caffe2/opt/fusion.h" 5 #include "caffe2/opt/passes.h" 13 for (
auto node : nn->dataFlow.getMutableNodes()) {
15 NOM_REQUIRE_OR_CONT(repr::nn::is<repr::NeuralNetOperator>(node));
18 auto nnOp = repr::nn::get<repr::NeuralNetOperator>(node);
19 NOM_REQUIRE_OR_CONT(isa<nom::repr::Conv>(nnOp));
22 NOM_REQUIRE_OR_CONT(node->getInEdges().size() >= 3);
24 std::string engine =
"NNPACK";
27 bool validTransformCandidate =
true;
28 auto conv = dyn_cast<nom::repr::Conv>(nnOp);
30 NOM_REQUIRE_OR_CONT(conv->getLayout() == nom::repr::Conv::NNLayout::NCHW);
33 for (
auto stride : conv->getStrides()) {
35 validTransformCandidate =
false;
39 NOM_REQUIRE_OR_CONT(validTransformCandidate);
42 const auto& kernelShape = conv->getKernelShape();
43 NOM_REQUIRE_OR_CONT(kernelShape.size() == 2);
46 if (kernelShape[0] != kernelShape[1]) {
47 NOM_REQUIRE_OR_CONT(kernelShape[0] != 1 && kernelShape[1] != 1);
51 auto annotation = conv->getMutableAnnotation();
52 NOM_REQUIRE_OR_CONT(annotation && isa<Caffe2Annotation>(annotation));
54 auto* op = dyn_cast<Caffe2Annotation>(annotation)->getMutableOperatorDef();
55 op->set_engine(engine);
57 auto* precompute_argument = op->add_arg();
58 precompute_argument->set_name(
"convolution_transform_strategy");
59 precompute_argument->set_s(
"PRECOMPUTE");
66 inline bool isNNPACKConvReluEfficient(
67 const std::string& algo,
68 const repr::Conv& conv) {
69 if (algo ==
"AUTO" || algo ==
"") {
70 for (
auto stride : conv.getStrides()) {
75 for (
auto kernel : conv.getKernelShape()) {
80 }
else if (!(algo ==
"WINOGRAD" || algo ==
"WINOGRAD_FP16" ||
81 algo ==
"FT8x8" || algo ==
"FT16x16")) {
90 auto should_fuse = [](
const repr::Conv& conv) {
91 const auto annotation = conv.getAnnotation();
92 if (!annotation || !isa<Caffe2Annotation>(annotation)) {
95 const auto& op = dyn_cast<Caffe2Annotation>(annotation)->getOperatorDef();
98 if (op.engine() !=
"NNPACK") {
101 caffe2::string algo =
"AUTO";
102 for (
const auto arg : op.arg()) {
103 if (arg.name() ==
"algo") {
107 if (!isNNPACKConvReluEfficient(algo, conv)) {
114 auto conv = repr::nn::get<repr::Conv>(conv_node);
115 auto annotation = conv->getMutableAnnotation();
116 if (!annotation || !isa<Caffe2Annotation>(annotation)) {
119 auto* op = dyn_cast<Caffe2Annotation>(annotation)->getMutableOperatorDef();
120 auto* arg = op->add_arg();
121 arg->set_name(
"activation");
125 fuseActivation<repr::Conv, repr::Relu>(nn, should_fuse, postprocess);
128 REGISTER_OPT_PASS_FROM_FUNC(FuseNNPACKConvRelu, fuseNNPACKConvRelu);
129 REGISTER_OPT_PASS_FROM_FUNC(AddNNPACK, addNNPACK);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...