Caffe2 - C++ API
A deep learning, cross platform ML framework
common_subexpression_elimination.cc
1 
17 #include "caffe2/transforms/common_subexpression_elimination.h"
18 
19 #include "caffe2/core/common.h"
20 #include "caffe2/core/net.h"
21 #include "caffe2/proto/caffe2.pb.h"
22 
23 namespace caffe2 {
24 
25 using transform::Graph;
26 using transform::Node;
27 
28 // Checks if the node at model_idx and the node at candidate_idx are
29 // "common subexpressions". That is, do they have the same function, and
30 // take in the exact same input. If so, then their function is duplicated.
31 bool are_nodes_common(const Graph& g, int model_idx, int candidate_idx) {
32  // We need the candidate operator to match this model_op.
33  const Node& model_node = g.node(model_idx);
34  const Node& candidate_node = g.node(candidate_idx);
35 
36  // Types need to match.
37  if (model_node.op.type() != candidate_node.op.type()) {
38  return false;
39  }
40  // Arguments need to match.
41  if (!MatchArguments(model_node.op, candidate_node.op)) {
42  return false;
43  }
44  // Inputs need to match.
45  if (model_node.op.input_size() != candidate_node.op.input_size()) {
46  return false;
47  }
48  // If any input_blob name is different, this is not okay.
49  for (int i = 0; i < model_node.op.input_size(); i++) {
50  if (candidate_node.op.input(i) != model_node.op.input(i)) {
51  return false;
52  }
53  }
54  // Now, we also need to check that each blob comes from the same parent, or
55  // if they are external (isn't in parents). This is equivalent to a
56  // map equality (since parent edges can only contain up to one blob).
57  if (model_node.parents.size() != candidate_node.parents.size() ||
58  !std::equal(
59  model_node.parents.begin(),
60  model_node.parents.end(),
61  candidate_node.parents.begin())) {
62  return false;
63  }
64 
65  // Output size have to match too.
66  if (model_node.op.output_size() != candidate_node.op.output_size()) {
67  return false;
68  }
69  return true;
70 }
71 
73  const Graph& g,
74  const std::vector<int>& subgraph,
75  int idx) {
76  if (subgraph.size() == 0) {
77  if (IsWhitelisted(g.node(idx).op.type()))
78  return true;
79  return false;
80  }
81  return are_nodes_common(g, subgraph.at(0), idx);
82 }
83 
84 // As long as we have matched more than 2 ops, it is worth eliminating.
86  const Graph& g,
87  const std::vector<int>& subgraph) {
88  if (subgraph.size() >= 2) {
89  return true;
90  }
91  return false;
92 }
93 
95  const std::vector<int>& subgraph,
96  Graph* g_ptr) {
97  CHECK(g_ptr);
98  auto& g = *g_ptr;
99 
100  // We're gonna make a new node, with the same input as all of the ones in
101  // subgraph, but with their combined children.
102  int new_idx = g.size();
103  OperatorDef new_op = g.node(subgraph[0]).op;
104  // We will need to rename the output blobs.
105  new_op.clear_output();
106  for (const auto& blob : g.node(subgraph[0]).op.output()) {
107  new_op.add_output("transform/" + blob);
108  }
109 
110  // Need to set up the parents.
111  const auto& new_op_parents = g.node(subgraph[0]).parents;
112 
113  for (auto& parent : new_op_parents) {
114  int parent_idx = parent.first;
115 
116  // Make the parents acknowledge us as its new child.
117  g.node(parent_idx).children[new_idx] = new_op_parents.at(parent_idx);
118 
119  // Make the parents disown all our outdated siblings.
120  for (int i = 0; i < subgraph.size(); i++) {
121  g.node(parent_idx).children.erase(subgraph[i]);
122  }
123  }
124 
125  // Add the node now.
126  g.push_node(
127  Node(new_op, true, new_op_parents, std::map<int, std::vector<string>>()));
128 
129  // Now, we need to populate the child edges.
130  for (const int x : subgraph) {
131  // Figure out what the subgraph's node's blobs correspond to in new_op
132  // This is easy, since their indices match.
133  std::map<string, string> output_renamings;
134  for (int i = 0; i < new_op.output_size(); i++) {
135  output_renamings[g.node(x).op.output(i)] = g.node(new_idx).op.output(i);
136  }
137 
138  // Now, time to add the old node's children to new_op
139  for (auto& child : g.node(x).children) {
140  int child_idx = child.first;
141  std::vector<string> blobs = child.second;
142 
143  // rename the old blobs, and use them for our new edge.
144  for (string& blob : blobs) {
145  blob = output_renamings.at(blob);
146  }
147 
148  // create this new edge
149  g.node(new_idx).children[child_idx] = blobs;
150  g.node(child_idx).parents[new_idx] = blobs;
151 
152  // delete the old edge
153  g.node(child_idx).parents.erase(x);
154 
155  // need to rename the inputs of the children too.
156  for (int i = 0; i < g.node(child_idx).op.input_size(); i++) {
157  string blob = g.node(child_idx).op.input(i);
158  if (output_renamings.count(blob) > 0) {
159  g.node(child_idx).op.set_input(i, output_renamings.at(blob));
160  }
161  }
162  }
163  }
164 
165  g.DeactivateSubgraph(subgraph);
166 
167  return true;
168 }
169 
170 REGISTER_TRANSFORM(
171  CommonSubexpressionElimination,
173 
174 } // namespace caffe2
bool PatternRule(const transform::Graph &g, const std::vector< int > &subgraph, int idx) override
The PatternRule essentially answers: Given the current subgraph (ordered), should we append the new n...
bool ValidatorRule(const transform::Graph &g, const std::vector< int > &subgraph) override
The ValidatorRule essentially answers: Given a subgraph, can we accept it?
bool ReplaceRule(const std::vector< int > &subgraph, transform::Graph *g_ptr) override
The ReplaceRule actually mutates the graph, and applies the transformation upon the subgraph...
bool MatchArguments(const OperatorDef &p_op, const OperatorDef &g_op)
This ensures that each named arg that exists in the pattern exists in g_op, is equal in value...
Definition: graph.cc:244
Copyright (c) 2016-present, Facebook, Inc.