Caffe2 - C++ API
A deep learning, cross platform ML framework
transform.cc
1 #include "caffe2/core/transform.h"
2 
3 #include "caffe2/core/common.h"
4 #include "caffe2/core/logging.h"
5 #include "caffe2/core/net.h"
6 #include "caffe2/core/timer.h"
7 #include "caffe2/proto/caffe2_pb.h"
8 
9 namespace caffe2 {
10 
11 using transform::Graph;
12 
13 C10_DEFINE_REGISTRY(TransformRegistry, Transform);
14 
15 std::vector<std::vector<int>> Transform::PatternMatch(const Graph& graph) {
16  // checks if the node at index i is matched already or not
17  std::vector<bool> matched(graph.size(), false);
18 
19  // stores matches, which are ordered subgraphs of G
20  std::vector<std::vector<int>> matches;
21 
22  // Consider every possible node as the starting point.
23  for (int idx = 0; idx < (int)graph.size(); ++idx) {
24  // The current working subgraph. We will try to add new nodes to this,
25  // when invoking the PatternRule.
26  std::vector<int> subgraph;
27 
28  // The largest "validated" subgraph found so far.
29  // This will be mutated by PatternMatchHelper.
30  std::vector<int> best_subgraph;
31 
32  // Only begin to match if the start node is accepted.
33  if (!matched.at(idx) && PatternRule(graph, subgraph, idx)) {
34  subgraph.push_back(idx);
35  PatternMatchHelper(graph, matched, &subgraph, &best_subgraph);
36  subgraph.pop_back();
37  }
38  if (best_subgraph.size() > 0) { // match found
39  matches.push_back(best_subgraph);
40  for (const auto& x : best_subgraph) {
41  matched[x] = true;
42  }
43  }
44  }
45  return matches;
46 }
47 
48 void Transform::TryNeighbors(
49  const Graph& graph,
50  const std::map<int, std::vector<string>>& neighbors,
51  const std::vector<bool>& matched,
52  std::vector<int>* subgraph_ptr,
53  std::vector<int>* best_subgraph_ptr) {
54  auto& subgraph = *subgraph_ptr;
55  for (const auto& edge : neighbors) {
56  int j = edge.first;
57  if (std::find(subgraph.begin(), subgraph.end(), j) == subgraph.end()) {
58  if (!matched.at(j) && PatternRule(graph, subgraph, j)) {
59  subgraph.push_back(j);
60  PatternMatchHelper(graph, matched, subgraph_ptr, best_subgraph_ptr);
61  subgraph.pop_back();
62  }
63  }
64  }
65 }
66 
67 void Transform::PatternMatchHelper(
68  const Graph& graph,
69  const std::vector<bool>& matched,
70  std::vector<int>* subgraph_ptr,
71  std::vector<int>* best_subgraph_ptr) {
72  CHECK(subgraph_ptr);
73  auto& subgraph = *subgraph_ptr;
74  CHECK(best_subgraph_ptr);
75  auto& best_subgraph = *best_subgraph_ptr;
76 
77  // If the current subgraph is valid, and the largest we've seen so far,
78  // make it the best_subgraph.
79  if (ValidatorRule(graph, subgraph) &&
80  subgraph.size() > best_subgraph.size()) {
81  best_subgraph = subgraph;
82  }
83 
84  size_t size_before = subgraph.size();
85 
86  if (pattern_match_type_ == CONNECTED_SUBGRAPH) {
87  // Connected Component Order Pattern Matching
88  // We want to match subgraphs which are connected ConnectedComponents
89 
90  // Try adding each parent and child of every node in the subgraph,
91  // and see if we can accept it.
92  for (size_t i = 0; i < subgraph.size(); i++) {
93  int x = subgraph[i];
94  TryNeighbors(
95  graph,
96  graph.node(x).children,
97  matched,
98  subgraph_ptr,
99  best_subgraph_ptr);
100  CAFFE_ENFORCE(
101  size_before == subgraph.size(),
102  "Subgraph size should not change after returning from recursive call.");
103  TryNeighbors(
104  graph,
105  graph.node(x).parents,
106  matched,
107  subgraph_ptr,
108  best_subgraph_ptr);
109  CAFFE_ENFORCE(
110  size_before == subgraph.size(),
111  "Subgraph size should not change after returning from recursive call.");
112  }
113  } else if (pattern_match_type_ == SORTED_WRT_EXECUTION_ORDER) {
114  // Sorted Execution Order Pattern matching
115  // We want to be able to match subgraphs in sorted execution order
116 
117  // We can safely assume our subgraph is already sorted.
118  // This means, we only need to consider nodes that come after the LAST
119  // node in our current subgraph.
120  // Thus, we simply iterate over the nodes that come AFTER the last node of
121  // our current subgraph.
122  size_t start_idx = 0;
123  if (subgraph.size() > 0) {
124  start_idx = subgraph.back() + 1;
125  }
126  for (size_t i = start_idx; i < graph.size(); i++) {
127  if (!matched.at(i) && PatternRule(graph, subgraph, i)) {
128  subgraph.push_back(i);
129  PatternMatchHelper(graph, matched, subgraph_ptr, best_subgraph_ptr);
130  subgraph.pop_back();
131  }
132  }
133  } else if (pattern_match_type_ == GENERAL) {
134  // General Pattern matching
135  // We want to be able to match any ordered subgraph
136 
137  // For every current subgraph, we consider all nodes to be
138  // the next candidate node, as long as it isn't already matched.
139  for (size_t i = 0; i < graph.size(); i++) {
140  if (std::find(subgraph.begin(), subgraph.end(), i) == subgraph.end()) {
141  // Then we try appending it to the subgraph.
142  if (!matched.at(i) && PatternRule(graph, subgraph, i)) {
143  subgraph.push_back(i);
144  PatternMatchHelper(graph, matched, subgraph_ptr, best_subgraph_ptr);
145  subgraph.pop_back();
146  }
147  }
148  }
149  } else {
150  CAFFE_NOT_IMPLEMENTED;
151  }
152 }
153 
155  const std::vector<vector<int>>& matches,
156  Graph* graph) {
157  for (const auto& match : matches) {
158  // Make sure each matched node is still active (not overwritten)
159  bool is_match_active = true;
160  for (int idx : match) {
161  if (!graph->is_node_active(idx)) {
162  is_match_active = false;
163  }
164  }
165 
166  // Simply try to apply the replace rule upon every match.
167  if (is_match_active && !ReplaceRule(match, graph)) {
168  CAFFE_THROW("Replace failed!");
169  }
170  }
171 }
172 
173 // The simple interface - performs the transformation upon a NetDef, and returns
174 // the result.
175 NetDef Transform::ApplyTo(const NetDef& orig_net) {
176  Graph g(orig_net);
177  const auto matches = PatternMatch(g);
178  ReplacePattern(matches, &g);
179  return g.GetNetDef();
180 }
181 
182 // Create a Transform object
183 unique_ptr<Transform> CreateTransform(string key) {
184  auto t = TransformRegistry()->Create(key);
185  CAFFE_ENFORCE(t != nullptr, "Transform not found in registry: ", key);
186  return t;
187 }
188 
189 // Create a Transform object from registry,
190 // and immediately apply it to a Netdef.
191 NetDef ApplyTransform(const string& key, const NetDef& netdef) {
192  auto t = CreateTransform(key);
193  return t->ApplyTo(netdef);
194 }
195 
196 double average_net_run_duration(
197  const NetDef& netdef,
198  const NetDef& init_netdef,
199  const int warmup_runs,
200  const int main_runs) {
201  Workspace ws;
202  if (init_netdef.op_size() > 0) {
203  std::unique_ptr<NetBase> init_net(CreateNet(init_netdef, &ws));
204  CHECK(init_net);
205  CAFFE_ENFORCE(init_net->Run(), "Init run has failed!");
206  } else {
207  // If a proper init_net is not provided, then this is the best we can do.
208  for (auto inp : netdef.external_input()) {
209  ws.CreateBlob(inp);
210  }
211  }
212  std::unique_ptr<NetBase> net(CreateNet(netdef, &ws));
213  CHECK(net);
214  CAFFE_ENFORCE(
215  warmup_runs >= 0,
216  "Number of warm up runs should be non negative, provided ",
217  warmup_runs,
218  ".");
219 
220  for (int i = 0; i < warmup_runs; i++) {
221  CAFFE_ENFORCE(net->Run(), "Warmup run ", i, " has failed.");
222  }
223 
224  CAFFE_ENFORCE(
225  main_runs > 0,
226  "Number of main runs should be positive, provided ",
227  main_runs,
228  ".");
229  Timer timer;
230  for (int i = 0; i < main_runs; i++) {
231  CAFFE_ENFORCE(net->Run(), "Main run ", i, " has failed.");
232  }
233  return timer.MilliSeconds();
234 }
235 
236 // Create a Transform object from registry, apply it to a NetDef.
237 // Will only return the transformed net if it is faster than the old net.
238 // This will run the init net first, will run the two nets warmup_runs times.
239 // Then, we will take the average time of main_runs runs, and only keep the
240 // transformed net if it is faster by a factor of improvement_threshold.
241 NetDef ApplyTransformIfFaster(
242  const string& key,
243  const NetDef& netdef,
244  const NetDef& init_netdef,
245  const int warmup_runs,
246  const int main_runs,
247  const double improvement_threshold) {
248  NetDef transformed_netdef = ApplyTransform(key, netdef);
249  double original_net_time =
250  average_net_run_duration(netdef, init_netdef, warmup_runs, main_runs);
251  double new_net_time = average_net_run_duration(
252  transformed_netdef, init_netdef, warmup_runs, main_runs);
253  if (original_net_time > improvement_threshold * new_net_time) {
254  return transformed_netdef;
255  }
256  return netdef;
257 }
258 
259 } // namespace Caffe2
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:100
void ReplacePattern(const std::vector< std::vector< int >> &matches, transform::Graph *graph)
Applies the replace rule onto each of the matches found.
Definition: transform.cc:154
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
std::vector< std::vector< int > > PatternMatch(const transform::Graph &graph)
Generates all matches (stored as ordered subgraphs) and returns them.
Definition: transform.cc:15
virtual bool PatternRule(const transform::Graph &g, const std::vector< int > &subgraph, int)
The PatternRule essentially answers: Given the current subgraph (ordered), should we append the new n...
Definition: transform.h:96
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
float MilliSeconds()
Returns the elapsed time in milliseconds.
Definition: timer.h:32
NetDef ApplyTo(const NetDef &orig_net_def)
Apply a Transform onto a NetDef.
Definition: transform.cc:175
virtual bool ReplaceRule(const std::vector< int > &subgraph, transform::Graph *g_ptr)
The ReplaceRule actually mutates the graph, and applies the transformation upon the subgraph...
Definition: transform.h:117
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:151
A simple graph implementation.
Definition: Graph.h:29
A simple timer object for measuring time.
Definition: timer.h:16
virtual bool ValidatorRule(const transform::Graph &g, const std::vector< int > &subgraph)
The ValidatorRule essentially answers: Given a subgraph, can we accept it?
Definition: transform.h:107