Caffe2 - C++ API
A deep learning, cross platform ML framework
pattern_net_transform.h
1 
17 #pragma once
18 
19 #include "caffe2/core/common.h"
20 #include "caffe2/core/transform.h"
21 #include "caffe2/proto/caffe2.pb.h"
22 #include "caffe2/utils/proto_utils.h"
23 
24 namespace caffe2 {
25 
35  public:
36  PatternNetTransform(const NetDef& pattern_net, const NetDef& replace_net)
37  : p_(transform::Graph(pattern_net)), r_(transform::Graph(replace_net)) {
38  // external input and output must match!
39  CAFFE_ENFORCE(
40  p_.external_input() == r_.external_input(),
41  "External inputs do not match!");
42  CAFFE_ENFORCE(
43  p_.external_output() == r_.external_output(),
44  "External outputs do not match!");
45  ordered_ops_ = GetPatternTraversalOrder(p_);
46  inverse_ops_.resize(ordered_ops_.size());
47  for (int i = 0; i < ordered_ops_.size(); i++) {
48  inverse_ops_[ordered_ops_[i]] = i;
49  }
50  }
51 
52  void EnableArgumentMatching() {
53  argument_match_ = true;
54  }
55 
56  void DisableArgumentMatching() {
57  argument_match_ = false;
58  }
59 
60  protected:
74  bool PatternRule(
75  const transform::Graph& g,
76  const std::vector<int>& subgraph,
77  int idx) override;
83  bool ValidatorRule(
84  const transform::Graph& g,
85  const std::vector<int>& subgraph) override;
102  bool ReplaceRule(const std::vector<int>& subgraph, transform::Graph* g_ptr)
103  override;
104 
105  private:
123  std::vector<int> GetPatternTraversalOrder(const transform::Graph& g);
124 
125  // Graph of Pattern NetDef
126  transform::Graph p_;
127 
128  // The Traversal Order of the Pattern Net's Operators
129  // This is a permutation of the numbers from {0, ..., p.size()-1}
130  std::vector<int> ordered_ops_;
131 
132  // The Inverse of the Traversal Order of the Pattern Net's Operators
133  // That is, inverse_ops[ordered_ops[i]] == i is always true.
134  std::vector<int> inverse_ops_;
135 
136  // Graph of Replace NetDef
137  transform::Graph r_;
138 
139  // This flag determines if the transform will match operator arguments.
140  bool argument_match_ = false;
141 
142  const string TransformBlobWrapper(const string& blob_name) {
143  return "transform/" + blob_name + "_" + caffe2::to_string(ssa_id_);
144  }
145 
146  int ssa_id_ = 0;
147 };
148 
149 } // namespace caffe2
bool ValidatorRule(const transform::Graph &g, const std::vector< int > &subgraph) override
ValidatorRule for PatternNetTransform does the following:
Graph representation of a Netdef.
Definition: graph.h:64
The Transform Base Object.
Definition: transform.h:50
bool ReplaceRule(const std::vector< int > &subgraph, transform::Graph *g_ptr) override
ReplaceRule for PatternNet Transform does the following:
PatternNetTransform allows you to create transforms using a simple interface.
bool PatternRule(const transform::Graph &g, const std::vector< int > &subgraph, int idx) override
We want to the final result of subgraph to match the PatternNet in the order of ordered_ops, operator by operator.
Copyright (c) 2016-present, Facebook, Inc.