Caffe2 - C++ API
A deep learning, cross platform ML framework
graph.h
1 #pragma once
2 
3 #include "caffe2/core/common.h"
4 #include "caffe2/proto/caffe2_pb.h"
5 #include "caffe2/utils/proto_utils.h"
6 #include "caffe2/utils/string_utils.h"
7 
8 #include <algorithm>
9 #include <unordered_map>
10 #include <unordered_set>
11 
12 namespace caffe2 {
13 
14 namespace transform {
15 
19 struct CAFFE2_API Node {
20  public:
21  // Empty constructor for resize
22  Node() {}
23 
24  // Alternate constructor
25  Node(
26  const OperatorDef& op,
27  bool active,
28  std::map<int, std::vector<string>> parents,
29  std::map<int, std::vector<string>> children)
30  : op(op), active(active), parents(parents), children(children) {}
31 
32  // The OperatorDef which this node represents.
33  OperatorDef op;
34 
35  // Keeps track of if an operator has been deleted through a transformation.
36  bool active = true;
37 
38  // Stores a pair (idx, blob_list),
39  // idx = index of the child
40  // blob_list = a list of strings, containing the blobs that connect the nodes
41  std::map<int, std::vector<string>> parents;
42  std::map<int, std::vector<string>> children;
43 };
44 
48 struct CAFFE2_API Graph {
49  public:
57  const std::vector<std::pair<string, int>> GetSubgraphInput(
58  const std::vector<int>& subgraph);
59 
67  const std::vector<std::pair<string, int>> GetSubgraphOutput(
68  const std::vector<int>& subgraph);
69 
81  explicit Graph(const NetDef& net_def);
82 
98  NetDef GetNetDef();
99 
103  void DeactivateSubgraph(std::vector<int> subgraph);
104 
105  size_t size() const {
106  return nodes_.size();
107  }
108 
109  void push_node(const Node& new_node) {
110  return nodes_.push_back(new_node);
111  }
112 
113  void resize_nodes(size_t new_size) {
114  nodes_.resize(new_size);
115  }
116 
117  // Index safe, less verbose way to access nodes
118  inline const Node& node(size_t idx) const {
119  return nodes_.at(idx);
120  }
121 
122  inline Node& node(size_t idx) {
123  return nodes_.at(idx);
124  }
125 
126  inline bool is_node_active(size_t idx) {
127  return node(idx).active;
128  }
129 
130  inline const std::set<string>& external_input() const {
131  return external_input_;
132  }
133 
134  inline const std::set<string>& external_output() const {
135  return external_output_;
136  }
137 
138  private:
139  const std::vector<std::pair<string, int>> GetSubgraphPerimeterHelper(
140  bool from_children,
141  const std::vector<int>& match);
142 
143  // Stores the netdef representation. Is updated upon calls to GetNetDef.
144  NetDef netdef_;
145 
146  // Stores which blobs the graph reads from, and writes to.
147  std::set<string> external_input_;
148  std::set<string> external_output_;
149 
150  // Keeps track of all the Operators currently within graph, even if inactive.
151  std::vector<Node> nodes_;
152 };
153 
154 } // namespace transform
155 
156 // Adds an operator def to a netdef.
157 // Returns the ptr, if you want to add anything extra (such as device_option)
158 CAFFE2_API OperatorDef* AddOp(
159  NetDef* netdef_ptr,
160  string op_type,
161  std::vector<string> inputs,
162  std::vector<string> outputs);
163 
171 CAFFE2_API bool MatchStrings(string p, string s);
172 
177 CAFFE2_API bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op);
178 
179 } // namespace caffe2
Graph representation of a Netdef.
Definition: graph.h:48
bool MatchStrings(string p, string s)
This allows for the use of * and | to match operator types, engines, or any other property that is re...
Definition: graph.cc:214
bool MatchArguments(const OperatorDef &p_op, const OperatorDef &g_op)
This ensures that each named arg that exists in the pattern exists in g_op, is equal in value...
Definition: graph.cc:228
Graph representation of an operator.
Definition: graph.h:19
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13