Caffe2 - C++ API
A deep learning, cross platform ML framework
transform.cc
1 
17 #include "caffe2/core/transform.h"
18 
19 #include "caffe2/core/common.h"
20 #include "caffe2/core/logging.h"
21 #include "caffe2/core/net.h"
22 #include "caffe2/core/timer.h"
23 #include "caffe2/proto/caffe2.pb.h"
24 
25 namespace caffe2 {
26 
27 using transform::Graph;
28 
29 CAFFE_DEFINE_REGISTRY(TransformRegistry, Transform);
30 
31 std::vector<std::vector<int>> Transform::PatternMatch(const Graph& graph) {
32  // checks if the node at index i is matched already or not
33  std::vector<bool> matched(graph.size(), false);
34 
35  // stores matches, which are ordered subgraphs of G
36  std::vector<std::vector<int>> matches;
37 
38  // Consider every possible node as the starting point.
39  for (int idx = 0; idx < graph.size(); ++idx) {
40  // The current working subgraph. We will try to add new nodes to this,
41  // when invoking the PatternRule.
42  std::vector<int> subgraph;
43 
44  // The largest "validated" subgraph found so far.
45  // This will be mutated by PatternMatchHelper.
46  std::vector<int> best_subgraph;
47 
48  // Only begin to match if the start node is accepted.
49  if (!matched.at(idx) && PatternRule(graph, subgraph, idx)) {
50  subgraph.push_back(idx);
51  PatternMatchHelper(graph, matched, &subgraph, &best_subgraph);
52  subgraph.pop_back();
53  }
54  if (best_subgraph.size() > 0) { // match found
55  matches.push_back(best_subgraph);
56  for (const auto& x : best_subgraph) {
57  matched[x] = true;
58  }
59  }
60  }
61  return matches;
62 }
63 
64 void Transform::TryNeighbors(
65  const Graph& graph,
66  const std::map<int, std::vector<string>>& neighbors,
67  const std::vector<bool>& matched,
68  std::vector<int>* subgraph_ptr,
69  std::vector<int>* best_subgraph_ptr) {
70  auto& subgraph = *subgraph_ptr;
71  for (const auto& edge : neighbors) {
72  int j = edge.first;
73  if (std::find(subgraph.begin(), subgraph.end(), j) == subgraph.end()) {
74  if (!matched.at(j) && PatternRule(graph, subgraph, j)) {
75  subgraph.push_back(j);
76  PatternMatchHelper(graph, matched, subgraph_ptr, best_subgraph_ptr);
77  subgraph.pop_back();
78  }
79  }
80  }
81 }
82 
83 void Transform::PatternMatchHelper(
84  const Graph& graph,
85  const std::vector<bool>& matched,
86  std::vector<int>* subgraph_ptr,
87  std::vector<int>* best_subgraph_ptr) {
88  CHECK(subgraph_ptr);
89  auto& subgraph = *subgraph_ptr;
90  CHECK(best_subgraph_ptr);
91  auto& best_subgraph = *best_subgraph_ptr;
92 
93  // If the current subgraph is valid, and the largest we've seen so far,
94  // make it the best_subgraph.
95  if (ValidatorRule(graph, subgraph) &&
96  subgraph.size() > best_subgraph.size()) {
97  best_subgraph = subgraph;
98  }
99 
100  int size_before = subgraph.size();
101 
102  if (pattern_match_type_ == CONNECTED_SUBGRAPH) {
103  // Connected Component Order Pattern Matching
104  // We want to match subgraphs which are connected ConnectedComponents
105 
106  // Try adding each parent and child of every node in the subgraph,
107  // and see if we can accept it.
108  for (int i = 0; i < subgraph.size(); i++) {
109  int x = subgraph[i];
110  TryNeighbors(
111  graph,
112  graph.node(x).children,
113  matched,
114  subgraph_ptr,
115  best_subgraph_ptr);
116  CAFFE_ENFORCE(
117  size_before == subgraph.size(),
118  "Subgraph size should not change after returning from recursive call.");
119  TryNeighbors(
120  graph,
121  graph.node(x).parents,
122  matched,
123  subgraph_ptr,
124  best_subgraph_ptr);
125  CAFFE_ENFORCE(
126  size_before == subgraph.size(),
127  "Subgraph size should not change after returning from recursive call.");
128  }
129  } else if (pattern_match_type_ == SORTED_WRT_EXECUTION_ORDER) {
130  // Sorted Execution Order Pattern matching
131  // We want to be able to match subgraphs in sorted execution order
132 
133  // We can safely assume our subgraph is already sorted.
134  // This means, we only need to consider nodes that come after the LAST
135  // node in our current subgraph.
136  // Thus, we simply iterate over the nodes that come AFTER the last node of
137  // our current subgraph.
138  int start_idx = 0;
139  if (subgraph.size() > 0) {
140  start_idx = subgraph.back() + 1;
141  }
142  for (int i = start_idx; i < graph.size(); i++) {
143  if (!matched.at(i) && PatternRule(graph, subgraph, i)) {
144  subgraph.push_back(i);
145  PatternMatchHelper(graph, matched, subgraph_ptr, best_subgraph_ptr);
146  subgraph.pop_back();
147  }
148  }
149  } else if (pattern_match_type_ == GENERAL) {
150  // General Pattern matching
151  // We want to be able to match any ordered subgraph
152 
153  // For every current subgraph, we consider all nodes to be
154  // the next candidate node, as long as it isn't already matched.
155  for (int i = 0; i < graph.size(); i++) {
156  if (std::find(subgraph.begin(), subgraph.end(), i) == subgraph.end()) {
157  // Then we try appending it to the subgraph.
158  if (!matched.at(i) && PatternRule(graph, subgraph, i)) {
159  subgraph.push_back(i);
160  PatternMatchHelper(graph, matched, subgraph_ptr, best_subgraph_ptr);
161  subgraph.pop_back();
162  }
163  }
164  }
165  } else {
166  CAFFE_NOT_IMPLEMENTED;
167  }
168 }
169 
171  const std::vector<vector<int>>& matches,
172  Graph* graph) {
173  for (const auto& match : matches) {
174  // Make sure each matched node is still active (not overwritten)
175  bool is_match_active = true;
176  for (int idx : match) {
177  if (!graph->is_node_active(idx)) {
178  is_match_active = false;
179  }
180  }
181 
182  // Simply try to apply the replace rule upon every match.
183  if (is_match_active && !ReplaceRule(match, graph)) {
184  CAFFE_THROW("Replace failed!");
185  }
186  }
187 }
188 
189 // The simple interface - performs the transformation upon a NetDef, and returns
190 // the result.
191 NetDef Transform::ApplyTo(const NetDef& orig_net) {
192  Graph g(orig_net);
193  const auto matches = PatternMatch(g);
194  ReplacePattern(matches, &g);
195  return g.GetNetDef();
196 }
197 
198 // Create a Transform object
199 unique_ptr<Transform> CreateTransform(string key) {
200  auto t = TransformRegistry()->Create(key);
201  CAFFE_ENFORCE(t != nullptr, "Transform not found in registry: ", key);
202  return t;
203 }
204 
205 // Create a Transform object from registry,
206 // and immediately apply it to a Netdef.
207 NetDef ApplyTransform(const string& key, const NetDef& netdef) {
208  auto t = CreateTransform(key);
209  return t->ApplyTo(netdef);
210 }
211 
212 double average_net_run_duration(
213  const NetDef& netdef,
214  const NetDef& init_netdef,
215  const int warmup_runs,
216  const int main_runs) {
217  Workspace ws;
218  if (init_netdef.op_size() > 0) {
219  std::unique_ptr<NetBase> init_net(CreateNet(init_netdef, &ws));
220  CHECK(init_net);
221  CAFFE_ENFORCE(init_net->Run(), "Init run has failed!");
222  } else {
223  // If a proper init_net is not provided, then this is the best we can do.
224  for (auto inp : netdef.external_input()) {
225  ws.CreateBlob(inp);
226  }
227  }
228  std::unique_ptr<NetBase> net(CreateNet(netdef, &ws));
229  CHECK(net);
230  CAFFE_ENFORCE(
231  warmup_runs >= 0,
232  "Number of warm up runs should be non negative, provided ",
233  warmup_runs,
234  ".");
235 
236  for (int i = 0; i < warmup_runs; i++) {
237  CAFFE_ENFORCE(net->Run(), "Warmup run ", i, " has failed.");
238  }
239 
240  CAFFE_ENFORCE(
241  main_runs > 0,
242  "Number of main runs should be positive, provided ",
243  main_runs,
244  ".");
245  Timer timer;
246  for (int i = 0; i < main_runs; i++) {
247  CAFFE_ENFORCE(net->Run(), "Main run ", i, " has failed.");
248  }
249  return timer.MilliSeconds();
250 }
251 
252 // Create a Transform object from registry, apply it to a NetDef.
253 // Will only return the transformed net if it is faster than the old net.
254 // This will run the init net first, will run the two nets warmup_runs times.
255 // Then, we will take the average time of main_runs runs, and only keep the
256 // transformed net if it is faster by a factor of improvement_threshold.
257 NetDef ApplyTransformIfFaster(
258  const string& key,
259  const NetDef& netdef,
260  const NetDef& init_netdef,
261  const int warmup_runs,
262  const int main_runs,
263  const double improvement_threshold) {
264  NetDef transformed_netdef = ApplyTransform(key, netdef);
265  double original_net_time =
266  average_net_run_duration(netdef, init_netdef, warmup_runs, main_runs);
267  double new_net_time = average_net_run_duration(
268  transformed_netdef, init_netdef, warmup_runs, main_runs);
269  if (original_net_time > improvement_threshold * new_net_time) {
270  return transformed_netdef;
271  }
272  return netdef;
273 }
274 
275 } // namespace Caffe2
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:120
void ReplacePattern(const std::vector< std::vector< int >> &matches, transform::Graph *graph)
Applies the replace rule onto each of the matches found.
Definition: transform.cc:170
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
std::vector< std::vector< int > > PatternMatch(const transform::Graph &graph)
Generates all matches (stored as ordered subgraphs) and returns them.
Definition: transform.cc:31
virtual bool PatternRule(const transform::Graph &g, const std::vector< int > &subgraph, int)
The PatternRule essentially answers: Given the current subgraph (ordered), should we append the new n...
Definition: transform.h:112
Copyright (c) 2016-present, Facebook, Inc.
float MilliSeconds()
Returns the elapsed time in milliseconds.
Definition: timer.h:48
NetDef ApplyTo(const NetDef &orig_net_def)
Apply a Transform onto a NetDef.
Definition: transform.cc:191
virtual bool ReplaceRule(const std::vector< int > &subgraph, transform::Graph *g_ptr)
The ReplaceRule actually mutates the graph, and applies the transformation upon the subgraph...
Definition: transform.h:133
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:117
A simple timer object for measuring time.
Definition: timer.h:32
virtual bool ValidatorRule(const transform::Graph &g, const std::vector< int > &subgraph)
The ValidatorRule essentially answers: Given a subgraph, can we accept it?
Definition: transform.h:123