1 #include "caffe2/transforms/pattern_net_transform.h" 3 #include "caffe2/core/common.h" 4 #include "caffe2/core/logging.h" 5 #include "caffe2/core/net.h" 6 #include "caffe2/proto/caffe2_pb.h" 13 std::vector<int> PatternNetTransform::GetPatternTraversalOrder(
14 const transform::Graph& graph) {
15 std::vector<bool> visited(graph.size(),
false);
16 std::vector<int> ordered_ops;
18 if (graph.size() > 0) {
20 ordered_ops.push_back(0);
26 for (
const auto& edge : graph.node(idx).children) {
30 ordered_ops.push_back(x);
34 for (
const auto& edge : graph.node(idx).parents) {
38 ordered_ops.push_back(x);
44 ordered_ops.size() == graph.size(),
"Pattern graph must be connected.");
49 const OperatorDef& p_op,
50 const OperatorDef& g_op,
54 p_op.has_type(),
"Types must be specified for all pattern operators.");
59 if (p_op.input().size() != g_op.input().size()) {
64 if (p_op.output().size() != g_op.output().size()) {
68 if (p_op.has_device_option()) {
69 if (!g_op.has_device_option() ||
70 p_op.device_option().device_type() !=
71 g_op.device_option().device_type()) {
77 if (p_op.has_engine() && !
MatchStrings(p_op.engine(), g_op.engine())) {
93 const std::vector<int>& subgraph,
95 if (subgraph.size() >= ordered_ops_.size()) {
98 int p_idx = ordered_ops_[subgraph.size()];
100 if (!compare_ops(p_.node(p_idx).op, g.node(g_idx).op, argument_match_)) {
113 for (
const auto& edge : p_.node(p_idx).parents) {
114 int parent = edge.first;
117 if (inverse_ops_[parent] < subgraph.size() &&
118 g.node(g_idx).parents.count(subgraph[inverse_ops_[parent]]) == 0) {
123 for (
const auto& edge : p_.node(p_idx).children) {
124 int child = edge.first;
125 if (inverse_ops_[child] < subgraph.size() &&
126 g.node(g_idx).children.count(subgraph[inverse_ops_[child]]) == 0) {
135 const std::vector<int>& subgraph) {
137 return subgraph.size() == p_.size();
141 const std::vector<int>& match,
150 std::unordered_map<string, string> external_renaming;
153 for (
int i = 0; i < match.size(); i++) {
154 int g_idx = match[i];
155 int p_idx = ordered_ops_[i];
156 for (
int j = 0; j < p_.node(p_idx).op.input().size(); j++) {
157 string p_blob = p_.node(p_idx).op.input(j);
158 string g_blob = g.node(g_idx).op.input(j);
159 if (p_.external_input().count(p_blob)) {
160 external_renaming[p_blob] = g_blob;
163 for (
int j = 0; j < p_.node(p_idx).op.output().size(); j++) {
164 string p_blob = p_.node(p_idx).op.output(j);
165 string g_blob = g.node(g_idx).op.output(j);
166 if (p_.external_output().count(p_blob)) {
167 external_renaming[p_blob] = g_blob;
172 auto input_list = g.GetSubgraphInput(match);
173 auto output_list = g.GetSubgraphOutput(match);
175 g.DeactivateSubgraph(match);
177 int offset = g.size();
179 g.resize_nodes(offset + r_.size());
182 for (
int i = 0; i < r_.size(); i++) {
183 int new_node_idx = offset + i;
185 OperatorDef new_op = r_.node(i).op;
187 new_op.clear_input();
188 new_op.clear_output();
190 for (
const auto& blob : r_.node(i).op.input()) {
191 if (external_renaming.count(blob)) {
192 string new_blob = external_renaming[blob];
193 new_op.add_input(new_blob);
196 auto it = std::lower_bound(
197 input_list.begin(), input_list.end(), std::make_pair(new_blob, -1));
200 for (; it < input_list.end() && it->first == new_blob; it++) {
201 int parent = it->second;
202 g.node(parent).
children[new_node_idx].push_back(new_blob);
203 g.node(new_node_idx).parents[parent].push_back(new_blob);
206 new_op.add_input(TransformBlobWrapper(blob));
210 for (
const auto& blob : r_.node(i).op.output()) {
211 if (external_renaming.count(blob)) {
212 string new_blob = external_renaming[blob];
213 new_op.add_output(new_blob);
216 auto it = std::lower_bound(
219 std::make_pair(new_blob, -1));
222 for (; it < output_list.end() && it->first == new_blob; it++) {
223 int child = it->second;
224 g.node(child).parents[new_node_idx].push_back(new_blob);
225 g.node(new_node_idx).
children[child].push_back(new_blob);
228 new_op.add_output(TransformBlobWrapper(blob));
233 for (
const auto& edge : r_.node(i).parents) {
234 int parent = edge.first;
235 int new_node_parent = offset + parent;
236 const auto& blobs = edge.second;
237 for (
const string& blob : blobs) {
239 .parents[new_node_parent]
240 .push_back(TransformBlobWrapper(blob));
244 for (
const auto& edge : r_.node(i).children) {
245 int child = edge.first;
246 int new_node_child = offset + child;
247 const auto& blobs = edge.second;
248 for (
const string& blob : blobs) {
251 .push_back(TransformBlobWrapper(blob));
255 g.node(new_node_idx).op = new_op;
256 g.node(new_node_idx).active =
true;
bool MatchStrings(string p, string s)
This allows for the use of * and | to match operator types, engines, or any other property that is re...
bool MatchArguments(const OperatorDef &p_op, const OperatorDef &g_op)
This ensures that each named arg that exists in the pattern exists in g_op, is equal in value...
std::vector< std::shared_ptr< Module > > children() const
Returns the direct submodules of this Module.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...