Caffe2 - C++ API
A deep learning, cross platform ML framework
pattern_net_transform.cc
1 
17 #include "caffe2/transforms/pattern_net_transform.h"
18 
19 #include "caffe2/core/common.h"
20 #include "caffe2/core/logging.h"
21 #include "caffe2/core/net.h"
22 #include "caffe2/proto/caffe2.pb.h"
23 
24 namespace caffe2 {
25 
26 // First, single source traverse through the netdef.
27 // This ensures all newly ordered are reachable from their prefix subset
28 // Outputs a permutation of the operators.
29 std::vector<int> PatternNetTransform::GetPatternTraversalOrder(
30  const transform::Graph& graph) {
31  std::vector<bool> visited(graph.size(), false);
32  std::vector<int> ordered_ops;
33  std::queue<int> q;
34  if (graph.size() > 0) {
35  q.push(0);
36  ordered_ops.push_back(0);
37  visited[0] = true;
38  }
39  while (!q.empty()) {
40  int idx = q.front();
41  q.pop();
42  for (const auto& edge : graph.node(idx).children) {
43  int x = edge.first;
44  if (!visited[x]) {
45  q.push(x);
46  ordered_ops.push_back(x);
47  visited[x] = true;
48  }
49  }
50  for (const auto& edge : graph.node(idx).parents) {
51  int x = edge.first;
52  if (!visited[x]) {
53  q.push(x);
54  ordered_ops.push_back(x);
55  visited[x] = true;
56  }
57  }
58  }
59  CAFFE_ENFORCE(
60  ordered_ops.size() == graph.size(), "Pattern graph must be connected.");
61  return ordered_ops;
62 }
63 
64 bool compare_ops(
65  const OperatorDef& p_op,
66  const OperatorDef& g_op,
67  bool arg_match) {
68  // must specify a type for pattern operators
69  CAFFE_ENFORCE(
70  p_op.has_type(), "Types must be specified for all pattern operators.");
71  if (!MatchStrings(p_op.type(), g_op.type())) {
72  return false;
73  }
74  // ensure number of inputs are the same
75  if (p_op.input().size() != g_op.input().size()) {
76  return false;
77  }
78 
79  // ensure number of outputs are the same
80  if (p_op.output().size() != g_op.output().size()) {
81  return false;
82  }
83 
84  if (p_op.has_device_option()) {
85  if (!g_op.has_device_option() ||
86  p_op.device_option().device_type() !=
87  g_op.device_option().device_type()) {
88  return false;
89  }
90  }
91 
92  // make sure engine is the same (if specified in pattern)
93  if (p_op.has_engine() && !MatchStrings(p_op.engine(), g_op.engine())) {
94  return false;
95  }
96  // If argument_match is specified, make sure those are the same.
97  if (arg_match) {
98  if (!MatchArguments(p_op, g_op)) {
99  return false;
100  }
101  }
102  return true;
103 }
104 
105 // g.node(subgraph[i]) should match p_.node(ordered_ops_[i])
106 // g.node(g_idx) should match p_.node(p_idx)
108  const transform::Graph& g,
109  const std::vector<int>& subgraph,
110  int g_idx) {
111  if (subgraph.size() >= ordered_ops_.size()) {
112  return false;
113  }
114  int p_idx = ordered_ops_[subgraph.size()];
115 
116  if (!compare_ops(p_.node(p_idx).op, g.node(g_idx).op, argument_match_)) {
117  return false;
118  }
119 
120  // Let's say ordered_ops_ is [0, 2, 1], with 0 -> 2 being an edge
121  // When we try to match onto the second element, let's say our
122  // subgraph so far is [4], with it trying to become [4, 5].
123  // Then, we need to show that since 0 -> 2 is an edge is ordered_ops_,
124  // 4 must be a direct parent of 5 in the subgraph
125  // (the indices must match).
126  // Similarly, assume there is an edge from 1 -> 2 in p_.
127  // When trying to match [4, 5] to [4, 5, 7], we must verify that
128  // there exists an edge from 7 -> 5 in G.
129  for (const auto& edge : p_.node(p_idx).parents) {
130  int parent = edge.first;
131  // g_idx doesn't have parent in subgraph that p_[p_idx] has
132  // inverse_ops_ gets the index of a p_idx inside of ordered_ops_.
133  if (inverse_ops_[parent] < subgraph.size() &&
134  g.node(g_idx).parents.count(subgraph[inverse_ops_[parent]]) == 0) {
135  return false;
136  }
137  }
138 
139  for (const auto& edge : p_.node(p_idx).children) {
140  int child = edge.first;
141  if (inverse_ops_[child] < subgraph.size() &&
142  g.node(g_idx).children.count(subgraph[inverse_ops_[child]]) == 0) {
143  return false;
144  }
145  }
146  return true;
147 }
148 
150  const transform::Graph& g,
151  const std::vector<int>& subgraph) {
152  // Due to strict PatternRule, it suffices to simply check for size
153  return subgraph.size() == p_.size();
154 }
155 
157  const std::vector<int>& match,
158  transform::Graph* g_ptr) {
159  CHECK(g_ptr);
160  auto& g = *g_ptr;
161 
162  ssa_id_++;
163 
164  // Map of PatternNet blob name to Matched blob name.
165  // Figures out how to rename the pattern_net to make the replacement fit.
166  std::unordered_map<string, string> external_renaming;
167 
168  // Figure out blob renamings
169  for (int i = 0; i < match.size(); i++) {
170  int g_idx = match[i];
171  int p_idx = ordered_ops_[i];
172  for (int j = 0; j < p_.node(p_idx).op.input().size(); j++) {
173  string p_blob = p_.node(p_idx).op.input(j);
174  string g_blob = g.node(g_idx).op.input(j);
175  if (p_.external_input().count(p_blob)) {
176  external_renaming[p_blob] = g_blob;
177  }
178  }
179  for (int j = 0; j < p_.node(p_idx).op.output().size(); j++) {
180  string p_blob = p_.node(p_idx).op.output(j);
181  string g_blob = g.node(g_idx).op.output(j);
182  if (p_.external_output().count(p_blob)) {
183  external_renaming[p_blob] = g_blob;
184  }
185  }
186  }
187 
188  auto input_list = g.GetSubgraphInput(match);
189  auto output_list = g.GetSubgraphOutput(match);
190 
191  g.DeactivateSubgraph(match);
192 
193  int offset = g.size();
194 
195  g.resize_nodes(offset + r_.size());
196 
197  // Append all the new operators.
198  for (int i = 0; i < r_.size(); i++) {
199  int new_node_idx = offset + i;
200 
201  OperatorDef new_op = r_.node(i).op;
202 
203  new_op.clear_input();
204  new_op.clear_output();
205  // Stitch Input from external graph into replaced subgraph
206  for (const auto& blob : r_.node(i).op.input()) {
207  if (external_renaming.count(blob)) {
208  string new_blob = external_renaming[blob];
209  new_op.add_input(new_blob);
210 
211  // binary searches for new_blob amongst input list.
212  auto it = std::lower_bound(
213  input_list.begin(), input_list.end(), std::make_pair(new_blob, -1));
214 
215  // if the input came from the graph (instead of G's external input)
216  for (; it < input_list.end() && it->first == new_blob; it++) {
217  int parent = it->second;
218  g.node(parent).children[new_node_idx].push_back(new_blob);
219  g.node(new_node_idx).parents[parent].push_back(new_blob);
220  }
221  } else {
222  new_op.add_input(TransformBlobWrapper(blob));
223  }
224  }
225  // Stitch Output from replaced subgraph to external graph.
226  for (const auto& blob : r_.node(i).op.output()) {
227  if (external_renaming.count(blob)) {
228  string new_blob = external_renaming[blob];
229  new_op.add_output(new_blob);
230 
231  // binary searches for new_blob amongst input list.
232  auto it = std::lower_bound(
233  output_list.begin(),
234  output_list.end(),
235  std::make_pair(new_blob, -1));
236 
237  // if the output goes to the graph (instead of G's external output)
238  for (; it < output_list.end() && it->first == new_blob; it++) {
239  int child = it->second;
240  g.node(child).parents[new_node_idx].push_back(new_blob);
241  g.node(new_node_idx).children[child].push_back(new_blob);
242  }
243  } else {
244  new_op.add_output(TransformBlobWrapper(blob));
245  }
246  }
247 
248  // Connect all internal edges within replace graph
249  for (const auto& edge : r_.node(i).parents) {
250  int parent = edge.first;
251  int new_node_parent = offset + parent;
252  const auto& blobs = edge.second;
253  for (const string& blob : blobs) {
254  g.node(new_node_idx)
255  .parents[new_node_parent]
256  .push_back(TransformBlobWrapper(blob));
257  }
258  }
259 
260  for (const auto& edge : r_.node(i).children) {
261  int child = edge.first;
262  int new_node_child = offset + child;
263  const auto& blobs = edge.second;
264  for (const string& blob : blobs) {
265  g.node(offset + i)
266  .children[new_node_child]
267  .push_back(TransformBlobWrapper(blob));
268  }
269  }
270 
271  g.node(new_node_idx).op = new_op;
272  g.node(new_node_idx).active = true;
273  }
274  return true;
275 }
276 
277 } // namespace Caffe2
bool ValidatorRule(const transform::Graph &g, const std::vector< int > &subgraph) override
ValidatorRule for PatternNetTransform does the following:
Graph representation of a Netdef.
Definition: graph.h:64
bool ReplaceRule(const std::vector< int > &subgraph, transform::Graph *g_ptr) override
ReplaceRule for PatternNet Transform does the following:
bool PatternRule(const transform::Graph &g, const std::vector< int > &subgraph, int idx) override
We want to the final result of subgraph to match the PatternNet in the order of ordered_ops, operator by operator.
bool MatchArguments(const OperatorDef &p_op, const OperatorDef &g_op)
This ensures that each named arg that exists in the pattern exists in g_op, is equal in value...
Definition: graph.cc:244
Copyright (c) 2016-present, Facebook, Inc.
bool MatchStrings(string p, string s)
This allows for the use of * and | to match operator types, engines, or any other property that is re...
Definition: graph.cc:230