Caffe2 - C++ API
A deep learning, cross platform ML framework
common_subexpression_elimination.cc
1 #include "caffe2/transforms/common_subexpression_elimination.h"
2 
3 #include "caffe2/core/common.h"
4 #include "caffe2/core/net.h"
5 #include "caffe2/proto/caffe2_pb.h"
6 
7 namespace caffe2 {
8 
9 using transform::Graph;
10 using transform::Node;
11 
12 // Checks if the node at model_idx and the node at candidate_idx are
13 // "common subexpressions". That is, do they have the same function, and
14 // take in the exact same input. If so, then their function is duplicated.
15 bool are_nodes_common(const Graph& g, int model_idx, int candidate_idx) {
16  // We need the candidate operator to match this model_op.
17  const Node& model_node = g.node(model_idx);
18  const Node& candidate_node = g.node(candidate_idx);
19 
20  // Types need to match.
21  if (model_node.op.type() != candidate_node.op.type()) {
22  return false;
23  }
24  // Arguments need to match.
25  if (!MatchArguments(model_node.op, candidate_node.op)) {
26  return false;
27  }
28  // Inputs need to match.
29  if (model_node.op.input_size() != candidate_node.op.input_size()) {
30  return false;
31  }
32  // If any input_blob name is different, this is not okay.
33  for (int i = 0; i < model_node.op.input_size(); i++) {
34  if (candidate_node.op.input(i) != model_node.op.input(i)) {
35  return false;
36  }
37  }
38  // Now, we also need to check that each blob comes from the same parent, or
39  // if they are external (isn't in parents). This is equivalent to a
40  // map equality (since parent edges can only contain up to one blob).
41  if (model_node.parents.size() != candidate_node.parents.size() ||
42  !std::equal(
43  model_node.parents.begin(),
44  model_node.parents.end(),
45  candidate_node.parents.begin())) {
46  return false;
47  }
48 
49  // Output size have to match too.
50  if (model_node.op.output_size() != candidate_node.op.output_size()) {
51  return false;
52  }
53  return true;
54 }
55 
57  const Graph& g,
58  const std::vector<int>& subgraph,
59  int idx) {
60  if (subgraph.size() == 0) {
61  if (IsWhitelisted(g.node(idx).op.type()))
62  return true;
63  return false;
64  }
65  return are_nodes_common(g, subgraph.at(0), idx);
66 }
67 
68 // As long as we have matched more than 2 ops, it is worth eliminating.
70  const Graph& /*g*/,
71  const std::vector<int>& subgraph) {
72  if (subgraph.size() >= 2) {
73  return true;
74  }
75  return false;
76 }
77 
79  const std::vector<int>& subgraph,
80  Graph* g_ptr) {
81  CHECK(g_ptr);
82  auto& g = *g_ptr;
83 
84  // We're gonna make a new node, with the same input as all of the ones in
85  // subgraph, but with their combined children.
86  int new_idx = g.size();
87  OperatorDef new_op = g.node(subgraph[0]).op;
88  // We will need to rename the output blobs.
89  new_op.clear_output();
90  for (const auto& blob : g.node(subgraph[0]).op.output()) {
91  new_op.add_output("transform/" + blob);
92  }
93 
94  // Need to set up the parents.
95  const auto& new_op_parents = g.node(subgraph[0]).parents;
96 
97  for (auto& parent : new_op_parents) {
98  int parent_idx = parent.first;
99 
100  // Make the parents acknowledge us as its new child.
101  g.node(parent_idx).children[new_idx] = new_op_parents.at(parent_idx);
102 
103  // Make the parents disown all our outdated siblings.
104  for (int i = 0; i < subgraph.size(); i++) {
105  g.node(parent_idx).children.erase(subgraph[i]);
106  }
107  }
108 
109  // Add the node now.
110  g.push_node(
111  Node(new_op, true, new_op_parents, std::map<int, std::vector<string>>()));
112 
113  // Now, we need to populate the child edges.
114  for (const int x : subgraph) {
115  // Figure out what the subgraph's node's blobs correspond to in new_op
116  // This is easy, since their indices match.
117  std::map<string, string> output_renamings;
118  for (int i = 0; i < new_op.output_size(); i++) {
119  output_renamings[g.node(x).op.output(i)] = g.node(new_idx).op.output(i);
120  }
121 
122  // Now, time to add the old node's children to new_op
123  for (auto& child : g.node(x).children) {
124  int child_idx = child.first;
125  std::vector<string> blobs = child.second;
126 
127  // rename the old blobs, and use them for our new edge.
128  for (string& blob : blobs) {
129  blob = output_renamings.at(blob);
130  }
131 
132  // create this new edge
133  g.node(new_idx).children[child_idx] = blobs;
134  g.node(child_idx).parents[new_idx] = blobs;
135 
136  // delete the old edge
137  g.node(child_idx).parents.erase(x);
138 
139  // need to rename the inputs of the children too.
140  for (int i = 0; i < g.node(child_idx).op.input_size(); i++) {
141  string blob = g.node(child_idx).op.input(i);
142  if (output_renamings.count(blob) > 0) {
143  g.node(child_idx).op.set_input(i, output_renamings.at(blob));
144  }
145  }
146  }
147  }
148 
149  g.DeactivateSubgraph(subgraph);
150 
151  return true;
152 }
153 
154 REGISTER_TRANSFORM(
155  CommonSubexpressionElimination,
157 
158 } // namespace caffe2
bool PatternRule(const transform::Graph &g, const std::vector< int > &subgraph, int idx) override
The PatternRule essentially answers: Given the current subgraph (ordered), should we append the new n...
bool MatchArguments(const OperatorDef &p_op, const OperatorDef &g_op)
This ensures that each named arg that exists in the pattern exists in g_op, is equal in value...
Definition: graph.cc:228
bool ValidatorRule(const transform::Graph &g, const std::vector< int > &subgraph) override
The ValidatorRule essentially answers: Given a subgraph, can we accept it?
bool ReplaceRule(const std::vector< int > &subgraph, transform::Graph *g_ptr) override
The ReplaceRule actually mutates the graph, and applies the transformation upon the subgraph...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
A simple graph implementation.
Definition: Graph.h:29