Caffe2 - C++ API
A deep learning, cross platform ML framework
graph.h
1 
17 #pragma once
18 
19 #include "caffe2/core/common.h"
20 #include "caffe2/proto/caffe2.pb.h"
21 #include "caffe2/utils/proto_utils.h"
22 #include "caffe2/utils/string_utils.h"
23 
24 #include <algorithm>
25 #include <unordered_map>
26 #include <unordered_set>
27 
28 namespace caffe2 {
29 
30 namespace transform {
31 
35 struct Node {
36  public:
37  // Empty constructor for resize
38  Node() {}
39 
40  // Alternate constructor
41  Node(
42  const OperatorDef& op,
43  bool active,
44  std::map<int, std::vector<string>> parents,
45  std::map<int, std::vector<string>> children)
46  : op(op), active(active), parents(parents), children(children) {}
47 
48  // The OperatorDef which this node represents.
49  OperatorDef op;
50 
51  // Keeps track of if an operator has been deleted through a transformation.
52  bool active = true;
53 
54  // Stores a pair (idx, blob_list),
55  // idx = index of the child
56  // blob_list = a list of strings, containing the blobs that connect the nodes
57  std::map<int, std::vector<string>> parents;
58  std::map<int, std::vector<string>> children;
59 };
60 
64 struct Graph {
65  public:
73  const std::vector<std::pair<string, int>> GetSubgraphInput(
74  const std::vector<int>& subgraph);
75 
83  const std::vector<std::pair<string, int>> GetSubgraphOutput(
84  const std::vector<int>& subgraph);
85 
97  explicit Graph(const NetDef& net_def);
98 
114  NetDef GetNetDef();
115 
119  void DeactivateSubgraph(std::vector<int> subgraph);
120 
121  const size_t size() const {
122  return nodes_.size();
123  }
124 
125  void push_node(const Node& new_node) {
126  return nodes_.push_back(new_node);
127  }
128 
129  void resize_nodes(size_t new_size) {
130  nodes_.resize(new_size);
131  }
132 
133  // Index safe, less verbose way to access nodes
134  inline const Node& node(size_t idx) const {
135  return nodes_.at(idx);
136  }
137 
138  inline Node& node(size_t idx) {
139  return nodes_.at(idx);
140  }
141 
142  inline bool is_node_active(size_t idx) {
143  return node(idx).active;
144  }
145 
146  inline const std::set<string>& external_input() const {
147  return external_input_;
148  }
149 
150  inline const std::set<string>& external_output() const {
151  return external_output_;
152  }
153 
154  private:
155  const std::vector<std::pair<string, int>> GetSubgraphPerimeterHelper(
156  bool from_children,
157  const std::vector<int>& match);
158 
159  // Stores the netdef representation. Is updated upon calls to GetNetDef.
160  NetDef netdef_;
161 
162  // Stores which blobs the graph reads from, and writes to.
163  std::set<string> external_input_;
164  std::set<string> external_output_;
165 
166  // Keeps track of all the Operators currently within graph, even if inactive.
167  std::vector<Node> nodes_;
168 };
169 
170 } // namespace transform
171 
172 // Adds an operator def to a netdef.
173 // Returns the ptr, if you want to add anything extra (such as device_option)
174 OperatorDef* AddOp(
175  NetDef* netdef_ptr,
176  string op_type,
177  std::vector<string> inputs,
178  std::vector<string> outputs);
179 
187 bool MatchStrings(string p, string s);
188 
193 bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op);
194 
195 } // namespace caffe2
Graph representation of a Netdef.
Definition: graph.h:64
Graph representation of an operator.
Definition: graph.h:35
bool MatchArguments(const OperatorDef &p_op, const OperatorDef &g_op)
This ensures that each named arg that exists in the pattern exists in g_op, is equal in value...
Definition: graph.cc:244
Copyright (c) 2016-present, Facebook, Inc.
bool MatchStrings(string p, string s)
This allows for the use of * and | to match operator types, engines, or any other property that is re...
Definition: graph.cc:230