1 #ifndef CAFFE2_CORE_WORKSPACE_H_ 2 #define CAFFE2_CORE_WORKSPACE_H_ 4 #include "caffe2/core/common.h" 5 #include "caffe2/core/observer.h" 11 #include <unordered_set> 14 #include "c10/util/Registry.h" 15 #include "caffe2/core/blob.h" 16 #include "caffe2/core/net.h" 17 #include "caffe2/proto/caffe2_pb.h" 18 #include "caffe2/utils/signal_handler.h" 19 #include "caffe2/utils/threadpool/ThreadPool.h" 21 C10_DECLARE_bool(caffe2_print_blob_sizes_at_exit);
29 : handler_(std::make_shared<SignalHandler>(
30 SignalHandler::Action::STOP,
31 SignalHandler::Action::STOP)) {}
33 StopOnSignal(
const StopOnSignal& other) : handler_(other.handler_) {}
35 bool operator()(
int ) {
36 return handler_->CheckForSignals() != SignalHandler::Action::STOP;
39 std::shared_ptr<SignalHandler> handler_;
49 typedef std::function<bool(int)> ShouldContinue;
50 typedef CaffeMap<string, unique_ptr<Blob> > BlobMap;
51 typedef CaffeMap<string, unique_ptr<NetBase> > NetMap;
85 const std::unordered_map<string, string>& forwarded_blobs)
87 CAFFE_ENFORCE(shared,
"Parent workspace must be specified");
88 for (
const auto& forwarded : forwarded_blobs) {
90 shared->
HasBlob(forwarded.second),
91 "Invalid parent workspace blob: ",
93 forwarded_blobs_[forwarded.first] =
94 std::make_pair(shared, forwarded.second);
102 : root_folder_(root_folder), shared_(shared), bookkeeper_(bookkeeper()) {
103 std::lock_guard<std::mutex> guard(bookkeeper_->wsmutex);
104 bookkeeper_->workspaces.insert(
this);
108 if (FLAGS_caffe2_print_blob_sizes_at_exit) {
113 std::lock_guard<std::mutex> guard(bookkeeper_->wsmutex);
114 bookkeeper_->workspaces.erase(
this);
129 const std::unordered_map<string, string>& forwarded_blobs,
130 bool skip_defined_blobs =
false);
136 template <
class Context>
138 for (
const auto& blob : blobs) {
139 if (!forwarded_blobs_.count(blob)) {
142 const auto& ws_blob = forwarded_blobs_[blob];
143 const auto* parent_ws = ws_blob.first;
144 auto* from_blob = parent_ws->GetBlob(ws_blob.second);
145 CAFFE_ENFORCE(from_blob);
147 from_blob->template IsType<Tensor>(),
148 "Expected blob with tensor value",
150 forwarded_blobs_.erase(blob);
151 auto* to_blob = CreateBlob(blob);
152 CAFFE_ENFORCE(to_blob);
153 const auto& from_tensor = from_blob->template Get<Tensor>();
154 auto* to_tensor = BlobGetMutableTensor(to_blob, Context::GetDeviceType());
155 to_tensor->CopyFrom(from_tensor);
163 vector<string> LocalBlobs()
const;
170 vector<string> Blobs()
const;
179 inline bool HasBlob(
const string& name)
const {
182 if (blob_map_.count(name)) {
184 }
else if (forwarded_blobs_.count(name)) {
185 const auto parent_ws = forwarded_blobs_.at(name).first;
186 const auto& parent_name = forwarded_blobs_.at(name).second;
187 return parent_ws->HasBlob(parent_name);
188 }
else if (shared_) {
189 return shared_->HasBlob(name);
194 void PrintBlobSizes();
201 Blob* CreateBlob(
const string& name);
209 Blob* CreateLocalBlob(
const string& name);
215 bool RemoveBlob(
const string& name);
220 const Blob* GetBlob(
const string& name)
const;
225 Blob* GetBlob(
const string& name);
232 Blob* RenameBlob(
const string& old_name,
const string& new_name);
245 const std::shared_ptr<const NetDef>& net_def,
246 bool overwrite =
false);
251 NetBase* GetNet(
const string& net_name);
255 void DeleteNet(
const string& net_name);
261 bool RunNet(
const string& net_name);
267 vector<string> names;
268 for (
auto& entry : net_map_) {
269 names.push_back(entry.first);
277 bool RunPlan(
const PlanDef& plan_def,
292 bool RunOperatorOnce(
const OperatorDef& op_def);
293 bool RunNetOnce(
const NetDef& net_def);
301 template <
typename F>
303 auto bk = bookkeeper();
304 std::lock_guard<std::mutex> guard(bk->wsmutex);
311 std::atomic<int> last_failed_op_net_position{};
316 std::unordered_set<Workspace*> workspaces;
319 static std::shared_ptr<Bookkeeper> bookkeeper();
323 const string root_folder_;
325 std::unordered_map<string, std::pair<const Workspace*, string>>
327 std::unique_ptr<ThreadPool> thread_pool_;
328 std::mutex thread_pool_creation_mutex_;
329 std::shared_ptr<Bookkeeper> bookkeeper_;
336 #endif // CAFFE2_CORE_WORKSPACE_H_ const string & RootFolder()
Return the root folder of the workspace.
Blob is a general container that hosts a typed pointer.
Workspace(const Workspace *shared, const std::unordered_map< string, string > &forwarded_blobs)
Initializes workspace with parent workspace, blob name remapping (new name -> parent blob name)...
Workspace(const Workspace *shared)
Initializes a workspace with a shared workspace.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
static void ForEach(F f)
Applies a function f on each workspace that currently exists.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Workspace(const string &root_folder)
Initializes an empty workspace with the given root folder.
bool HasBlob(const string &name) const
Checks if a blob with the given name is present in the current workspace.
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
void CopyForwardedTensors(const std::unordered_set< std::string > &blobs)
Converts previously mapped tensor blobs to local blobs, copies values from parent workspace blobs int...
vector< string > Nets() const
Returns a list of names of the currently instantiated networks.
Workspace()
Initializes an empty workspace.
Workspace(const string &root_folder, const Workspace *shared)
Initializes a workspace with a root folder and a shared workspace.