1 #include "caffe2/transforms/common_subexpression_elimination.h" 3 #include "caffe2/core/common.h" 4 #include "caffe2/core/net.h" 5 #include "caffe2/proto/caffe2_pb.h" 9 using transform::Graph;
10 using transform::Node;
15 bool are_nodes_common(
const Graph& g,
int model_idx,
int candidate_idx) {
17 const Node& model_node = g.node(model_idx);
18 const Node& candidate_node = g.node(candidate_idx);
21 if (model_node.op.type() != candidate_node.op.type()) {
29 if (model_node.op.input_size() != candidate_node.op.input_size()) {
33 for (
int i = 0; i < model_node.op.input_size(); i++) {
34 if (candidate_node.op.input(i) != model_node.op.input(i)) {
41 if (model_node.parents.size() != candidate_node.parents.size() ||
43 model_node.parents.begin(),
44 model_node.parents.end(),
45 candidate_node.parents.begin())) {
50 if (model_node.op.output_size() != candidate_node.op.output_size()) {
58 const std::vector<int>& subgraph,
60 if (subgraph.size() == 0) {
61 if (IsWhitelisted(g.node(idx).op.type()))
65 return are_nodes_common(g, subgraph.at(0), idx);
71 const std::vector<int>& subgraph) {
72 if (subgraph.size() >= 2) {
79 const std::vector<int>& subgraph,
86 int new_idx = g.size();
87 OperatorDef new_op = g.node(subgraph[0]).op;
89 new_op.clear_output();
90 for (
const auto& blob : g.node(subgraph[0]).op.output()) {
91 new_op.add_output(
"transform/" + blob);
95 const auto& new_op_parents = g.node(subgraph[0]).parents;
97 for (
auto& parent : new_op_parents) {
98 int parent_idx = parent.first;
101 g.node(parent_idx).children[new_idx] = new_op_parents.at(parent_idx);
104 for (
int i = 0; i < subgraph.size(); i++) {
105 g.node(parent_idx).children.erase(subgraph[i]);
111 Node(new_op,
true, new_op_parents, std::map<
int, std::vector<string>>()));
114 for (
const int x : subgraph) {
117 std::map<string, string> output_renamings;
118 for (
int i = 0; i < new_op.output_size(); i++) {
119 output_renamings[g.node(x).op.output(i)] = g.node(new_idx).op.output(i);
123 for (
auto& child : g.node(x).children) {
124 int child_idx = child.first;
125 std::vector<string> blobs = child.second;
128 for (
string& blob : blobs) {
129 blob = output_renamings.at(blob);
133 g.node(new_idx).children[child_idx] = blobs;
134 g.node(child_idx).parents[new_idx] = blobs;
137 g.node(child_idx).parents.erase(x);
140 for (
int i = 0; i < g.node(child_idx).op.input_size(); i++) {
141 string blob = g.node(child_idx).op.input(i);
142 if (output_renamings.count(blob) > 0) {
143 g.node(child_idx).op.set_input(i, output_renamings.at(blob));
149 g.DeactivateSubgraph(subgraph);
155 CommonSubexpressionElimination,
bool MatchArguments(const OperatorDef &p_op, const OperatorDef &g_op)
This ensures that each named arg that exists in the pattern exists in g_op, is equal in value...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
A simple graph implementation.