Caffe2 - C++ API
A deep learning, cross platform ML framework
pattern_net_transform.cc
1 #include "caffe2/transforms/pattern_net_transform.h"
2 
3 #include "caffe2/core/common.h"
4 #include "caffe2/core/logging.h"
5 #include "caffe2/core/net.h"
6 #include "caffe2/proto/caffe2_pb.h"
7 
8 namespace caffe2 {
9 
10 // First, single source traverse through the netdef.
11 // This ensures all newly ordered are reachable from their prefix subset
12 // Outputs a permutation of the operators.
13 std::vector<int> PatternNetTransform::GetPatternTraversalOrder(
14  const transform::Graph& graph) {
15  std::vector<bool> visited(graph.size(), false);
16  std::vector<int> ordered_ops;
17  std::queue<int> q;
18  if (graph.size() > 0) {
19  q.push(0);
20  ordered_ops.push_back(0);
21  visited[0] = true;
22  }
23  while (!q.empty()) {
24  int idx = q.front();
25  q.pop();
26  for (const auto& edge : graph.node(idx).children) {
27  int x = edge.first;
28  if (!visited[x]) {
29  q.push(x);
30  ordered_ops.push_back(x);
31  visited[x] = true;
32  }
33  }
34  for (const auto& edge : graph.node(idx).parents) {
35  int x = edge.first;
36  if (!visited[x]) {
37  q.push(x);
38  ordered_ops.push_back(x);
39  visited[x] = true;
40  }
41  }
42  }
43  CAFFE_ENFORCE(
44  ordered_ops.size() == graph.size(), "Pattern graph must be connected.");
45  return ordered_ops;
46 }
47 
48 bool compare_ops(
49  const OperatorDef& p_op,
50  const OperatorDef& g_op,
51  bool arg_match) {
52  // must specify a type for pattern operators
53  CAFFE_ENFORCE(
54  p_op.has_type(), "Types must be specified for all pattern operators.");
55  if (!MatchStrings(p_op.type(), g_op.type())) {
56  return false;
57  }
58  // ensure number of inputs are the same
59  if (p_op.input().size() != g_op.input().size()) {
60  return false;
61  }
62 
63  // ensure number of outputs are the same
64  if (p_op.output().size() != g_op.output().size()) {
65  return false;
66  }
67 
68  if (p_op.has_device_option()) {
69  if (!g_op.has_device_option() ||
70  p_op.device_option().device_type() !=
71  g_op.device_option().device_type()) {
72  return false;
73  }
74  }
75 
76  // make sure engine is the same (if specified in pattern)
77  if (p_op.has_engine() && !MatchStrings(p_op.engine(), g_op.engine())) {
78  return false;
79  }
80  // If argument_match is specified, make sure those are the same.
81  if (arg_match) {
82  if (!MatchArguments(p_op, g_op)) {
83  return false;
84  }
85  }
86  return true;
87 }
88 
89 // g.node(subgraph[i]) should match p_.node(ordered_ops_[i])
90 // g.node(g_idx) should match p_.node(p_idx)
92  const transform::Graph& g,
93  const std::vector<int>& subgraph,
94  int g_idx) {
95  if (subgraph.size() >= ordered_ops_.size()) {
96  return false;
97  }
98  int p_idx = ordered_ops_[subgraph.size()];
99 
100  if (!compare_ops(p_.node(p_idx).op, g.node(g_idx).op, argument_match_)) {
101  return false;
102  }
103 
104  // Let's say ordered_ops_ is [0, 2, 1], with 0 -> 2 being an edge
105  // When we try to match onto the second element, let's say our
106  // subgraph so far is [4], with it trying to become [4, 5].
107  // Then, we need to show that since 0 -> 2 is an edge is ordered_ops_,
108  // 4 must be a direct parent of 5 in the subgraph
109  // (the indices must match).
110  // Similarly, assume there is an edge from 1 -> 2 in p_.
111  // When trying to match [4, 5] to [4, 5, 7], we must verify that
112  // there exists an edge from 7 -> 5 in G.
113  for (const auto& edge : p_.node(p_idx).parents) {
114  int parent = edge.first;
115  // g_idx doesn't have parent in subgraph that p_[p_idx] has
116  // inverse_ops_ gets the index of a p_idx inside of ordered_ops_.
117  if (inverse_ops_[parent] < subgraph.size() &&
118  g.node(g_idx).parents.count(subgraph[inverse_ops_[parent]]) == 0) {
119  return false;
120  }
121  }
122 
123  for (const auto& edge : p_.node(p_idx).children) {
124  int child = edge.first;
125  if (inverse_ops_[child] < subgraph.size() &&
126  g.node(g_idx).children.count(subgraph[inverse_ops_[child]]) == 0) {
127  return false;
128  }
129  }
130  return true;
131 }
132 
134  const transform::Graph& /*g*/,
135  const std::vector<int>& subgraph) {
136  // Due to strict PatternRule, it suffices to simply check for size
137  return subgraph.size() == p_.size();
138 }
139 
141  const std::vector<int>& match,
142  transform::Graph* g_ptr) {
143  CHECK(g_ptr);
144  auto& g = *g_ptr;
145 
146  ssa_id_++;
147 
148  // Map of PatternNet blob name to Matched blob name.
149  // Figures out how to rename the pattern_net to make the replacement fit.
150  std::unordered_map<string, string> external_renaming;
151 
152  // Figure out blob renamings
153  for (int i = 0; i < match.size(); i++) {
154  int g_idx = match[i];
155  int p_idx = ordered_ops_[i];
156  for (int j = 0; j < p_.node(p_idx).op.input().size(); j++) {
157  string p_blob = p_.node(p_idx).op.input(j);
158  string g_blob = g.node(g_idx).op.input(j);
159  if (p_.external_input().count(p_blob)) {
160  external_renaming[p_blob] = g_blob;
161  }
162  }
163  for (int j = 0; j < p_.node(p_idx).op.output().size(); j++) {
164  string p_blob = p_.node(p_idx).op.output(j);
165  string g_blob = g.node(g_idx).op.output(j);
166  if (p_.external_output().count(p_blob)) {
167  external_renaming[p_blob] = g_blob;
168  }
169  }
170  }
171 
172  auto input_list = g.GetSubgraphInput(match);
173  auto output_list = g.GetSubgraphOutput(match);
174 
175  g.DeactivateSubgraph(match);
176 
177  int offset = g.size();
178 
179  g.resize_nodes(offset + r_.size());
180 
181  // Append all the new operators.
182  for (int i = 0; i < r_.size(); i++) {
183  int new_node_idx = offset + i;
184 
185  OperatorDef new_op = r_.node(i).op;
186 
187  new_op.clear_input();
188  new_op.clear_output();
189  // Stitch Input from external graph into replaced subgraph
190  for (const auto& blob : r_.node(i).op.input()) {
191  if (external_renaming.count(blob)) {
192  string new_blob = external_renaming[blob];
193  new_op.add_input(new_blob);
194 
195  // binary searches for new_blob amongst input list.
196  auto it = std::lower_bound(
197  input_list.begin(), input_list.end(), std::make_pair(new_blob, -1));
198 
199  // if the input came from the graph (instead of G's external input)
200  for (; it < input_list.end() && it->first == new_blob; it++) {
201  int parent = it->second;
202  g.node(parent).children[new_node_idx].push_back(new_blob);
203  g.node(new_node_idx).parents[parent].push_back(new_blob);
204  }
205  } else {
206  new_op.add_input(TransformBlobWrapper(blob));
207  }
208  }
209  // Stitch Output from replaced subgraph to external graph.
210  for (const auto& blob : r_.node(i).op.output()) {
211  if (external_renaming.count(blob)) {
212  string new_blob = external_renaming[blob];
213  new_op.add_output(new_blob);
214 
215  // binary searches for new_blob amongst input list.
216  auto it = std::lower_bound(
217  output_list.begin(),
218  output_list.end(),
219  std::make_pair(new_blob, -1));
220 
221  // if the output goes to the graph (instead of G's external output)
222  for (; it < output_list.end() && it->first == new_blob; it++) {
223  int child = it->second;
224  g.node(child).parents[new_node_idx].push_back(new_blob);
225  g.node(new_node_idx).children[child].push_back(new_blob);
226  }
227  } else {
228  new_op.add_output(TransformBlobWrapper(blob));
229  }
230  }
231 
232  // Connect all internal edges within replace graph
233  for (const auto& edge : r_.node(i).parents) {
234  int parent = edge.first;
235  int new_node_parent = offset + parent;
236  const auto& blobs = edge.second;
237  for (const string& blob : blobs) {
238  g.node(new_node_idx)
239  .parents[new_node_parent]
240  .push_back(TransformBlobWrapper(blob));
241  }
242  }
243 
244  for (const auto& edge : r_.node(i).children) {
245  int child = edge.first;
246  int new_node_child = offset + child;
247  const auto& blobs = edge.second;
248  for (const string& blob : blobs) {
249  g.node(offset + i)
250  .children[new_node_child]
251  .push_back(TransformBlobWrapper(blob));
252  }
253  }
254 
255  g.node(new_node_idx).op = new_op;
256  g.node(new_node_idx).active = true;
257  }
258  return true;
259 }
260 
261 } // namespace Caffe2
bool ValidatorRule(const transform::Graph &g, const std::vector< int > &subgraph) override
ValidatorRule for PatternNetTransform does the following:
Graph representation of a Netdef.
Definition: graph.h:48
bool ReplaceRule(const std::vector< int > &subgraph, transform::Graph *g_ptr) override
ReplaceRule for PatternNet Transform does the following:
bool MatchStrings(string p, string s)
This allows for the use of * and | to match operator types, engines, or any other property that is re...
Definition: graph.cc:214
bool MatchArguments(const OperatorDef &p_op, const OperatorDef &g_op)
This ensures that each named arg that exists in the pattern exists in g_op, is equal in value...
Definition: graph.cc:228
std::vector< std::shared_ptr< Module > > children() const
Returns the direct submodules of this Module.
Definition: module.cpp:224
bool PatternRule(const transform::Graph &g, const std::vector< int > &subgraph, int idx) override
We want to the final result of subgraph to match the PatternNet in the order of ordered_ops, operator by operator.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13