Caffe2 - C++ API
A deep learning, cross platform ML framework
net_parallel.cc
1 #include "caffe2/core/net_parallel.h"
2 
3 #include "caffe2/core/operator.h"
4 
5 #include <sstream>
6 
7 C10_DEFINE_string(
8  caffe2_task_graph_engine,
9  "futures",
10  "Task graph engine type used by net executor");
11 
12 namespace caffe2 {
13 
14 ParallelNet::ParallelNet(
15  const std::shared_ptr<const NetDef>& net_def,
16  Workspace* ws)
17  : NetBase(net_def, ws), options_(net_def), run_future_(nullptr) {
18  num_workers_ = net_def->num_workers();
19  CAFFE_ENFORCE_GT(
20  num_workers_, 0, "Expected positive number of worker threads");
21 
22  helper_ = caffe2::make_unique<ParallelNetExecutorHelper>(this);
23 
24  // initialize operators
25  operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
26  operators_.reserve(operator_nodes_.size());
27  for (const auto& node : operator_nodes_) {
28  auto op = node.operator_.get();
29  op->SetExecutorHelper(helper_.get());
30  operators_.push_back(op);
31  }
32 
33  task_graph_ = TaskGraphRegistry()->Create(
34  FLAGS_caffe2_task_graph_engine, helper_.get(), options_);
35  CAFFE_ENFORCE(task_graph_, "Couldn't initialize task graph");
36 
37  // compute chains
38  // TODO: inference mode for chaining
39  auto execution_chains = dag_utils::computeChains(operator_nodes_);
40  std::vector<std::vector<int>> chains;
41  chains.reserve(execution_chains.size());
42  for (const auto& kv : execution_chains) {
43  chains.push_back(kv.second);
44  }
45  auto chain_nodes = dag_utils::prepareChainGraphNodes(operator_nodes_, chains);
46  CAFFE_ENFORCE_EQ(chains.size(), chain_nodes.size());
47 
48  // disable unused events
49  for (const auto& chain : chains) {
50  for (const auto& op_id : chain) {
51  if (op_id == chain.back() || op_id == chain.front()) {
52  continue;
53  }
54  auto op = operators_[op_id];
55  if (IsCPUDeviceType(op->device_option().device_type()) &&
56  op->HasAsyncPart()) {
57  continue;
58  }
59  op->DisableEvent();
60  }
61  }
62 
63  // initialize task graph
64  for (auto chain_id = 0; chain_id < chains.size(); ++chain_id) {
65  std::vector<OperatorBase*> ops;
66  ops.reserve(chains[chain_id].size());
67  for (auto op_id : chains[chain_id]) {
68  ops.push_back(operators_[op_id]);
69  }
70  CAFFE_ENFORCE(task_graph_->CreateNode(chain_id, ops));
71  }
72  for (auto chain_id = 0; chain_id < chain_nodes.size(); ++chain_id) {
73  if (!chain_nodes[chain_id].parents_.empty()) {
74  CAFFE_ENFORCE(
75  task_graph_->AddDependency(chain_id, chain_nodes[chain_id].parents_));
76  }
77  }
78 
79  // Freeze graph and initialize graph execution future
80  task_graph_->FreezeGraph();
81  run_future_ = task_graph_->GetFuture();
82  run_future_->SetCallback([this](const AsyncTaskFuture* /* unused */) {
83  StopAllObservers();
84  finishRun();
85  });
86 
87  LOG(INFO) << "Initialized parallel net: '" << Name()
88  << "', #ops: " << net_def->op_size()
89  << ", #chains: " << chains.size() << ", #workers: " << num_workers_
90  << ", dfs scheduling: " << options_.use_dfs_scheduling_
91  << ", task graph engine: " << FLAGS_caffe2_task_graph_engine;
92 }
93 
94 bool ParallelNet::RunAsync() {
95  reset();
96  StartAllObservers();
97 
98  try {
99  task_graph_->ExecuteGraph();
100  } catch (const std::exception&) {
101  StopAllObservers();
102  return false;
103  }
104 
105  return true;
106 }
107 
108 void ParallelNet::Wait() {
109  CAFFE_ENFORCE(run_future_);
110  run_future_->Wait();
111 }
112 
113 void ParallelNet::reset() {
114  task_graph_->Reset();
115 }
116 
117 bool ParallelNet::handleRunError() {
118  CAFFE_ENFORCE(run_future_ && run_future_->IsCompleted());
119  // TODO: throw saved exceptions
120  if (run_future_->IsFailed()) {
121  LOG(ERROR) << "Failed parallel run (" << Name()
122  << "): " << run_future_->ErrorMessage();
123  }
124  return !run_future_->IsFailed();
125 }
126 
127 TaskThreadPoolBase* ParallelNet::poolGetter(
128  PoolsMap& pools,
129  int device_type,
130  int device_id,
131  int pool_size) {
132  std::unique_lock<std::mutex> pools_lock(pools_mutex_);
133  auto pool = pools[device_id][pool_size];
134  if (!pool) {
135  pool = c10::ThreadPoolRegistry()->Create(
136  DeviceTypeName(device_type),
137  device_id,
138  pool_size,
139  options_.use_per_net_pools_);
140  pools[device_id][pool_size] = pool;
141  }
142  return pool.get();
143 }
144 
145 TaskThreadPoolBase* ParallelNet::Pool(const DeviceOption& device_option) {
146  if (options_.use_single_pool_) {
147  return poolGetter(cpu_pools_, PROTO_CPU, -1, num_workers_);
148  }
149  const auto device_type = device_option.device_type();
150  if (IsCPUDeviceType(device_type)) {
151  auto numa_node_id = -1;
152  if (device_option.has_numa_node_id()) {
153  numa_node_id = device_option.numa_node_id();
154  CAFFE_ENFORCE_GE(numa_node_id, 0, "Invalid NUMA node id: ", numa_node_id);
155  }
156  CAFFE_ENFORCE_LT(
157  numa_node_id,
158  FLAGS_caffe2_net_async_max_numa_nodes,
159  "Invalid NUMA node id: ",
160  numa_node_id);
161  return poolGetter(cpu_pools_, device_type, numa_node_id, num_workers_);
162  } else if (IsGPUDeviceType(device_type)) {
163  auto gpu_id = device_option.device_id();
164  CAFFE_ENFORCE(
165  gpu_id >= 0 && gpu_id < FLAGS_caffe2_net_async_max_gpus,
166  "Invalid GPU id: " + caffe2::to_string(gpu_id));
167  return poolGetter(gpu_pools_, device_type, gpu_id, num_workers_);
168  } else {
169  CAFFE_THROW("Unsupported device type " + caffe2::to_string(device_type));
170  }
171 }
172 
173 bool ParallelNet::SupportsAsync() {
174  return true;
175 }
176 
177 void ParallelNet::finishRun() {}
178 
179 std::vector<OperatorBase*> ParallelNet::GetOperators() const {
180  return operators_;
181 }
182 
183 std::shared_ptr<AsyncTaskGraphBase> GetAsyncTaskGraph(
184  ExecutorHelper* helper,
185  const ExecutionOptions& options) {
186  return std::make_shared<AsyncTaskGraph>(helper, options);
187 }
188 
189 C10_DEFINE_SHARED_REGISTRY(
190  TaskGraphRegistry,
191  AsyncTaskGraphBase,
192  ExecutorHelper*,
193  const ExecutionOptions&);
194 
195 C10_REGISTER_CREATOR(TaskGraphRegistry, futures, GetAsyncTaskGraph);
196 
197 REGISTER_NET(parallel, ParallelNet);
198 
199 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13