1 #include "caffe2/core/workspace.h" 7 #include "caffe2/core/logging.h" 8 #include "caffe2/core/net.h" 9 #include "caffe2/core/operator.h" 10 #include "caffe2/core/plan_executor.h" 11 #include "caffe2/core/tensor.h" 12 #include "caffe2/proto/caffe2_pb.h" 15 caffe2_print_blob_sizes_at_exit,
17 "If true, workspace destructor will print all blob shapes");
21 void Workspace::PrintBlobSizes() {
26 vector<std::pair<size_t, std::string>> blob_sizes;
27 for (
const auto& s : blobs) {
29 TensorInfoCall shape_fun = GetTensorInfoFunction(b->meta().id());
33 auto shape = shape_fun(b->GetRaw(), &capacity, &_device);
37 blob_sizes.push_back(make_pair(capacity, s));
43 [](
const std::pair<size_t, std::string>& a,
44 const std::pair<size_t, std::string>& b) {
45 return b.first < a.first;
49 LOG(INFO) <<
"---- Workspace blobs: ---- ";
50 LOG(INFO) <<
"name;current shape;capacity bytes;percentage";
51 for (
const auto& sb : blob_sizes) {
52 Blob* b = this->
GetBlob(sb.second);
53 TensorInfoCall shape_fun = GetTensorInfoFunction(b->meta().id());
54 CHECK(shape_fun !=
nullptr);
58 auto shape = shape_fun(b->GetRaw(), &capacity, &_device);
60 ss << sb.second <<
";";
61 for (
const auto d : shape) {
64 LOG(INFO) << ss.str() <<
";" << sb.first <<
";" << std::setprecision(3)
65 << (cumtotal > 0 ? 100.0 * double(sb.first) / cumtotal : 0.0)
68 LOG(INFO) <<
"Total;;" << cumtotal <<
";100%";
73 names.reserve(blob_map_.size());
74 for (
auto& entry : blob_map_) {
75 names.push_back(entry.first);
82 names.reserve(blob_map_.size());
83 for (
auto& entry : blob_map_) {
84 names.push_back(entry.first);
86 for (
const auto& forwarded : forwarded_blobs_) {
87 const auto parent_ws = forwarded.second.first;
88 const auto& parent_name = forwarded.second.second;
89 if (parent_ws->HasBlob(parent_name)) {
90 names.push_back(forwarded.first);
94 const auto& shared_blobs = shared_->
Blobs();
95 names.insert(names.end(), shared_blobs.begin(), shared_blobs.end());
102 VLOG(1) <<
"Blob " << name <<
" already exists. Skipping.";
103 }
else if (forwarded_blobs_.count(name)) {
105 VLOG(1) <<
"Blob " << name <<
" is already forwarded from parent workspace " 106 <<
"(blob " << forwarded_blobs_[name].second <<
"). Skipping.";
108 VLOG(1) <<
"Creating blob " << name;
109 blob_map_[name] = unique_ptr<Blob>(
new Blob());
115 if (blob_map_.count(name)) {
116 VLOG(1) <<
"Blob " << name <<
" already exists. Skipping.";
118 VLOG(1) <<
"Creating blob " << name;
119 blob_map_[name] = unique_ptr<Blob>(
new Blob());
126 auto it = blob_map_.find(old_name);
128 it != blob_map_.end(),
131 " is not in the local blob list");
136 !
HasBlob(new_name),
"Blob ", new_name,
"is already in the workspace");
139 auto value = std::move(it->second);
142 auto* raw_ptr = value.get();
143 blob_map_[new_name] = std::move(value);
148 auto it = blob_map_.find(name);
149 if (it != blob_map_.end()) {
150 VLOG(1) <<
"Removing blob " << name <<
" from this workspace.";
156 VLOG(1) <<
"Blob " << name <<
" not exists. Skipping.";
161 if (blob_map_.count(name)) {
162 return blob_map_.at(name).get();
163 }
else if (forwarded_blobs_.count(name)) {
164 const auto parent_ws = forwarded_blobs_.at(name).first;
165 const auto& parent_name = forwarded_blobs_.at(name).second;
166 return parent_ws->GetBlob(parent_name);
167 }
else if (shared_ && shared_->
HasBlob(name)) {
170 LOG(WARNING) <<
"Blob " << name <<
" not in the workspace.";
181 const std::unordered_map<string, string>& forwarded_blobs,
182 bool skip_defined_blobs) {
183 CAFFE_ENFORCE(parent,
"Parent workspace must be specified");
184 for (
const auto& forwarded : forwarded_blobs) {
186 parent->
HasBlob(forwarded.second),
187 "Invalid parent workspace blob " + forwarded.second);
188 if (forwarded_blobs_.count(forwarded.first)) {
189 const auto& ws_blob = forwarded_blobs_[forwarded.first];
191 ws_blob.first, parent,
"Redefinition of blob " + forwarded.first);
195 "Redefinition of blob " + forwarded.first);
197 if (skip_defined_blobs &&
HasBlob(forwarded.first)) {
201 !
HasBlob(forwarded.first),
"Redefinition of blob " + forwarded.first);
204 forwarded_blobs_[forwarded.first] =
205 std::make_pair(parent, forwarded.second);
215 std::shared_ptr<NetDef> tmp_net_def(
new NetDef(net_def));
216 return CreateNet(tmp_net_def, overwrite);
220 const std::shared_ptr<const NetDef>& net_def,
222 CAFFE_ENFORCE(net_def->has_name(),
"Net definition should have a name.");
223 if (net_map_.count(net_def->name()) > 0) {
226 "I respectfully refuse to overwrite an existing net of the same " 229 "\", unless you explicitly specify overwrite=true.");
231 VLOG(1) <<
"Deleting existing network of the same name.";
236 net_map_.erase(net_def->name());
239 VLOG(1) <<
"Initializing network " << net_def->name();
240 net_map_[net_def->name()] =
242 if (net_map_[net_def->name()].get() ==
nullptr) {
243 LOG(ERROR) <<
"Error when creating the network." 244 <<
"Maybe net type: [" << net_def->type() <<
"] does not exist";
245 net_map_.erase(net_def->name());
248 return net_map_[net_def->name()].get();
252 if (!net_map_.count(name)) {
255 return net_map_[name].get();
260 if (net_map_.count(name)) {
261 net_map_.erase(name);
266 if (!net_map_.count(name)) {
267 LOG(ERROR) <<
"Network " << name <<
" does not exist yet.";
270 return net_map_[name]->Run();
273 bool Workspace::RunOperatorOnce(
const OperatorDef& op_def) {
274 std::unique_ptr<OperatorBase> op(CreateOperator(op_def,
this));
275 if (op.get() ==
nullptr) {
276 LOG(ERROR) <<
"Cannot create operator of type " << op_def.type();
280 LOG(ERROR) <<
"Error when running operator " << op_def.type();
285 bool Workspace::RunNetOnce(
const NetDef& net_def) {
287 if (net ==
nullptr) {
289 "Could not create net: " + net_def.name() +
" of type " +
293 LOG(ERROR) <<
"Error when running network " << net_def.name();
300 return RunPlanOnWorkspace(
this, plan, shouldContinue);
304 std::lock_guard<std::mutex> guard(thread_pool_creation_mutex_);
306 thread_pool_ = ThreadPool::defaultThreadPool();
308 return thread_pool_.get();
311 std::shared_ptr<Workspace::Bookkeeper> Workspace::bookkeeper() {
312 static auto shared = std::make_shared<Workspace::Bookkeeper>();
Blob is a general container that hosts a typed pointer.
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
void DeleteNet(const string &net_name)
Deletes the instantiated network with the given name.
bool RunPlan(const PlanDef &plan_def, ShouldContinue should_continue=StopOnSignal{})
Runs a plan that has multiple nets and execution steps.
Blob * CreateLocalBlob(const string &name)
Similar to CreateBlob(), but it creates a blob in the local workspace even if another blob with the s...
bool RemoveBlob(const string &name)
Remove the blob of the given name.
vector< string > LocalBlobs() const
Return list of blobs owned by this Workspace, not including blobs shared from parent workspace...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasBlob(const string &name) const
Checks if a blob with the given name is present in the current workspace.
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
void AddBlobMapping(const Workspace *parent, const std::unordered_map< string, string > &forwarded_blobs, bool skip_defined_blobs=false)
Adds blob mappings from workspace to the blobs from parent workspace.
Blob * RenameBlob(const string &old_name, const string &new_name)
Renames a local workspace blob.
bool RunNet(const string &net_name)
Finds and runs the instantiated network with the given name.
NetBase * GetNet(const string &net_name)
Gets the pointer to a created net.
vector< string > Blobs() const
Return a list of blob names.
NetBase * CreateNet(const NetDef &net_def, bool overwrite=false)
Creates a network with the given NetDef, and returns the pointer to the network.