1 #include "caffe2/core/graph.h" 3 #include "caffe2/core/common.h" 4 #include "caffe2/core/logging.h" 5 #include "caffe2/core/net.h" 6 #include "caffe2/proto/caffe2_pb.h" 14 nodes_.resize(net.op_size());
17 for (
int x = 0; x < net.op_size(); x++) {
18 node(x).op = net.op(x);
23 std::unordered_map<string, int> edge_parent;
25 for (
int i = 0; i < (int)nodes_.size(); i++) {
26 for (
const string& blob : node(i).op.input()) {
27 auto it = edge_parent.find(blob);
28 if (it != edge_parent.end()) {
30 node(i).parents[j].push_back(blob);
31 node(j).children[i].push_back(blob);
33 external_input_.insert(blob);
36 for (
const string& blob : node(i).op.output()) {
37 edge_parent[blob] = i;
44 std::unordered_map<string, int> edge_child;
46 for (
int i = (
int)nodes_.size() - 1; i >= 0; i--) {
47 for (
const string& blob : node(i).op.output()) {
48 auto it = edge_child.find(blob);
49 if (it == edge_child.end()) {
50 external_output_.insert(blob);
53 for (
const string& blob : node(i).op.input()) {
60 const std::vector<int>& match) {
61 return GetSubgraphPerimeterHelper(
true, match);
65 const std::vector<int>& match) {
66 return GetSubgraphPerimeterHelper(
false, match);
74 const std::vector<std::pair<string, int>> Graph::GetSubgraphPerimeterHelper(
76 const std::vector<int>& match) {
77 std::vector<std::pair<string, int>> edge_list;
78 std::unordered_set<int> match_set(match.begin(), match.end());
79 for (
int x = 0; x < (int)nodes_.size(); x++) {
80 if (!is_node_active(x)) {
83 if (!match_set.count(x)) {
84 const auto& list = from_children ? node(x).children : node(x).parents;
85 for (
const auto& edge : list) {
86 int parent = edge.first;
87 const auto& blobs = edge.second;
88 if (match_set.count(parent)) {
89 for (
const string& blob : blobs) {
90 edge_list.push_back({blob, x});
97 std::sort(edge_list.begin(), edge_list.end());
102 std::vector<bool> visited(nodes_.size(),
false);
105 NetDef netdef = netdef_;
111 std::vector<int> unchecked_parent_count;
121 std::priority_queue<int, std::vector<int>, std::greater<int>> q;
145 for (
int i = 0; i < (int)nodes_.size(); i++) {
146 unchecked_parent_count.push_back(node(i).parents.size());
147 if (node(i).parents.size() == 0 && is_node_active(i)) {
156 if (!is_node_active(idx)) {
160 auto& op = *(netdef.add_op());
163 for (
const auto& edge : node(idx).children) {
164 int child = edge.first;
165 if (!visited[child] && is_node_active(child)) {
166 unchecked_parent_count[child]--;
167 if (unchecked_parent_count[child] == 0) {
169 visited[child] =
true;
178 for (
int idx : subgraph) {
180 for (
const auto& edge : node(idx).parents) {
181 int parent = edge.first;
182 node(parent).children.erase(idx);
184 for (
const auto& edge : node(idx).children) {
185 int child = edge.first;
186 node(child).parents.erase(idx);
189 node(idx).active =
false;
198 std::vector<string> inputs,
199 std::vector<string> outputs) {
201 auto& netdef = *netdef_ptr;
202 auto op_ptr = netdef.add_op();
204 op.set_type(op_type);
205 for (
const string& inp : inputs) {
208 for (
const string& outp : outputs) {
219 vector<string> choices = split(
'|', p);
220 for (
const string& candidate : choices) {
221 if (candidate == s) {
229 for (
const auto& p_arg : p_op.arg()) {
230 if (!p_arg.has_name()) {
234 for (
const auto& g_arg : g_op.arg()) {
235 if (p_arg.name() == g_arg.name()) {
238 if (!g_arg.has_f() || p_arg.f() != g_arg.f()) {
243 if (!g_arg.has_i() || p_arg.i() != g_arg.i()) {
248 if (!g_arg.has_s() || !
MatchStrings(p_arg.s(), g_arg.s())) {
252 if (p_arg.floats_size() != g_arg.floats_size()) {
255 for (
int i = 0; i < p_arg.floats_size(); i++) {
256 if (p_arg.floats(i) != g_arg.floats(i)) {
260 if (p_arg.ints_size() != g_arg.ints_size()) {
263 for (
int i = 0; i < p_arg.ints_size(); i++) {
264 if (p_arg.ints(i) != g_arg.ints(i)) {
268 if (p_arg.strings_size() != g_arg.strings_size()) {
271 for (
int i = 0; i < p_arg.strings_size(); i++) {
272 if (!
MatchStrings(p_arg.strings(i), g_arg.strings(i))) {
bool MatchStrings(string p, string s)
This allows for the use of * and | to match operator types, engines, or any other property that is re...
bool MatchArguments(const OperatorDef &p_op, const OperatorDef &g_op)
This ensures that each named arg that exists in the pattern exists in g_op, is equal in value...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...