Caffe2 - C++ API
A deep learning, cross platform ML framework
fusion.h
1 
17 #ifndef CAFFE2_OPT_FUSION_H_
18 #define CAFFE2_OPT_FUSION_H_
19 
20 #include "caffe2/core/workspace.h"
21 #include "nomnigraph/Representations/NeuralNet.h"
22 
23 namespace caffe2 {
24 namespace opt {
25 
26 using namespace nom;
27 
28 CAFFE2_API void fuseConvBN(repr::NNModule* nn, caffe2::Workspace* ws);
29 
30 // Generic activation fusion helper.
31 //
32 // \tparam OperationT The operator to be fused.
33 // \tparam ActivationT The activation to be fused.
34 // \param nn Neural network module to be modified in place
35 // \param should_fuse Given a conv op, check whether we want to fuse it with
36 // subsequent relu or not
37 // \param postprocess Functor to postprocess the conv node,
38 // attaching additional attributes if necessary
39 template <typename OperationT, typename ActivationT>
40 C10_EXPORT void fuseActivation(
41  repr::NNModule* nn,
42  std::function<bool(const OperationT& conv)> should_fuse,
43  std::function<void(repr::NNGraph::NodeRef conv_node)> postprocess) {
44  for (auto node_pair : repr::nn::dataIterator<OperationT>(nn->dataFlow)) {
45  repr::NNGraph::NodeRef conv_node;
46  OperationT* conv;
47  std::tie(conv, conv_node) = node_pair;
48 
49  // Check topological feasibility
50  auto conv_outputs = repr::nn::getOutputs(conv_node);
51  if (conv_outputs.size() != 1) {
52  continue;
53  }
54  auto conv_output = conv_outputs.front();
55 
56  auto consumers = repr::nn::getConsumers(conv_output);
57  if (consumers.size() != 1) {
58  continue;
59  }
60  if (!repr::nn::is<ActivationT>(consumers.front())) {
61  continue;
62  }
63  auto relu_node = consumers.front();
64 
65  auto relu_outputs = repr::nn::getOutputs(relu_node);
66  if (relu_outputs.size() != 1) {
67  continue;
68  }
69 
70  // Check feasibility with application specific logic
71  if (!should_fuse(*conv)) {
72  continue;
73  }
74 
75  // Ready to fuse
76  auto relu_output = relu_outputs.front();
77  auto output_tensor = repr::nn::get<repr::Tensor>(relu_output);
78  auto output_node = relu_output;
79  auto input_tensor =
80  repr::nn::get<repr::Tensor>(repr::nn::getInputs(conv_node).front());
81 
82  // Conv cannot be in-place
83  if (output_tensor->getName() != input_tensor->getName()) {
84  nn->dataFlow.replaceNode(conv_output, relu_output);
85  nn->dataFlow.deleteNode(relu_node);
86  nn->dataFlow.deleteNode(conv_output);
87  } else {
88  nn->dataFlow.replaceNode(relu_output, conv_output);
89  output_tensor = repr::nn::get<repr::Tensor>(conv_output);
90  output_node = conv_output;
91  nn->dataFlow.deleteNode(relu_node);
92  nn->dataFlow.deleteNode(relu_output);
93  }
94 
95  // We may have accidentally made the next op in-place
96  // In future iterations of transformations this won't be an issue,
97  // but current caffe2 predictor usage requires things like
98  // external_input and output to be unchanged.
99  bool rectify_inplace = false;
100  for (auto& consumer : repr::nn::getConsumers(output_node)) {
101  for (auto& consumer_output : repr::nn::getOutputs(consumer)) {
102  auto co_name = repr::nn::get<repr::Tensor>(consumer_output)->getName();
103  if (co_name == output_tensor->getName()) {
104  rectify_inplace = true;
105  }
106  }
107  }
108  if (rectify_inplace) {
109  auto new_output = nn->dataFlow.createNode(
110  make_unique<repr::Tensor>(output_tensor->getName() + "_fusion_fix"));
111  nn->dataFlow.replaceNode(output_node, new_output);
112  }
113 
114  // Application specific logic for postprocessing the conv node
115  postprocess(conv_node);
116  }
117 }
118 
119 } // namespace opt
120 } // namespace caffe2
121 
122 #endif // CAFFE2_OPT_FUSION_H_
NodeRef createNode(T &&data)
Creates a node and retains ownership of it.
Definition: Graph.h:240
Definition: Dot.h:16
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
void deleteNode(NodeRef n)
Deletes a node from the graph.
Definition: Graph.h:460
void replaceNode(const NodeRef &oldNode, const NodeRef &newNode)
Replace a node in the graph with another node.
Definition: Graph.h:384