Caffe2 - C++ API
A deep learning, cross platform ML framework
workspace.cc
1 
17 #include "caffe2/core/workspace.h"
18 
19 #include <algorithm>
20 #include <ctime>
21 #include <mutex>
22 
23 #include "caffe2/core/logging.h"
24 #include "caffe2/core/net.h"
25 #include "caffe2/core/operator.h"
26 #include "caffe2/core/plan_executor.h"
27 #include "caffe2/core/tensor.h"
28 #include "caffe2/proto/caffe2.pb.h"
29 
30 CAFFE2_DEFINE_bool(
31  caffe2_print_blob_sizes_at_exit,
32  false,
33  "If true, workspace destructor will print all blob shapes");
34 
35 namespace caffe2 {
36 
37 void Workspace::PrintBlobSizes() {
38  vector<string> blobs = LocalBlobs();
39  size_t cumtotal = 0;
40 
41  // First get total sizes and sort
42  vector<std::pair<size_t, std::string>> blob_sizes;
43  for (const auto& s : blobs) {
44  Blob* b = this->GetBlob(s);
45  TensorInfoCall shape_fun = GetTensorInfoFunction(b->meta().id());
46  if (shape_fun) {
47  bool shares_data = false;
48  size_t capacity;
49  DeviceOption _device;
50  auto shape = shape_fun(b->GetRaw(), &shares_data, &capacity, &_device);
51  if (shares_data) {
52  // Blobs sharing data do not actually take any memory
53  capacity = 0;
54  }
55  cumtotal += capacity;
56  blob_sizes.push_back(make_pair(capacity, s));
57  }
58  }
59  std::sort(
60  blob_sizes.begin(),
61  blob_sizes.end(),
62  [](const std::pair<size_t, std::string>& a,
63  const std::pair<size_t, std::string>& b) {
64  return b.first < a.first;
65  });
66 
67  // Then print in descending order
68  LOG(INFO) << "---- Workspace blobs: ---- ";
69  LOG(INFO) << "name;current shape;capacity bytes;percentage";
70  for (const auto& sb : blob_sizes) {
71  Blob* b = this->GetBlob(sb.second);
72  TensorInfoCall shape_fun = GetTensorInfoFunction(b->meta().id());
73  CHECK(shape_fun != nullptr);
74  bool _shares_data = false;
75  size_t capacity;
76  DeviceOption _device;
77 
78  auto shape = shape_fun(b->GetRaw(), &_shares_data, &capacity, &_device);
79  std::stringstream ss;
80  ss << sb.second << ";";
81  for (const auto d : shape) {
82  ss << d << ",";
83  }
84  LOG(INFO) << ss.str() << ";" << sb.first << ";" << std::setprecision(3)
85  << (cumtotal > 0 ? 100.0 * double(sb.first) / cumtotal : 0.0)
86  << "%";
87  }
88  LOG(INFO) << "Total;;" << cumtotal << ";100%";
89 }
90 
91 vector<string> Workspace::LocalBlobs() const {
92  vector<string> names;
93  names.reserve(blob_map_.size());
94  for (auto& entry : blob_map_) {
95  names.push_back(entry.first);
96  }
97  return names;
98 }
99 
100 vector<string> Workspace::Blobs() const {
101  vector<string> names;
102  names.reserve(blob_map_.size());
103  for (auto& entry : blob_map_) {
104  names.push_back(entry.first);
105  }
106  for (const auto& forwarded : forwarded_blobs_) {
107  const auto parent_ws = forwarded.second.first;
108  const auto& parent_name = forwarded.second.second;
109  if (parent_ws->HasBlob(parent_name)) {
110  names.push_back(forwarded.first);
111  }
112  }
113  if (shared_) {
114  const auto& shared_blobs = shared_->Blobs();
115  names.insert(names.end(), shared_blobs.begin(), shared_blobs.end());
116  }
117  return names;
118 }
119 
120 Blob* Workspace::CreateBlob(const string& name) {
121  if (HasBlob(name)) {
122  VLOG(1) << "Blob " << name << " already exists. Skipping.";
123  } else if (forwarded_blobs_.count(name)) {
124  // possible if parent workspace deletes forwarded blob
125  VLOG(1) << "Blob " << name << " is already forwarded from parent workspace "
126  << "(blob " << forwarded_blobs_[name].second << "). Skipping.";
127  } else {
128  VLOG(1) << "Creating blob " << name;
129  blob_map_[name] = unique_ptr<Blob>(new Blob());
130  }
131  return GetBlob(name);
132 }
133 
134 Blob* Workspace::CreateLocalBlob(const string& name) {
135  if (blob_map_.count(name)) {
136  VLOG(1) << "Blob " << name << " already exists. Skipping.";
137  } else {
138  VLOG(1) << "Creating blob " << name;
139  blob_map_[name] = unique_ptr<Blob>(new Blob());
140  }
141  return GetBlob(name);
142 }
143 
144 Blob* Workspace::RenameBlob(const string& old_name, const string& new_name) {
145  // We allow renaming only local blobs for API clarity purpose
146  auto it = blob_map_.find(old_name);
147  CAFFE_ENFORCE(
148  it != blob_map_.end(),
149  "Blob ",
150  old_name,
151  " is not in the local blob list");
152 
153  // New blob can't be in any parent either, otherwise it will hide a parent
154  // blob
155  CAFFE_ENFORCE(
156  !HasBlob(new_name), "Blob ", new_name, "is already in the workspace");
157 
158  // First delete the old record
159  auto value = std::move(it->second);
160  blob_map_.erase(it);
161 
162  auto* raw_ptr = value.get();
163  blob_map_[new_name] = std::move(value);
164  return raw_ptr;
165 }
166 
167 bool Workspace::RemoveBlob(const string& name) {
168  auto it = blob_map_.find(name);
169  if (it != blob_map_.end()) {
170  VLOG(1) << "Removing blob " << name << " from this workspace.";
171  blob_map_.erase(it);
172  return true;
173  }
174 
175  // won't go into shared_ here
176  VLOG(1) << "Blob " << name << " not exists. Skipping.";
177  return false;
178 }
179 
180 const Blob* Workspace::GetBlob(const string& name) const {
181  if (blob_map_.count(name)) {
182  return blob_map_.at(name).get();
183  } else if (forwarded_blobs_.count(name)) {
184  const auto parent_ws = forwarded_blobs_.at(name).first;
185  const auto& parent_name = forwarded_blobs_.at(name).second;
186  return parent_ws->GetBlob(parent_name);
187  } else if (shared_ && shared_->HasBlob(name)) {
188  return shared_->GetBlob(name);
189  }
190  LOG(WARNING) << "Blob " << name << " not in the workspace.";
191  // TODO(Yangqing): do we want to always print out the list of blobs here?
192  // LOG(WARNING) << "Current blobs:";
193  // for (const auto& entry : blob_map_) {
194  // LOG(WARNING) << entry.first;
195  // }
196  return nullptr;
197 }
198 
200  const Workspace* parent,
201  const std::unordered_map<string, string>& forwarded_blobs,
202  bool skip_defined_blobs) {
203  CAFFE_ENFORCE(parent, "Parent workspace must be specified");
204  for (const auto& forwarded : forwarded_blobs) {
205  CAFFE_ENFORCE(
206  parent->HasBlob(forwarded.second),
207  "Invalid parent workspace blob " + forwarded.second);
208  if (forwarded_blobs_.count(forwarded.first)) {
209  const auto& ws_blob = forwarded_blobs_[forwarded.first];
210  CAFFE_ENFORCE_EQ(
211  ws_blob.first, parent, "Redefinition of blob " + forwarded.first);
212  CAFFE_ENFORCE_EQ(
213  ws_blob.second,
214  forwarded.second,
215  "Redefinition of blob " + forwarded.first);
216  } else {
217  if (skip_defined_blobs && HasBlob(forwarded.first)) {
218  continue;
219  }
220  CAFFE_ENFORCE(
221  !HasBlob(forwarded.first), "Redefinition of blob " + forwarded.first);
222  // Lazy blob resolution - store the parent workspace and
223  // blob name, blob value might change in the parent workspace
224  forwarded_blobs_[forwarded.first] =
225  std::make_pair(parent, forwarded.second);
226  }
227  }
228 }
229 
230 Blob* Workspace::GetBlob(const string& name) {
231  return const_cast<Blob*>(static_cast<const Workspace*>(this)->GetBlob(name));
232 }
233 
234 NetBase* Workspace::CreateNet(const NetDef& net_def, bool overwrite) {
235  std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
236  return CreateNet(tmp_net_def, overwrite);
237 }
238 
240  const std::shared_ptr<const NetDef>& net_def,
241  bool overwrite) {
242  CAFFE_ENFORCE(net_def->has_name(), "Net definition should have a name.");
243  if (net_map_.count(net_def->name()) > 0) {
244  if (!overwrite) {
245  CAFFE_THROW(
246  "I respectfully refuse to overwrite an existing net of the same "
247  "name \"",
248  net_def->name(),
249  "\", unless you explicitly specify overwrite=true.");
250  }
251  VLOG(1) << "Deleting existing network of the same name.";
252  // Note(Yangqing): Why do we explicitly erase it here? Some components of
253  // the old network, such as an opened LevelDB, may prevent us from creating
254  // a new network before the old one is deleted. Thus we will need to first
255  // erase the old one before the new one can be constructed.
256  net_map_.erase(net_def->name());
257  }
258  // Create a new net with its name.
259  VLOG(1) << "Initializing network " << net_def->name();
260  net_map_[net_def->name()] =
261  unique_ptr<NetBase>(caffe2::CreateNet(net_def, this));
262  if (net_map_[net_def->name()].get() == nullptr) {
263  LOG(ERROR) << "Error when creating the network."
264  << "Maybe net type: [" << net_def->type() << "] does not exist";
265  net_map_.erase(net_def->name());
266  return nullptr;
267  }
268  return net_map_[net_def->name()].get();
269 }
270 
271 NetBase* Workspace::GetNet(const string& name) {
272  if (!net_map_.count(name)) {
273  return nullptr;
274  } else {
275  return net_map_[name].get();
276  }
277 }
278 
279 void Workspace::DeleteNet(const string& name) {
280  if (net_map_.count(name)) {
281  net_map_.erase(name);
282  }
283 }
284 
285 bool Workspace::RunNet(const string& name) {
286  if (!net_map_.count(name)) {
287  LOG(ERROR) << "Network " << name << " does not exist yet.";
288  return false;
289  }
290  return net_map_[name]->Run();
291 }
292 
293 bool Workspace::RunOperatorOnce(const OperatorDef& op_def) {
294  std::unique_ptr<OperatorBase> op(CreateOperator(op_def, this));
295  if (op.get() == nullptr) {
296  LOG(ERROR) << "Cannot create operator of type " << op_def.type();
297  return false;
298  }
299  if (!op->Run()) {
300  LOG(ERROR) << "Error when running operator " << op_def.type();
301  return false;
302  }
303  return true;
304 }
305 bool Workspace::RunNetOnce(const NetDef& net_def) {
306  std::unique_ptr<NetBase> net(caffe2::CreateNet(net_def, this));
307  if (net == nullptr) {
308  CAFFE_THROW(
309  "Could not create net: " + net_def.name() + " of type " +
310  net_def.type());
311  }
312  if (!net->Run()) {
313  LOG(ERROR) << "Error when running network " << net_def.name();
314  return false;
315  }
316  return true;
317 }
318 
319 bool Workspace::RunPlan(const PlanDef& plan, ShouldContinue shouldContinue) {
320  return RunPlanOnWorkspace(this, plan, shouldContinue);
321 }
322 
323 ThreadPool* Workspace::GetThreadPool() {
324  std::lock_guard<std::mutex> guard(thread_pool_creation_mutex_);
325  if (!thread_pool_) {
326  thread_pool_ = ThreadPool::defaultThreadPool();
327  }
328  return thread_pool_.get();
329 }
330 
331 } // namespace caffe2
Blob is a general container that hosts a typed pointer.
Definition: blob.h:41
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:120
void DeleteNet(const string &net_name)
Deletes the instantiated network with the given name.
Definition: workspace.cc:279
bool RunPlan(const PlanDef &plan_def, ShouldContinue should_continue=StopOnSignal{})
Runs a plan that has multiple nets and execution steps.
Definition: workspace.cc:319
Blob * CreateLocalBlob(const string &name)
Similar to CreateBlob(), but it creates a blob in the local workspace even if another blob with the s...
Definition: workspace.cc:134
bool RemoveBlob(const string &name)
Remove the blob of the given name.
Definition: workspace.cc:167
vector< string > LocalBlobs() const
Return list of blobs owned by this Workspace, not including blobs shared from parent workspace...
Definition: workspace.cc:91
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:180
Copyright (c) 2016-present, Facebook, Inc.
bool HasBlob(const string &name) const
Checks if a blob with the given name is present in the current workspace.
Definition: workspace.h:187
void AddBlobMapping(const Workspace *parent, const std::unordered_map< string, string > &forwarded_blobs, bool skip_defined_blobs=false)
Adds blob mappings from workspace to the blobs from parent workspace.
Definition: workspace.cc:199
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:117
Blob * RenameBlob(const string &old_name, const string &new_name)
Renames a local workspace blob.
Definition: workspace.cc:144
bool RunNet(const string &net_name)
Finds and runs the instantiated network with the given name.
Definition: workspace.cc:285
NetBase * GetNet(const string &net_name)
Gets the pointer to a created net.
Definition: workspace.cc:271
vector< string > Blobs() const
Return a list of blob names.
Definition: workspace.cc:100
NetBase * CreateNet(const NetDef &net_def, bool overwrite=false)
Creates a network with the given NetDef, and returns the pointer to the network.
Definition: workspace.cc:234