1 #include "caffe2/core/transform.h" 3 #include "caffe2/core/common.h" 4 #include "caffe2/core/logging.h" 5 #include "caffe2/core/net.h" 6 #include "caffe2/core/timer.h" 7 #include "caffe2/proto/caffe2_pb.h" 11 using transform::Graph;
13 C10_DEFINE_REGISTRY(TransformRegistry, Transform);
17 std::vector<bool> matched(graph.size(),
false);
20 std::vector<std::vector<int>> matches;
23 for (
int idx = 0; idx < (int)graph.size(); ++idx) {
26 std::vector<int> subgraph;
30 std::vector<int> best_subgraph;
33 if (!matched.at(idx) &&
PatternRule(graph, subgraph, idx)) {
34 subgraph.push_back(idx);
35 PatternMatchHelper(graph, matched, &subgraph, &best_subgraph);
38 if (best_subgraph.size() > 0) {
39 matches.push_back(best_subgraph);
40 for (
const auto& x : best_subgraph) {
48 void Transform::TryNeighbors(
50 const std::map<
int, std::vector<string>>& neighbors,
51 const std::vector<bool>& matched,
52 std::vector<int>* subgraph_ptr,
53 std::vector<int>* best_subgraph_ptr) {
54 auto& subgraph = *subgraph_ptr;
55 for (
const auto& edge : neighbors) {
57 if (std::find(subgraph.begin(), subgraph.end(), j) == subgraph.end()) {
58 if (!matched.at(j) &&
PatternRule(graph, subgraph, j)) {
59 subgraph.push_back(j);
60 PatternMatchHelper(graph, matched, subgraph_ptr, best_subgraph_ptr);
67 void Transform::PatternMatchHelper(
69 const std::vector<bool>& matched,
70 std::vector<int>* subgraph_ptr,
71 std::vector<int>* best_subgraph_ptr) {
73 auto& subgraph = *subgraph_ptr;
74 CHECK(best_subgraph_ptr);
75 auto& best_subgraph = *best_subgraph_ptr;
80 subgraph.size() > best_subgraph.size()) {
81 best_subgraph = subgraph;
84 size_t size_before = subgraph.size();
86 if (pattern_match_type_ == CONNECTED_SUBGRAPH) {
92 for (
size_t i = 0; i < subgraph.size(); i++) {
96 graph.node(x).children,
101 size_before == subgraph.size(),
102 "Subgraph size should not change after returning from recursive call.");
105 graph.node(x).parents,
110 size_before == subgraph.size(),
111 "Subgraph size should not change after returning from recursive call.");
113 }
else if (pattern_match_type_ == SORTED_WRT_EXECUTION_ORDER) {
122 size_t start_idx = 0;
123 if (subgraph.size() > 0) {
124 start_idx = subgraph.back() + 1;
126 for (
size_t i = start_idx; i < graph.size(); i++) {
127 if (!matched.at(i) &&
PatternRule(graph, subgraph, i)) {
128 subgraph.push_back(i);
129 PatternMatchHelper(graph, matched, subgraph_ptr, best_subgraph_ptr);
133 }
else if (pattern_match_type_ == GENERAL) {
139 for (
size_t i = 0; i < graph.size(); i++) {
140 if (std::find(subgraph.begin(), subgraph.end(), i) == subgraph.end()) {
142 if (!matched.at(i) &&
PatternRule(graph, subgraph, i)) {
143 subgraph.push_back(i);
144 PatternMatchHelper(graph, matched, subgraph_ptr, best_subgraph_ptr);
150 CAFFE_NOT_IMPLEMENTED;
155 const std::vector<vector<int>>& matches,
157 for (
const auto& match : matches) {
159 bool is_match_active =
true;
160 for (
int idx : match) {
161 if (!graph->is_node_active(idx)) {
162 is_match_active =
false;
167 if (is_match_active && !
ReplaceRule(match, graph)) {
168 CAFFE_THROW(
"Replace failed!");
179 return g.GetNetDef();
183 unique_ptr<Transform> CreateTransform(
string key) {
184 auto t = TransformRegistry()->Create(key);
185 CAFFE_ENFORCE(t !=
nullptr,
"Transform not found in registry: ", key);
191 NetDef ApplyTransform(
const string& key,
const NetDef& netdef) {
192 auto t = CreateTransform(key);
193 return t->ApplyTo(netdef);
196 double average_net_run_duration(
197 const NetDef& netdef,
198 const NetDef& init_netdef,
199 const int warmup_runs,
200 const int main_runs) {
202 if (init_netdef.op_size() > 0) {
203 std::unique_ptr<NetBase> init_net(
CreateNet(init_netdef, &ws));
205 CAFFE_ENFORCE(init_net->Run(),
"Init run has failed!");
208 for (
auto inp : netdef.external_input()) {
212 std::unique_ptr<NetBase> net(
CreateNet(netdef, &ws));
216 "Number of warm up runs should be non negative, provided ",
220 for (
int i = 0; i < warmup_runs; i++) {
221 CAFFE_ENFORCE(net->Run(),
"Warmup run ", i,
" has failed.");
226 "Number of main runs should be positive, provided ",
230 for (
int i = 0; i < main_runs; i++) {
231 CAFFE_ENFORCE(net->Run(),
"Main run ", i,
" has failed.");
241 NetDef ApplyTransformIfFaster(
243 const NetDef& netdef,
244 const NetDef& init_netdef,
245 const int warmup_runs,
247 const double improvement_threshold) {
248 NetDef transformed_netdef = ApplyTransform(key, netdef);
249 double original_net_time =
250 average_net_run_duration(netdef, init_netdef, warmup_runs, main_runs);
251 double new_net_time = average_net_run_duration(
252 transformed_netdef, init_netdef, warmup_runs, main_runs);
253 if (original_net_time > improvement_threshold * new_net_time) {
254 return transformed_netdef;
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
float MilliSeconds()
Returns the elapsed time in milliseconds.
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
A simple graph implementation.
A simple timer object for measuring time.