Caffe2 - C++ API
A deep learning, cross platform ML framework
transform.h
1 #pragma once
2 
3 #include "caffe2/core/common.h"
4 #include "caffe2/core/graph.h"
5 #include "caffe2/core/workspace.h"
6 #include "caffe2/proto/caffe2_pb.h"
7 #include "caffe2/utils/proto_utils.h"
8 
9 namespace caffe2 {
10 
34 class CAFFE2_API Transform {
35  public:
36  Transform() {}
37 
42  NetDef ApplyTo(const NetDef& orig_net_def);
43 
44  virtual ~Transform() {}
45 
71  CONNECTED_SUBGRAPH,
72  SORTED_WRT_EXECUTION_ORDER,
73  GENERAL
74  };
75 
82  std::vector<std::vector<int>> PatternMatch(const transform::Graph& graph);
83 
87  void ReplacePattern(
88  const std::vector<std::vector<int>>& matches,
89  transform::Graph* graph);
90 
91  protected:
96  virtual bool PatternRule(
97  const transform::Graph& g,
98  const std::vector<int>& subgraph,
99  int /*idx*/) {
100  CAFFE_NOT_IMPLEMENTED;
101  }
102 
107  virtual bool ValidatorRule(
108  const transform::Graph& g,
109  const std::vector<int>& subgraph) {
110  CAFFE_NOT_IMPLEMENTED;
111  }
112 
117  virtual bool ReplaceRule(
118  const std::vector<int>& subgraph,
119  transform::Graph* g_ptr) {
120  CAFFE_NOT_IMPLEMENTED;
121  }
122 
123  void SetPatternMatchType(PatternMatchType type) {
124  pattern_match_type_ = type;
125  }
126 
127  private:
132  void PatternMatchHelper(
133  const transform::Graph& graph,
134  const std::vector<bool>& matched,
135  std::vector<int>* subgraph_ptr,
136  std::vector<int>* best_subgraph_ptr);
140  void TryNeighbors(
141  const transform::Graph& graph,
142  const std::map<int, std::vector<string>>& neighbors,
143  const std::vector<bool>& matched,
144  std::vector<int>* subgraph_ptr,
145  std::vector<int>* best_subgraph_ptr);
146 
147  PatternMatchType pattern_match_type_ = CONNECTED_SUBGRAPH;
148 };
149 
150 // Creates a Transform based on a key, which should be defined in registry.
151 CAFFE2_API unique_ptr<Transform> CreateTransform(string key);
152 
153 C10_DECLARE_REGISTRY(TransformRegistry, Transform);
154 #define REGISTER_TRANSFORM(name, ...) \
155  C10_REGISTER_CLASS(TransformRegistry, name, __VA_ARGS__)
156 
157 // Create a Transform object from registry,
158 // and immediately apply it to a Netdef.
159 CAFFE2_API NetDef ApplyTransform(const string& key, const NetDef& netdef);
160 
161 // Create a Transform object from registry, apply it to a NetDef.
162 // Will only return the transformed net if it is faster than the old net.
163 // This will run the init net first, will run the two nets warmup_runs times.
164 // Then, we will take the average time of main_runs runs, and only keep the
165 // transformed net if it is faster by a factor of improvement_threshold.
166 CAFFE2_API NetDef ApplyTransformIfFaster(
167  const string& key,
168  const NetDef& netdef,
169  const NetDef& init_netdef,
170  const int warmup_runs,
171  const int main_runs,
172  const double improvement_threshold);
173 
174 } // namespace
Graph representation of a Netdef.
Definition: graph.h:48
The Transform Base Object.
Definition: transform.h:34
virtual bool PatternRule(const transform::Graph &g, const std::vector< int > &subgraph, int)
The PatternRule essentially answers: Given the current subgraph (ordered), should we append the new n...
Definition: transform.h:96
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
virtual bool ReplaceRule(const std::vector< int > &subgraph, transform::Graph *g_ptr)
The ReplaceRule actually mutates the graph, and applies the transformation upon the subgraph...
Definition: transform.h:117
PatternMatchType
Determines the type of subgraphs that PatternMatch will find.
Definition: transform.h:70
virtual bool ValidatorRule(const transform::Graph &g, const std::vector< int > &subgraph)
The ValidatorRule essentially answers: Given a subgraph, can we accept it?
Definition: transform.h:107