Caffe2 - C++ API
A deep learning, cross platform ML framework
net.cc
1 
17 #include "caffe2/core/net.h"
18 #include "caffe2/core/net_simple.h"
19 
20 #include <set>
21 #include <unordered_map>
22 #include <unordered_set>
23 
24 #include "caffe2/core/operator.h"
25 #include "caffe2/core/timer.h"
26 #include "caffe2/proto/caffe2.pb.h"
27 #include "caffe2/utils/proto_utils.h"
28 
29 namespace caffe2 {
30 
31 CAFFE_DEFINE_REGISTRY(
32  NetRegistry,
33  NetBase,
34  const std::shared_ptr<const NetDef>&,
35  Workspace*);
36 
37 NetBase::NetBase(
38  const std::shared_ptr<const NetDef>& def,
39  Workspace* /* unused */)
40  : external_input_(
41  def->external_input().begin(),
42  def->external_input().end()),
43  external_output_(
44  def->external_output().begin(),
45  def->external_output().end()),
46  name_(def->name()),
47  net_def_(def) {
48  // Check that node_name is empty for all ops
49  for (const OperatorDef& op : def->op()) {
50  if (op.has_device_option()) {
51  CAFFE_ENFORCE(
52  !op.device_option().has_node_name(),
53  "node_name must be empty for all operators at execution time.");
54  }
55  }
56 
57  // Go through the operators and make sure that blobs are correctly made.
58  std::set<string> known_blobs(
59  external_input_.begin(), external_input_.end());
60  std::set<string> remaining_output(
61  external_output_.begin(), external_output_.end());
62  for (const auto& blob : known_blobs) {
63  remaining_output.erase(blob);
64  }
65  for (const OperatorDef& op : def->op()) {
66  for (const string& in : op.input()) {
67  if (!known_blobs.count(in)) {
68  if (external_input_.size()) {
69  CAFFE_THROW(
70  "op ",
71  op.type(),
72  ": Source for input ",
73  in,
74  " is unknown for net ",
75  def->name(),
76  ", operator ",
77  ProtoDebugString(op));
78  } else {
79  // If we are not declaring input and output, we will simply VLOG it
80  // for debugging purposes.
81  VLOG(1) << "op " << op.type() << ": input " << in << " is unknown.";
82  }
83  }
84  }
85  for (const string& out : op.output()) {
86  known_blobs.insert(out);
87  remaining_output.erase(out);
88  }
89  }
90  // Finally, check if all declared outputs are being created.
91  CAFFE_ENFORCE(
92  remaining_output.size() == 0,
93  "Some of the blobs are declared as output but never produced by the "
94  "net ",
95  def->name(),
96  ", the first one is ",
97  *remaining_output.begin());
98 }
99 
100 bool NetBase::RunAsync() {
101  for (auto& op : GetOperators()) {
102  op->ResetEvent();
103  }
104  return DoRunAsync();
105 }
106 
107 static NetObserverCreator GlobalNetObserverCreator = [](NetBase* net) {
108  // A no-op ObserverBase<NetBase> observer
109  return std::unique_ptr<NetObserver>(new NetObserver(net));
110 };
111 
112 void SetGlobalNetObserverCreator(NetObserverCreator creator) {
113  GlobalNetObserverCreator = creator;
114  VLOG(1) << "Have set custom GlobalNetObserverCreator";
115 }
116 
117 unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws) {
118  std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
119  return CreateNet(tmp_net_def, ws);
120 }
121 
122 unique_ptr<NetBase> CreateNet(
123  const std::shared_ptr<const NetDef>& net_def,
124  Workspace* ws) {
125  // In default, we will return a simple network that just runs all operators
126  // sequentially.
127  unique_ptr<NetBase> net;
128  if (!net_def->has_type()) {
129  net = std::unique_ptr<NetBase>(new SimpleNet(net_def, ws));
130  } else {
131  net = NetRegistry()->Create(net_def->type(), net_def, ws);
132  }
133  VLOG(1) << "Adding a global observer to a net";
134  if (net) {
135  net->AttachObserver(GlobalNetObserverCreator(net.get()));
136  }
137  return net;
138 }
139 
140 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:117