Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_task_future.cc
1 #include "caffe2/core/net_async_task_future.h"
2 
3 #include "c10/util/Logging.h"
4 #include "caffe2/core/common.h"
5 
6 namespace caffe2 {
7 
8 AsyncTaskFuture::AsyncTaskFuture() : completed_(false), failed_(false) {}
9 
10 AsyncTaskFuture::AsyncTaskFuture(const std::vector<AsyncTaskFuture*>& futures)
11  : completed_(false), failed_(false) {
12  if (futures.size() > 1) {
13  parent_counter_ = caffe2::make_unique<ParentCounter>(futures.size());
14  for (auto future : futures) {
15  future->SetCallback([this](const AsyncTaskFuture* f) {
16  if (f->IsFailed()) {
17  std::unique_lock<std::mutex> lock(parent_counter_->err_mutex);
18  if (parent_counter_->parent_failed) {
19  parent_counter_->err_msg += ", " + f->ErrorMessage();
20  } else {
21  parent_counter_->parent_failed = true;
22  parent_counter_->err_msg = f->ErrorMessage();
23  }
24  }
25  int count = --parent_counter_->parent_count;
26  if (count == 0) {
27  // thread safe to use parent_counter here
28  if (!parent_counter_->parent_failed) {
29  SetCompleted();
30  } else {
31  SetCompleted(parent_counter_->err_msg.c_str());
32  }
33  }
34  });
35  }
36  } else {
37  CAFFE_ENFORCE_EQ(futures.size(), 1);
38  auto future = futures.back();
39  future->SetCallback([this](const AsyncTaskFuture* f) {
40  if (!f->IsFailed()) {
41  SetCompleted();
42  } else {
43  SetCompleted(f->ErrorMessage().c_str());
44  }
45  });
46  }
47 }
48 
49 bool AsyncTaskFuture::IsCompleted() const {
50  return completed_;
51 }
52 
53 bool AsyncTaskFuture::IsFailed() const {
54  return failed_;
55 }
56 
57 std::string AsyncTaskFuture::ErrorMessage() const {
58  return err_msg_;
59 }
60 
61 void AsyncTaskFuture::Wait() const {
62  std::unique_lock<std::mutex> lock(mutex_);
63  while (!completed_) {
64  cv_completed_.wait(lock);
65  }
66 }
67 
68 void AsyncTaskFuture::SetCallback(
69  std::function<void(const AsyncTaskFuture*)> callback) {
70  std::unique_lock<std::mutex> lock(mutex_);
71 
72  callbacks_.push_back(callback);
73  if (completed_) {
74  callback(this);
75  }
76 }
77 
78 void AsyncTaskFuture::SetCompleted(const char* err_msg) {
79  std::unique_lock<std::mutex> lock(mutex_);
80 
81  CAFFE_ENFORCE(!completed_, "Calling SetCompleted on a completed future");
82  completed_ = true;
83 
84  if (err_msg) {
85  failed_ = true;
86  err_msg_ = err_msg;
87  }
88 
89  for (auto& callback : callbacks_) {
90  callback(this);
91  }
92 
93  cv_completed_.notify_all();
94 }
95 
96 // ResetState is called on a completed future,
97 // does not reset callbacks to keep task graph structure
98 void AsyncTaskFuture::ResetState() {
99  std::unique_lock<std::mutex> lock(mutex_);
100  if (parent_counter_) {
101  parent_counter_->Reset();
102  }
103  completed_ = false;
104  failed_ = false;
105  err_msg_ = "";
106 }
107 
108 AsyncTaskFuture::~AsyncTaskFuture() {}
109 
110 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13