Caffe2 - C++ API
A deep learning, cross platform ML framework
workspace.h
1 
17 #ifndef CAFFE2_CORE_WORKSPACE_H_
18 #define CAFFE2_CORE_WORKSPACE_H_
19 
20 #include "caffe2/core/common.h"
21 #include "caffe2/core/observer.h"
22 
23 #include <climits>
24 #include <cstddef>
25 #include <mutex>
26 #include <typeinfo>
27 #include <unordered_set>
28 #include <vector>
29 
30 #include "caffe2/core/blob.h"
31 #include "caffe2/core/registry.h"
32 #include "caffe2/core/net.h"
33 #include "caffe2/proto/caffe2.pb.h"
34 #include "caffe2/utils/signal_handler.h"
35 #include "caffe2/utils/threadpool/ThreadPool.h"
36 
37 CAFFE2_DECLARE_bool(caffe2_print_blob_sizes_at_exit);
38 
39 namespace caffe2 {
40 
41 class NetBase;
42 
43 struct StopOnSignal {
44  StopOnSignal()
45  : handler_(std::make_shared<SignalHandler>(
46  SignalHandler::Action::STOP,
47  SignalHandler::Action::STOP)) {}
48 
49  StopOnSignal(const StopOnSignal& other) : handler_(other.handler_) {}
50 
51  bool operator()(int /*iter*/) {
52  return handler_->CheckForSignals() != SignalHandler::Action::STOP;
53  }
54 
55  std::shared_ptr<SignalHandler> handler_;
56 };
57 
63 class Workspace {
64  public:
65  typedef std::function<bool(int)> ShouldContinue;
66  typedef CaffeMap<string, unique_ptr<Blob> > BlobMap;
67  typedef CaffeMap<string, unique_ptr<NetBase> > NetMap;
71  Workspace() : root_folder_("."), shared_(nullptr) {}
72 
80  explicit Workspace(const string& root_folder)
81  : root_folder_(root_folder), shared_(nullptr) {}
82 
92  explicit Workspace(const Workspace* shared)
93  : root_folder_("."), shared_(shared) {}
94 
101  const Workspace* shared,
102  const std::unordered_map<string, string>& forwarded_blobs)
103  : root_folder_("."), shared_(nullptr) {
104  CAFFE_ENFORCE(shared, "Parent workspace must be specified");
105  for (const auto& forwarded : forwarded_blobs) {
106  CAFFE_ENFORCE(
107  shared->HasBlob(forwarded.second), "Invalid parent workspace blob");
108  forwarded_blobs_[forwarded.first] =
109  std::make_pair(shared, forwarded.second);
110  }
111  }
112 
116  Workspace(const string& root_folder, Workspace* shared)
117  : root_folder_(root_folder), shared_(shared) {}
118 
119  ~Workspace() {
120  if (FLAGS_caffe2_print_blob_sizes_at_exit) {
121  PrintBlobSizes();
122  }
123  }
124 
135  void AddBlobMapping(
136  const Workspace* parent,
137  const std::unordered_map<string, string>& forwarded_blobs,
138  bool skip_defined_blobs = false);
139 
144  template <class Context>
145  void CopyForwardedTensors(const std::unordered_set<std::string>& blobs) {
146  for (const auto& blob : blobs) {
147  if (!forwarded_blobs_.count(blob)) {
148  continue;
149  }
150  const auto& ws_blob = forwarded_blobs_[blob];
151  const auto* parent_ws = ws_blob.first;
152  auto* from_blob = parent_ws->GetBlob(ws_blob.second);
153  CAFFE_ENFORCE(from_blob);
154  CAFFE_ENFORCE(
155  from_blob->template IsType<Tensor<Context>>(),
156  "Expected blob with tensor value",
157  ws_blob.second);
158  forwarded_blobs_.erase(blob);
159  auto* to_blob = CreateBlob(blob);
160  CAFFE_ENFORCE(to_blob);
161  const auto& from_tensor = from_blob->template Get<Tensor<Context>>();
162  auto* to_tensor = to_blob->template GetMutable<Tensor<Context>>();
163  to_tensor->CopyFrom(from_tensor);
164  }
165  }
166 
171  vector<string> LocalBlobs() const;
172 
178  vector<string> Blobs() const;
179 
183  const string& RootFolder() { return root_folder_; }
187  inline bool HasBlob(const string& name) const {
188  // First, check the local workspace,
189  // Then, check the forwarding map, then the parent workspace
190  if (blob_map_.count(name)) {
191  return true;
192  } else if (forwarded_blobs_.count(name)) {
193  const auto parent_ws = forwarded_blobs_.at(name).first;
194  const auto& parent_name = forwarded_blobs_.at(name).second;
195  return parent_ws->HasBlob(parent_name);
196  } else if (shared_) {
197  return shared_->HasBlob(name);
198  }
199  return false;
200  }
201 
202  void PrintBlobSizes();
203 
209  Blob* CreateBlob(const string& name);
217  Blob* CreateLocalBlob(const string& name);
223  bool RemoveBlob(const string& name);
228  const Blob* GetBlob(const string& name) const;
233  Blob* GetBlob(const string& name);
234 
240  Blob* RenameBlob(const string& old_name, const string& new_name);
241 
251  NetBase* CreateNet(const NetDef& net_def, bool overwrite = false);
253  const std::shared_ptr<const NetDef>& net_def,
254  bool overwrite = false);
259  NetBase* GetNet(const string& net_name);
263  void DeleteNet(const string& net_name);
269  bool RunNet(const string& net_name);
270 
274  vector<string> Nets() const {
275  vector<string> names;
276  for (auto& entry : net_map_) {
277  names.push_back(entry.first);
278  }
279  return names;
280  }
281 
285  bool RunPlan(const PlanDef& plan_def,
286  ShouldContinue should_continue = StopOnSignal{});
287 
288  /*
289  * Returns a CPU threadpool instace for parallel execution of
290  * work. The threadpool is created lazily; if no operators use it,
291  * then no threadpool will be created.
292  */
293  ThreadPool* GetThreadPool();
294 
295  // RunOperatorOnce and RunNetOnce runs an operator or net once. The difference
296  // between RunNet and RunNetOnce lies in the fact that RunNet allows you to
297  // have a persistent net object, while RunNetOnce creates a net and discards
298  // it on the fly - this may make things like database read and random number
299  // generators repeat the same thing over multiple calls.
300  bool RunOperatorOnce(const OperatorDef& op_def);
301  bool RunNetOnce(const NetDef& net_def);
302 
303  public:
304  std::atomic<int> last_failed_op_net_position;
305 
306  private:
307  BlobMap blob_map_;
308  NetMap net_map_;
309  const string root_folder_;
310  const Workspace* shared_;
311  std::unordered_map<string, std::pair<const Workspace*, string>>
312  forwarded_blobs_;
313  std::unique_ptr<ThreadPool> thread_pool_;
314  std::mutex thread_pool_creation_mutex_;
315 
316  DISABLE_COPY_AND_ASSIGN(Workspace);
317 };
318 
319 } // namespace caffe2
320 
321 #endif // CAFFE2_CORE_WORKSPACE_H_
const string & RootFolder()
Return the root folder of the workspace.
Definition: workspace.h:183
Blob is a general container that hosts a typed pointer.
Definition: blob.h:41
Workspace(const Workspace *shared, const std::unordered_map< string, string > &forwarded_blobs)
Initializes workspace with parent workspace, blob name remapping (new name -> parent blob name)...
Definition: workspace.h:100
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace(const Workspace *shared)
Initializes a workspace with a shared workspace.
Definition: workspace.h:92
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
Workspace(const string &root_folder)
Initializes an empty workspace with the given root folder.
Definition: workspace.h:80
bool HasBlob(const string &name) const
Checks if a blob with the given name is present in the current workspace.
Definition: workspace.h:187
void CopyForwardedTensors(const std::unordered_set< std::string > &blobs)
Converts prevously mapped tensor blobs to local blobs, copies values from parent workspace blobs into...
Definition: workspace.h:145
vector< string > Nets() const
Returns a list of names of the currently instantiated networks.
Definition: workspace.h:274
Workspace()
Initializes an empty workspace.
Definition: workspace.h:71
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:117
Workspace(const string &root_folder, Workspace *shared)
Initializes a workspace with a root folder and a shared workspace.
Definition: workspace.h:116