Caffe2 - C++ API
A deep learning, cross platform ML framework
workspace.cc
1 #include "caffe2/core/workspace.h"
2 
3 #include <algorithm>
4 #include <ctime>
5 #include <mutex>
6 
7 #include "caffe2/core/logging.h"
8 #include "caffe2/core/net.h"
9 #include "caffe2/core/operator.h"
10 #include "caffe2/core/plan_executor.h"
11 #include "caffe2/core/tensor.h"
12 #include "caffe2/proto/caffe2_pb.h"
13 
14 C10_DEFINE_bool(
15  caffe2_print_blob_sizes_at_exit,
16  false,
17  "If true, workspace destructor will print all blob shapes");
18 
19 namespace caffe2 {
20 
21 void Workspace::PrintBlobSizes() {
22  vector<string> blobs = LocalBlobs();
23  size_t cumtotal = 0;
24 
25  // First get total sizes and sort
26  vector<std::pair<size_t, std::string>> blob_sizes;
27  for (const auto& s : blobs) {
28  Blob* b = this->GetBlob(s);
29  TensorInfoCall shape_fun = GetTensorInfoFunction(b->meta().id());
30  if (shape_fun) {
31  size_t capacity;
32  DeviceOption _device;
33  auto shape = shape_fun(b->GetRaw(), &capacity, &_device);
34  // NB: currently it overcounts capacity of shared storages
35  // TODO: fix it after the storage sharing is merged
36  cumtotal += capacity;
37  blob_sizes.push_back(make_pair(capacity, s));
38  }
39  }
40  std::sort(
41  blob_sizes.begin(),
42  blob_sizes.end(),
43  [](const std::pair<size_t, std::string>& a,
44  const std::pair<size_t, std::string>& b) {
45  return b.first < a.first;
46  });
47 
48  // Then print in descending order
49  LOG(INFO) << "---- Workspace blobs: ---- ";
50  LOG(INFO) << "name;current shape;capacity bytes;percentage";
51  for (const auto& sb : blob_sizes) {
52  Blob* b = this->GetBlob(sb.second);
53  TensorInfoCall shape_fun = GetTensorInfoFunction(b->meta().id());
54  CHECK(shape_fun != nullptr);
55  size_t capacity;
56  DeviceOption _device;
57 
58  auto shape = shape_fun(b->GetRaw(), &capacity, &_device);
59  std::stringstream ss;
60  ss << sb.second << ";";
61  for (const auto d : shape) {
62  ss << d << ",";
63  }
64  LOG(INFO) << ss.str() << ";" << sb.first << ";" << std::setprecision(3)
65  << (cumtotal > 0 ? 100.0 * double(sb.first) / cumtotal : 0.0)
66  << "%";
67  }
68  LOG(INFO) << "Total;;" << cumtotal << ";100%";
69 }
70 
71 vector<string> Workspace::LocalBlobs() const {
72  vector<string> names;
73  names.reserve(blob_map_.size());
74  for (auto& entry : blob_map_) {
75  names.push_back(entry.first);
76  }
77  return names;
78 }
79 
80 vector<string> Workspace::Blobs() const {
81  vector<string> names;
82  names.reserve(blob_map_.size());
83  for (auto& entry : blob_map_) {
84  names.push_back(entry.first);
85  }
86  for (const auto& forwarded : forwarded_blobs_) {
87  const auto parent_ws = forwarded.second.first;
88  const auto& parent_name = forwarded.second.second;
89  if (parent_ws->HasBlob(parent_name)) {
90  names.push_back(forwarded.first);
91  }
92  }
93  if (shared_) {
94  const auto& shared_blobs = shared_->Blobs();
95  names.insert(names.end(), shared_blobs.begin(), shared_blobs.end());
96  }
97  return names;
98 }
99 
100 Blob* Workspace::CreateBlob(const string& name) {
101  if (HasBlob(name)) {
102  VLOG(1) << "Blob " << name << " already exists. Skipping.";
103  } else if (forwarded_blobs_.count(name)) {
104  // possible if parent workspace deletes forwarded blob
105  VLOG(1) << "Blob " << name << " is already forwarded from parent workspace "
106  << "(blob " << forwarded_blobs_[name].second << "). Skipping.";
107  } else {
108  VLOG(1) << "Creating blob " << name;
109  blob_map_[name] = unique_ptr<Blob>(new Blob());
110  }
111  return GetBlob(name);
112 }
113 
114 Blob* Workspace::CreateLocalBlob(const string& name) {
115  if (blob_map_.count(name)) {
116  VLOG(1) << "Blob " << name << " already exists. Skipping.";
117  } else {
118  VLOG(1) << "Creating blob " << name;
119  blob_map_[name] = unique_ptr<Blob>(new Blob());
120  }
121  return GetBlob(name);
122 }
123 
124 Blob* Workspace::RenameBlob(const string& old_name, const string& new_name) {
125  // We allow renaming only local blobs for API clarity purpose
126  auto it = blob_map_.find(old_name);
127  CAFFE_ENFORCE(
128  it != blob_map_.end(),
129  "Blob ",
130  old_name,
131  " is not in the local blob list");
132 
133  // New blob can't be in any parent either, otherwise it will hide a parent
134  // blob
135  CAFFE_ENFORCE(
136  !HasBlob(new_name), "Blob ", new_name, "is already in the workspace");
137 
138  // First delete the old record
139  auto value = std::move(it->second);
140  blob_map_.erase(it);
141 
142  auto* raw_ptr = value.get();
143  blob_map_[new_name] = std::move(value);
144  return raw_ptr;
145 }
146 
147 bool Workspace::RemoveBlob(const string& name) {
148  auto it = blob_map_.find(name);
149  if (it != blob_map_.end()) {
150  VLOG(1) << "Removing blob " << name << " from this workspace.";
151  blob_map_.erase(it);
152  return true;
153  }
154 
155  // won't go into shared_ here
156  VLOG(1) << "Blob " << name << " not exists. Skipping.";
157  return false;
158 }
159 
160 const Blob* Workspace::GetBlob(const string& name) const {
161  if (blob_map_.count(name)) {
162  return blob_map_.at(name).get();
163  } else if (forwarded_blobs_.count(name)) {
164  const auto parent_ws = forwarded_blobs_.at(name).first;
165  const auto& parent_name = forwarded_blobs_.at(name).second;
166  return parent_ws->GetBlob(parent_name);
167  } else if (shared_ && shared_->HasBlob(name)) {
168  return shared_->GetBlob(name);
169  }
170  LOG(WARNING) << "Blob " << name << " not in the workspace.";
171  // TODO(Yangqing): do we want to always print out the list of blobs here?
172  // LOG(WARNING) << "Current blobs:";
173  // for (const auto& entry : blob_map_) {
174  // LOG(WARNING) << entry.first;
175  // }
176  return nullptr;
177 }
178 
180  const Workspace* parent,
181  const std::unordered_map<string, string>& forwarded_blobs,
182  bool skip_defined_blobs) {
183  CAFFE_ENFORCE(parent, "Parent workspace must be specified");
184  for (const auto& forwarded : forwarded_blobs) {
185  CAFFE_ENFORCE(
186  parent->HasBlob(forwarded.second),
187  "Invalid parent workspace blob " + forwarded.second);
188  if (forwarded_blobs_.count(forwarded.first)) {
189  const auto& ws_blob = forwarded_blobs_[forwarded.first];
190  CAFFE_ENFORCE_EQ(
191  ws_blob.first, parent, "Redefinition of blob " + forwarded.first);
192  CAFFE_ENFORCE_EQ(
193  ws_blob.second,
194  forwarded.second,
195  "Redefinition of blob " + forwarded.first);
196  } else {
197  if (skip_defined_blobs && HasBlob(forwarded.first)) {
198  continue;
199  }
200  CAFFE_ENFORCE(
201  !HasBlob(forwarded.first), "Redefinition of blob " + forwarded.first);
202  // Lazy blob resolution - store the parent workspace and
203  // blob name, blob value might change in the parent workspace
204  forwarded_blobs_[forwarded.first] =
205  std::make_pair(parent, forwarded.second);
206  }
207  }
208 }
209 
210 Blob* Workspace::GetBlob(const string& name) {
211  return const_cast<Blob*>(static_cast<const Workspace*>(this)->GetBlob(name));
212 }
213 
214 NetBase* Workspace::CreateNet(const NetDef& net_def, bool overwrite) {
215  std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
216  return CreateNet(tmp_net_def, overwrite);
217 }
218 
220  const std::shared_ptr<const NetDef>& net_def,
221  bool overwrite) {
222  CAFFE_ENFORCE(net_def->has_name(), "Net definition should have a name.");
223  if (net_map_.count(net_def->name()) > 0) {
224  if (!overwrite) {
225  CAFFE_THROW(
226  "I respectfully refuse to overwrite an existing net of the same "
227  "name \"",
228  net_def->name(),
229  "\", unless you explicitly specify overwrite=true.");
230  }
231  VLOG(1) << "Deleting existing network of the same name.";
232  // Note(Yangqing): Why do we explicitly erase it here? Some components of
233  // the old network, such as an opened LevelDB, may prevent us from creating
234  // a new network before the old one is deleted. Thus we will need to first
235  // erase the old one before the new one can be constructed.
236  net_map_.erase(net_def->name());
237  }
238  // Create a new net with its name.
239  VLOG(1) << "Initializing network " << net_def->name();
240  net_map_[net_def->name()] =
241  unique_ptr<NetBase>(caffe2::CreateNet(net_def, this));
242  if (net_map_[net_def->name()].get() == nullptr) {
243  LOG(ERROR) << "Error when creating the network."
244  << "Maybe net type: [" << net_def->type() << "] does not exist";
245  net_map_.erase(net_def->name());
246  return nullptr;
247  }
248  return net_map_[net_def->name()].get();
249 }
250 
251 NetBase* Workspace::GetNet(const string& name) {
252  if (!net_map_.count(name)) {
253  return nullptr;
254  } else {
255  return net_map_[name].get();
256  }
257 }
258 
259 void Workspace::DeleteNet(const string& name) {
260  if (net_map_.count(name)) {
261  net_map_.erase(name);
262  }
263 }
264 
265 bool Workspace::RunNet(const string& name) {
266  if (!net_map_.count(name)) {
267  LOG(ERROR) << "Network " << name << " does not exist yet.";
268  return false;
269  }
270  return net_map_[name]->Run();
271 }
272 
273 bool Workspace::RunOperatorOnce(const OperatorDef& op_def) {
274  std::unique_ptr<OperatorBase> op(CreateOperator(op_def, this));
275  if (op.get() == nullptr) {
276  LOG(ERROR) << "Cannot create operator of type " << op_def.type();
277  return false;
278  }
279  if (!op->Run()) {
280  LOG(ERROR) << "Error when running operator " << op_def.type();
281  return false;
282  }
283  return true;
284 }
285 bool Workspace::RunNetOnce(const NetDef& net_def) {
286  std::unique_ptr<NetBase> net(caffe2::CreateNet(net_def, this));
287  if (net == nullptr) {
288  CAFFE_THROW(
289  "Could not create net: " + net_def.name() + " of type " +
290  net_def.type());
291  }
292  if (!net->Run()) {
293  LOG(ERROR) << "Error when running network " << net_def.name();
294  return false;
295  }
296  return true;
297 }
298 
299 bool Workspace::RunPlan(const PlanDef& plan, ShouldContinue shouldContinue) {
300  return RunPlanOnWorkspace(this, plan, shouldContinue);
301 }
302 
303 ThreadPool* Workspace::GetThreadPool() {
304  std::lock_guard<std::mutex> guard(thread_pool_creation_mutex_);
305  if (!thread_pool_) {
306  thread_pool_ = ThreadPool::defaultThreadPool();
307  }
308  return thread_pool_.get();
309 }
310 
311 std::shared_ptr<Workspace::Bookkeeper> Workspace::bookkeeper() {
312  static auto shared = std::make_shared<Workspace::Bookkeeper>();
313  return shared;
314 }
315 
316 } // namespace caffe2
Blob is a general container that hosts a typed pointer.
Definition: blob.h:24
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:100
void DeleteNet(const string &net_name)
Deletes the instantiated network with the given name.
Definition: workspace.cc:259
bool RunPlan(const PlanDef &plan_def, ShouldContinue should_continue=StopOnSignal{})
Runs a plan that has multiple nets and execution steps.
Definition: workspace.cc:299
Blob * CreateLocalBlob(const string &name)
Similar to CreateBlob(), but it creates a blob in the local workspace even if another blob with the s...
Definition: workspace.cc:114
bool RemoveBlob(const string &name)
Remove the blob of the given name.
Definition: workspace.cc:147
vector< string > LocalBlobs() const
Return list of blobs owned by this Workspace, not including blobs shared from parent workspace...
Definition: workspace.cc:71
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:160
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasBlob(const string &name) const
Checks if a blob with the given name is present in the current workspace.
Definition: workspace.h:179
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:151
void AddBlobMapping(const Workspace *parent, const std::unordered_map< string, string > &forwarded_blobs, bool skip_defined_blobs=false)
Adds blob mappings from workspace to the blobs from parent workspace.
Definition: workspace.cc:179
Blob * RenameBlob(const string &old_name, const string &new_name)
Renames a local workspace blob.
Definition: workspace.cc:124
bool RunNet(const string &net_name)
Finds and runs the instantiated network with the given name.
Definition: workspace.cc:265
NetBase * GetNet(const string &net_name)
Gets the pointer to a created net.
Definition: workspace.cc:251
vector< string > Blobs() const
Return a list of blob names.
Definition: workspace.cc:80
NetBase * CreateNet(const NetDef &net_def, bool overwrite=false)
Creates a network with the given NetDef, and returns the pointer to the network.
Definition: workspace.cc:214