Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_task.cc
1 #include "caffe2/core/net_async_task.h"
2 
3 #include "caffe2/core/net_async_task_graph.h"
4 
5 namespace caffe2 {
6 
7 AsyncTask::AsyncTask(const std::vector<OperatorBase*>& ops) : ops_(ops) {
8  CAFFE_ENFORCE(!ops_.empty());
9  device_option_ = ops_.front()->device_option();
10  for (auto& op : ops_) {
11  CAFFE_ENFORCE(IsSameDevice(device_option_, op->device_option()));
12  }
13  Reset();
14 }
15 
16 void AsyncTask::handleChainError(
17  OperatorBase* op,
18  const char* err_str,
19  bool save_exception) {
20  std::string err_msg = err_str;
21  if (op) {
22  err_msg += ", op " + (op->has_debug_def() ? op->type() : " unknown");
23  }
24  LOG(ERROR) << err_msg;
25 
26  // save error message and exception in chain's Event
27  auto last_op = ops_.back();
28  if (save_exception) {
29  last_op->event().SetFinishedWithException(err_msg.c_str());
30  } else {
31  last_op->event().SetFinished(err_msg.c_str());
32  }
33 
34  // set future as completed with an error
35  // TODO: exceptions in future
36  future_.SetCompleted(err_msg.c_str());
37 }
38 
39 bool AsyncTask::Run(const ExecutionOptions& options) {
40  // TODO: insert CUDA's async stream waits; tracing and counters
41  OperatorBase* op = nullptr;
42  try {
43  for (auto op_idx = 0; op_idx < ops_.size(); ++op_idx) {
44  op = ops_[op_idx];
45  int stream_id = 0; // TODO: thread local stream id
46  if (!op->RunAsync(stream_id)) {
47  handleChainError(op, "Failed to execute an op");
48  return false;
49  }
50  }
51 
52  if (options.finish_chain_) {
53  op = ops_.back();
54  op->Finish();
55  }
56 
57  // set the future as successfully completed or, in case of async CPU,
58  // use op's callback
59  if (IsCPUDeviceType(device_option_.device_type()) &&
60  ops_.back()->HasAsyncPart()) {
61  auto& event = ops_.back()->event();
62  event.SetCallback([this, &event]() {
63  CAFFE_ENFORCE(event.IsFinished());
64  if (event.Query() == EventStatus::EVENT_SUCCESS) {
65  future_.SetCompleted();
66  } else {
67  // TODO: support for exceptions
68  future_.SetCompleted(event.ErrorMessage().c_str());
69  }
70  });
71  } else {
72  future_.SetCompleted();
73  }
74  } catch (const std::exception& e) {
75  handleChainError(op, e.what(), /* save_exception */ true);
76  return false;
77  } catch (...) {
78  handleChainError(
79  op,
80  "Failed to execute task: unknown error",
81  /* save_exception */ true);
82  return false;
83  }
84 
85  return true;
86 }
87 
88 void AsyncTask::Reset() {
89  for (auto& op : ops_) {
90  op->ResetEvent();
91  }
92  future_.ResetState();
93 }
94 
95 DeviceOption AsyncTask::GetDeviceOption() const {
96  return device_option_;
97 }
98 
99 AsyncTaskFuture& AsyncTask::GetFuture() {
100  return future_;
101 }
102 
103 const AsyncTaskFuture& AsyncTask::GetFuture() const {
104  return future_;
105 }
106 
107 }; // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13