17 #ifndef CAFFE2_OPT_FUSION_H_ 18 #define CAFFE2_OPT_FUSION_H_ 20 #include "caffe2/core/workspace.h" 21 #include "nomnigraph/Representations/NeuralNet.h" 39 template <
typename OperationT,
typename ActivationT>
40 C10_EXPORT
void fuseActivation(
42 std::function<
bool(
const OperationT& conv)> should_fuse,
44 for (
auto node_pair : repr::nn::dataIterator<OperationT>(nn->dataFlow)) {
47 std::tie(conv, conv_node) = node_pair;
50 auto conv_outputs = repr::nn::getOutputs(conv_node);
51 if (conv_outputs.size() != 1) {
54 auto conv_output = conv_outputs.front();
56 auto consumers = repr::nn::getConsumers(conv_output);
57 if (consumers.size() != 1) {
60 if (!repr::nn::is<ActivationT>(consumers.front())) {
63 auto relu_node = consumers.front();
65 auto relu_outputs = repr::nn::getOutputs(relu_node);
66 if (relu_outputs.size() != 1) {
71 if (!should_fuse(*conv)) {
76 auto relu_output = relu_outputs.front();
77 auto output_tensor = repr::nn::get<repr::Tensor>(relu_output);
78 auto output_node = relu_output;
80 repr::nn::get<repr::Tensor>(repr::nn::getInputs(conv_node).front());
83 if (output_tensor->getName() != input_tensor->getName()) {
89 output_tensor = repr::nn::get<repr::Tensor>(conv_output);
90 output_node = conv_output;
99 bool rectify_inplace =
false;
100 for (
auto& consumer : repr::nn::getConsumers(output_node)) {
101 for (
auto& consumer_output : repr::nn::getOutputs(consumer)) {
102 auto co_name = repr::nn::get<repr::Tensor>(consumer_output)->getName();
103 if (co_name == output_tensor->getName()) {
104 rectify_inplace =
true;
108 if (rectify_inplace) {
110 make_unique<repr::Tensor>(output_tensor->getName() +
"_fusion_fix"));
115 postprocess(conv_node);
122 #endif // CAFFE2_OPT_FUSION_H_
NodeRef createNode(T &&data)
Creates a node and retains ownership of it.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
void deleteNode(NodeRef n)
Deletes a node from the graph.
void replaceNode(const NodeRef &oldNode, const NodeRef &newNode)
Replace a node in the graph with another node.