1 #include "caffe2/core/plan_executor.h" 3 #include <condition_variable> 7 #include <unordered_map> 10 #include "caffe2/core/timer.h" 11 #include "caffe2/core/workspace.h" 12 #include "caffe2/proto/caffe2_pb.h" 15 caffe2_handle_executor_threads_exceptions,
17 "If used we will handle exceptions in executor threads. " 18 "This avoids SIGABRT but may cause process to deadlock");
32 using NetDefMap = std::unordered_map<std::string, NetDefInfo>;
35 struct ReporterInstance {
36 std::mutex report_mutex;
37 std::condition_variable report_cv;
38 std::thread report_thread;
39 ReporterInstance(
int intervalMillis,
bool* done, std::function<
void()> f) {
40 auto interval = std::chrono::milliseconds(intervalMillis);
41 auto reportWorker = [=]() {
42 std::unique_lock<std::mutex> lk(report_mutex);
44 report_cv.wait_for(lk, interval, [&]() {
return *done; });
48 report_thread = std::thread(reportWorker);
52 void start(int64_t intervalMillis, std::function<
void()> f) {
53 instances_.emplace_back(
new ReporterInstance(intervalMillis, &done, f));
58 for (
auto& instance : instances_) {
59 if (!instance->report_thread.joinable()) {
62 instance->report_cv.notify_all();
63 instance->report_thread.join();
68 std::vector<std::unique_ptr<ReporterInstance>> instances_;
74 std::function<bool(int64_t)> getContinuationTest(
76 const ExecutionStep& step) {
77 if (step.has_should_stop_blob()) {
80 "Must not specify num_iter if should_stop_blob is set");
83 if (!step.has_should_stop_blob()) {
84 CAFFE_ENFORCE(!step.has_only_once(),
"not supported");
85 int64_t iterations = step.has_num_iter() ? step.num_iter() : 1;
86 VLOG(1) <<
"Will execute step " << step.name() <<
" for " << iterations
88 return [=](int64_t i) {
return i < iterations; };
90 bool onlyOnce = step.has_only_once() && step.only_once();
91 VLOG(1) <<
"Will execute step" << step.name() << (onlyOnce ?
" once " :
"")
92 <<
" until stopped by blob " << step.should_stop_blob();
94 return [](int64_t i) {
return i == 0; };
96 return [](int64_t ) {
return true; };
102 inline bool getShouldStop(
const Blob* b) {
103 if (!b || b->meta().id() == TypeIdentifier::uninitialized()) {
107 const auto& t = b->Get<TensorCPU>();
108 CAFFE_ENFORCE(t.IsType<
bool>() && t.numel() == 1,
"expects a scalar boolean");
109 return *(t.template data<bool>());
120 struct WorkspaceIdInjector {
121 static const string NODE_ID;
122 static const string GLOBAL_WORKSPACE_ID;
124 void InjectWorkspaceId(Workspace* workspace) {
125 if (workspace->HasBlob(NODE_ID)) {
126 Blob* node_id_blob = workspace->GetBlob(NODE_ID);
127 const TensorCPU& node_id_tensor = node_id_blob->template Get<TensorCPU>();
128 int node_id = node_id_tensor.template data<int32_t>()[0];
131 "Integer overflow while calculating GLOBAL_WORKSPACE_ID blob");
132 int32_t global_ws_id = (seq_++) + (static_cast<int32_t>(node_id) << 16);
133 Blob* global_ws_id_blob = workspace->CreateLocalBlob(GLOBAL_WORKSPACE_ID);
134 TensorCPU* global_ws_id_tensor =
135 BlobGetMutableTensor(global_ws_id_blob, CPU);
136 global_ws_id_tensor->Resize();
137 global_ws_id_tensor->template mutable_data<int32_t>()[0] = global_ws_id;
138 VLOG(1) <<
"Adding " << GLOBAL_WORKSPACE_ID <<
" = " << global_ws_id;
143 std::atomic<int> seq_{0};
146 const string WorkspaceIdInjector::NODE_ID =
"NODE_ID";
147 const string WorkspaceIdInjector::GLOBAL_WORKSPACE_ID =
"GLOBAL_WORKSPACE_ID";
149 struct CompiledExecutionStep;
170 struct ExecutionStepWrapper {
171 ExecutionStepWrapper(
172 const ExecutionStep* step,
173 Workspace* externalWorkspace,
174 ShouldContinue externalShouldContinue,
176 WorkspaceIdInjector* ws_id_injector)
178 externalWorkspace_(externalWorkspace),
179 externalShouldContinue_(externalShouldContinue),
181 ws_id_injector_(ws_id_injector) {
185 if (!step_->create_workspace()) {
186 compiledStep_ = doCompile();
190 class CompiledGuard {
191 void reset(std::unique_ptr<CompiledExecutionStep>&& compiled) {
192 compiled_ = std::move(compiled);
193 compiledRef_ = compiled_.get();
195 void reset(CompiledExecutionStep* compiledRef) {
197 compiledRef_ = compiledRef;
201 CompiledExecutionStep* operator->() {
207 std::unique_ptr<CompiledExecutionStep> compiled_;
208 CompiledExecutionStep* compiledRef_;
209 friend struct ExecutionStepWrapper;
212 const ExecutionStep& step() {
216 CompiledGuard compiled() {
219 guard.reset(compiledStep_.get());
221 guard.reset(doCompile());
227 std::unique_ptr<CompiledExecutionStep> doCompile();
229 const ExecutionStep* step_;
230 Workspace* externalWorkspace_;
231 ShouldContinue externalShouldContinue_;
233 std::unique_ptr<CompiledExecutionStep> compiledStep_;
234 WorkspaceIdInjector* ws_id_injector_;
237 struct CompiledExecutionStep {
238 typedef std::function<bool(int)> ShouldContinue;
240 CompiledExecutionStep(
241 const ExecutionStep* mainStep,
242 Workspace* externalWorkspace,
243 ShouldContinue externalShouldContinue,
245 WorkspaceIdInjector* ws_id_injector)
247 if (mainStep->create_workspace()) {
248 localWorkspace_.reset(
new Workspace(externalWorkspace));
249 workspace = localWorkspace_.get();
250 ws_id_injector->InjectWorkspaceId(workspace);
252 workspace = externalWorkspace;
256 (step->substep_size() == 0 || step->network_size() == 0),
257 "An ExecutionStep should either have substep or networks" 260 auto createAndGetNet = [&](
const std::string& network_name) {
261 auto it = netDefs->find(network_name);
263 it != netDefs->end(),
264 "ExecutionStep " + mainStep->name() +
" uses undefined net " +
269 if (it->second.needsOverride || !workspace->GetNet(network_name)) {
270 workspace->CreateNet(*it->second.netDef,
true);
271 it->second.needsOverride =
false;
273 auto* net = workspace->GetNet(network_name);
274 CAFFE_ENFORCE(net !=
nullptr,
"Network ", network_name,
" not found.");
278 if (step->substep_size()) {
279 ShouldContinue substepShouldContinue;
280 if (!step->concurrent_substeps() || step->substep().size() <= 1) {
281 substepShouldContinue = externalShouldContinue;
283 substepShouldContinue = [
this, externalShouldContinue](int64_t it) {
284 return !gotFailure && externalShouldContinue(it);
288 for (
const auto& ss : step->substep()) {
289 auto compiledSubstep = std::make_shared<ExecutionStepWrapper>(
290 &ss, workspace, substepShouldContinue, netDefs, ws_id_injector);
291 if (ss.has_run_every_ms()) {
292 reportSubsteps.push_back(compiledSubstep);
294 recurringSubsteps.push_back(compiledSubstep);
298 for (
const string& network_name : step->network()) {
299 networks.push_back(createAndGetNet(network_name));
303 if (step->has_should_stop_blob()) {
304 shouldStop = workspace->GetBlob(step->should_stop_blob());
306 shouldStop,
"blob ", step->should_stop_blob(),
" does not exist");
309 if (step->has_report_net()) {
311 step->has_report_interval(),
312 "A report_interval must be provided if report_net is set.");
313 reportNet = createAndGetNet(step->report_net());
318 netShouldContinue = getContinuationTest(workspace, *step);
319 shouldContinue = [
this, externalShouldContinue](int64_t iter) {
320 return externalShouldContinue(iter) && this->netShouldContinue(iter);
324 const ExecutionStep* step;
325 Workspace* workspace;
326 vector<std::shared_ptr<ExecutionStepWrapper>> reportSubsteps;
327 vector<std::shared_ptr<ExecutionStepWrapper>> recurringSubsteps;
329 vector<NetBase*> networks;
331 Blob* shouldStop{
nullptr};
332 ShouldContinue netShouldContinue;
333 ShouldContinue shouldContinue;
334 std::atomic<bool> gotFailure{
false};
337 std::unique_ptr<Workspace> localWorkspace_;
340 std::unique_ptr<CompiledExecutionStep> ExecutionStepWrapper::doCompile() {
341 return std::unique_ptr<CompiledExecutionStep>(
new CompiledExecutionStep(
344 externalShouldContinue_,
349 #define CHECK_SHOULD_STOP(step, shouldStop) \ 350 if (getShouldStop(shouldStop)) { \ 351 VLOG(1) << "Execution step " << step.name() << " stopped by " \ 352 << step.should_stop_blob(); \ 356 bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) {
357 const auto& step = stepWrapper.step();
358 auto compiledStep = stepWrapper.compiled();
360 VLOG(1) <<
"Running execution step " << step.name();
362 std::unique_ptr<Reporter> reporter;
363 if (step.has_report_net() || compiledStep->reportSubsteps.size() > 0) {
364 reporter = caffe2::make_unique<Reporter>();
365 auto* reportNet = compiledStep->reportNet;
367 VLOG(1) <<
"Starting reporter net";
368 reporter->start(step.report_interval() * 1000, [reportNet]() {
369 if (!reportNet->Run()) {
370 LOG(WARNING) <<
"Error running report_net.";
374 for (
auto& substepWrapper : compiledStep->reportSubsteps) {
376 substepWrapper->step().run_every_ms(), [substepWrapper]() {
377 if (!ExecuteStepRecursive(*substepWrapper)) {
378 LOG(WARNING) <<
"Error running report step.";
384 const Blob* shouldStop = compiledStep->shouldStop;
386 if (step.substep_size()) {
388 (!step.concurrent_substeps() || step.substep().size() <= 1) &&
389 (!step.has_num_concurrent_instances() ||
390 step.num_concurrent_instances() <= 1);
391 for (int64_t iter = 0; compiledStep->shouldContinue(iter); ++iter) {
393 VLOG(1) <<
"Executing step " << step.name() <<
" iteration " << iter;
394 for (
auto& substepWrapper : compiledStep->recurringSubsteps) {
395 if (!ExecuteStepRecursive(*substepWrapper)) {
398 CHECK_SHOULD_STOP(step, shouldStop);
401 VLOG(1) <<
"Executing step " << step.name() <<
" iteration " << iter
402 <<
" with " << step.substep().size() <<
" concurrent substeps";
404 std::atomic<int> next_substep{0};
405 std::mutex exception_mutex;
406 string first_exception;
407 auto worker = [&]() {
408 auto num_substeps = compiledStep->recurringSubsteps.size();
409 int substep_id = next_substep++ % num_substeps;
410 if (compiledStep->gotFailure) {
414 if (!ExecuteStepRecursive(
415 *compiledStep->recurringSubsteps.at(substep_id))) {
416 compiledStep->gotFailure =
true;
418 }
catch (
const std::exception& ex) {
419 std::lock_guard<std::mutex> guard(exception_mutex);
420 if (!first_exception.size()) {
421 first_exception = c10::GetExceptionString(ex);
422 LOG(ERROR) <<
"Parallel worker exception:\n" << first_exception;
424 compiledStep->gotFailure =
true;
425 if (!FLAGS_caffe2_handle_executor_threads_exceptions) {
435 std::vector<std::thread> threads;
436 auto numThreads = compiledStep->recurringSubsteps.size();
437 if (step.has_num_concurrent_instances()) {
438 numThreads *= step.num_concurrent_instances();
440 for (
size_t i = 0; i < numThreads; ++i) {
441 threads.emplace_back(worker);
443 for (
auto& thread : threads) {
446 if (compiledStep->gotFailure) {
447 LOG(ERROR) <<
"One of the workers failed.";
448 if (first_exception.size()) {
450 "One of the workers died with an unhandled exception ",
456 CHECK_SHOULD_STOP(step, shouldStop);
462 for (int64_t iter = 0; compiledStep->shouldContinue(iter); ++iter) {
463 VLOG(1) <<
"Executing networks " << step.name() <<
" iteration " << iter;
464 for (NetBase* network : compiledStep->networks) {
465 if (!network->Run()) {
468 CHECK_SHOULD_STOP(step, shouldStop);
475 #undef CHECK_SHOULD_STOP 478 bool RunPlanOnWorkspace(
481 ShouldContinue shouldContinue) {
482 LOG(INFO) <<
"Started executing plan " << plan.name();
483 if (plan.execution_step_size() == 0) {
484 LOG(WARNING) <<
"Nothing to run - did you define a correct plan?";
488 LOG(INFO) <<
"Initializing networks for plan " << plan.name();
491 for (
const NetDef& net_def : plan.network()) {
492 LOG(INFO) <<
"Processing net '" << net_def.name() <<
"', type: '" 493 << net_def.type() <<
"', #ops: " << net_def.op_size()
494 <<
", num_workers: " << net_def.num_workers();
496 net_defs.count(net_def.name()) == 0,
497 "Your plan contains networks of the same name \"",
499 "\", which should not happen. Check your plan to see " 500 "if you made a programming error in creating the plan.");
501 auto netAlreadyExists = ws->GetNet(net_def.name()) !=
nullptr;
502 net_defs[net_def.name()] = NetDefInfo{&net_def, netAlreadyExists};
504 WorkspaceIdInjector ws_id_injector;
506 for (
const ExecutionStep& step : plan.execution_step()) {
508 ExecutionStepWrapper stepWrapper(
509 &step, ws, shouldContinue, &net_defs, &ws_id_injector);
510 if (!ExecuteStepRecursive(stepWrapper)) {
511 LOG(ERROR) <<
"Failed initializing step " << step.name();
514 LOG(INFO) <<
"Step " << step.name() <<
" in plan " << plan.name()
515 <<
" took " << step_timer.Seconds() <<
" seconds.";
517 LOG(INFO) <<
"Total plan " << plan.name() <<
" took " << plan_timer.Seconds()
519 LOG(INFO) <<
"Plan " << plan.name() <<
" executed successfully.";
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...