Caffe2 - C++ API
A deep learning, cross platform ML framework
pattern_net_transform.h
1 #pragma once
2 
3 #include "caffe2/core/common.h"
4 #include "caffe2/core/transform.h"
5 #include "caffe2/proto/caffe2_pb.h"
6 #include "caffe2/utils/proto_utils.h"
7 
8 namespace caffe2 {
9 
18 class CAFFE2_API PatternNetTransform : public Transform {
19  public:
20  PatternNetTransform(const NetDef& pattern_net, const NetDef& replace_net)
21  : p_(transform::Graph(pattern_net)), r_(transform::Graph(replace_net)) {
22  // external input and output must match!
23  CAFFE_ENFORCE(
24  p_.external_input() == r_.external_input(),
25  "External inputs do not match!");
26  CAFFE_ENFORCE(
27  p_.external_output() == r_.external_output(),
28  "External outputs do not match!");
29  ordered_ops_ = GetPatternTraversalOrder(p_);
30  inverse_ops_.resize(ordered_ops_.size());
31  for (size_t i = 0; i < ordered_ops_.size(); i++) {
32  inverse_ops_[ordered_ops_[i]] = i;
33  }
34  }
35 
36  void EnableArgumentMatching() {
37  argument_match_ = true;
38  }
39 
40  void DisableArgumentMatching() {
41  argument_match_ = false;
42  }
43 
44  protected:
58  bool PatternRule(
59  const transform::Graph& g,
60  const std::vector<int>& subgraph,
61  int idx) override;
67  bool ValidatorRule(
68  const transform::Graph& g,
69  const std::vector<int>& subgraph) override;
86  bool ReplaceRule(const std::vector<int>& subgraph, transform::Graph* g_ptr)
87  override;
88 
89  private:
107  std::vector<int> GetPatternTraversalOrder(const transform::Graph& g);
108 
109  // Graph of Pattern NetDef
110  transform::Graph p_;
111 
112  // The Traversal Order of the Pattern Net's Operators
113  // This is a permutation of the numbers from {0, ..., p.size()-1}
114  std::vector<int> ordered_ops_;
115 
116  // The Inverse of the Traversal Order of the Pattern Net's Operators
117  // That is, inverse_ops[ordered_ops[i]] == i is always true.
118  std::vector<int> inverse_ops_;
119 
120  // Graph of Replace NetDef
121  transform::Graph r_;
122 
123  // This flag determines if the transform will match operator arguments.
124  bool argument_match_ = false;
125 
126  const string TransformBlobWrapper(const string& blob_name) {
127  return "transform/" + blob_name + "_" + c10::to_string(ssa_id_);
128  }
129 
130  int ssa_id_ = 0;
131 };
132 
133 } // namespace caffe2
Graph representation of a Netdef.
Definition: graph.h:48
The Transform Base Object.
Definition: transform.h:34
PatternNetTransform allows you to create transforms using a simple interface.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13