Caffe2 - C++ API
A deep learning, cross platform ML framework
transform.h
1 
17 #pragma once
18 
19 #include "caffe2/core/common.h"
20 #include "caffe2/core/graph.h"
21 #include "caffe2/core/workspace.h"
22 #include "caffe2/proto/caffe2.pb.h"
23 #include "caffe2/utils/proto_utils.h"
24 
25 namespace caffe2 {
26 
50 class Transform {
51  public:
52  Transform() {}
53 
58  NetDef ApplyTo(const NetDef& orig_net_def);
59 
60  virtual ~Transform() {}
61 
87  CONNECTED_SUBGRAPH,
88  SORTED_WRT_EXECUTION_ORDER,
89  GENERAL
90  };
91 
98  std::vector<std::vector<int>> PatternMatch(const transform::Graph& graph);
99 
103  void ReplacePattern(
104  const std::vector<std::vector<int>>& matches,
105  transform::Graph* graph);
106 
107  protected:
112  virtual bool PatternRule(
113  const transform::Graph& g,
114  const std::vector<int>& subgraph,
115  int /*idx*/) {
116  CAFFE_NOT_IMPLEMENTED;
117  }
118 
123  virtual bool ValidatorRule(
124  const transform::Graph& g,
125  const std::vector<int>& subgraph) {
126  CAFFE_NOT_IMPLEMENTED;
127  }
128 
133  virtual bool ReplaceRule(
134  const std::vector<int>& subgraph,
135  transform::Graph* g_ptr) {
136  CAFFE_NOT_IMPLEMENTED;
137  }
138 
139  void SetPatternMatchType(PatternMatchType type) {
140  pattern_match_type_ = type;
141  }
142 
143  private:
148  void PatternMatchHelper(
149  const transform::Graph& graph,
150  const std::vector<bool>& matched,
151  std::vector<int>* subgraph_ptr,
152  std::vector<int>* best_subgraph_ptr);
156  void TryNeighbors(
157  const transform::Graph& graph,
158  const std::map<int, std::vector<string>>& neighbors,
159  const std::vector<bool>& matched,
160  std::vector<int>* subgraph_ptr,
161  std::vector<int>* best_subgraph_ptr);
162 
163  PatternMatchType pattern_match_type_ = CONNECTED_SUBGRAPH;
164 };
165 
166 // Creates a Transform based on a key, which should be defined in registry.
167 unique_ptr<Transform> CreateTransform(string key);
168 
169 CAFFE_DECLARE_REGISTRY(TransformRegistry, Transform);
170 #define REGISTER_TRANSFORM(name, ...) \
171  CAFFE_REGISTER_CLASS(TransformRegistry, name, __VA_ARGS__)
172 
173 // Create a Transform object from registry,
174 // and immediately apply it to a Netdef.
175 NetDef ApplyTransform(const string& key, const NetDef& netdef);
176 
177 // Create a Transform object from registry, apply it to a NetDef.
178 // Will only return the transformed net if it is faster than the old net.
179 // This will run the init net first, will run the two nets warmup_runs times.
180 // Then, we will take the average time of main_runs runs, and only keep the
181 // transformed net if it is faster by a factor of improvement_threshold.
182 NetDef ApplyTransformIfFaster(
183  const string& key,
184  const NetDef& netdef,
185  const NetDef& init_netdef,
186  const int warmup_runs,
187  const int main_runs,
188  const double improvement_threshold);
189 
190 } // namespace
void ReplacePattern(const std::vector< std::vector< int >> &matches, transform::Graph *graph)
Applies the replace rule onto each of the matches found.
Definition: transform.cc:170
Graph representation of a Netdef.
Definition: graph.h:64
The Transform Base Object.
Definition: transform.h:50
std::vector< std::vector< int > > PatternMatch(const transform::Graph &graph)
Generates all matches (stored as ordered subgraphs) and returns them.
Definition: transform.cc:31
virtual bool PatternRule(const transform::Graph &g, const std::vector< int > &subgraph, int)
The PatternRule essentially answers: Given the current subgraph (ordered), should we append the new n...
Definition: transform.h:112
Copyright (c) 2016-present, Facebook, Inc.
NetDef ApplyTo(const NetDef &orig_net_def)
Apply a Transform onto a NetDef.
Definition: transform.cc:191
virtual bool ReplaceRule(const std::vector< int > &subgraph, transform::Graph *g_ptr)
The ReplaceRule actually mutates the graph, and applies the transformation upon the subgraph...
Definition: transform.h:133
PatternMatchType
Determines the type of subgraphs that PatternMatch will find.
Definition: transform.h:86
virtual bool ValidatorRule(const transform::Graph &g, const std::vector< int > &subgraph)
The ValidatorRule essentially answers: Given a subgraph, can we accept it?
Definition: transform.h:123