1 #include "caffe2/core/net_parallel.h" 3 #include "caffe2/core/operator.h" 8 caffe2_task_graph_engine,
10 "Task graph engine type used by net executor");
14 ParallelNet::ParallelNet(
15 const std::shared_ptr<const NetDef>& net_def,
17 : NetBase(net_def, ws), options_(net_def), run_future_(nullptr) {
18 num_workers_ = net_def->num_workers();
20 num_workers_, 0,
"Expected positive number of worker threads");
22 helper_ = caffe2::make_unique<ParallelNetExecutorHelper>(
this);
25 operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
26 operators_.reserve(operator_nodes_.size());
27 for (
const auto& node : operator_nodes_) {
28 auto op = node.operator_.get();
29 op->SetExecutorHelper(helper_.get());
30 operators_.push_back(op);
33 task_graph_ = TaskGraphRegistry()->Create(
34 FLAGS_caffe2_task_graph_engine, helper_.get(), options_);
35 CAFFE_ENFORCE(task_graph_,
"Couldn't initialize task graph");
39 auto execution_chains = dag_utils::computeChains(operator_nodes_);
40 std::vector<std::vector<int>> chains;
41 chains.reserve(execution_chains.size());
42 for (
const auto& kv : execution_chains) {
43 chains.push_back(kv.second);
45 auto chain_nodes = dag_utils::prepareChainGraphNodes(operator_nodes_, chains);
46 CAFFE_ENFORCE_EQ(chains.size(), chain_nodes.size());
49 for (
const auto& chain : chains) {
50 for (
const auto& op_id : chain) {
51 if (op_id == chain.back() || op_id == chain.front()) {
54 auto op = operators_[op_id];
55 if (IsCPUDeviceType(op->device_option().device_type()) &&
64 for (
auto chain_id = 0; chain_id < chains.size(); ++chain_id) {
65 std::vector<OperatorBase*> ops;
66 ops.reserve(chains[chain_id].size());
67 for (
auto op_id : chains[chain_id]) {
68 ops.push_back(operators_[op_id]);
70 CAFFE_ENFORCE(task_graph_->CreateNode(chain_id, ops));
72 for (
auto chain_id = 0; chain_id < chain_nodes.size(); ++chain_id) {
73 if (!chain_nodes[chain_id].parents_.empty()) {
75 task_graph_->AddDependency(chain_id, chain_nodes[chain_id].parents_));
80 task_graph_->FreezeGraph();
81 run_future_ = task_graph_->GetFuture();
82 run_future_->SetCallback([
this](
const AsyncTaskFuture* ) {
87 LOG(INFO) <<
"Initialized parallel net: '" << Name()
88 <<
"', #ops: " << net_def->op_size()
89 <<
", #chains: " << chains.size() <<
", #workers: " << num_workers_
90 <<
", dfs scheduling: " << options_.use_dfs_scheduling_
91 <<
", task graph engine: " << FLAGS_caffe2_task_graph_engine;
94 bool ParallelNet::RunAsync() {
99 task_graph_->ExecuteGraph();
100 }
catch (
const std::exception&) {
108 void ParallelNet::Wait() {
109 CAFFE_ENFORCE(run_future_);
113 void ParallelNet::reset() {
114 task_graph_->Reset();
117 bool ParallelNet::handleRunError() {
118 CAFFE_ENFORCE(run_future_ && run_future_->IsCompleted());
120 if (run_future_->IsFailed()) {
121 LOG(ERROR) <<
"Failed parallel run (" << Name()
122 <<
"): " << run_future_->ErrorMessage();
124 return !run_future_->IsFailed();
127 TaskThreadPoolBase* ParallelNet::poolGetter(
132 std::unique_lock<std::mutex> pools_lock(pools_mutex_);
133 auto pool = pools[device_id][pool_size];
135 pool = c10::ThreadPoolRegistry()->Create(
136 DeviceTypeName(device_type),
139 options_.use_per_net_pools_);
140 pools[device_id][pool_size] = pool;
145 TaskThreadPoolBase* ParallelNet::Pool(
const DeviceOption& device_option) {
146 if (options_.use_single_pool_) {
147 return poolGetter(cpu_pools_, PROTO_CPU, -1, num_workers_);
149 const auto device_type = device_option.device_type();
150 if (IsCPUDeviceType(device_type)) {
151 auto numa_node_id = -1;
152 if (device_option.has_numa_node_id()) {
153 numa_node_id = device_option.numa_node_id();
154 CAFFE_ENFORCE_GE(numa_node_id, 0,
"Invalid NUMA node id: ", numa_node_id);
158 FLAGS_caffe2_net_async_max_numa_nodes,
159 "Invalid NUMA node id: ",
161 return poolGetter(cpu_pools_, device_type, numa_node_id, num_workers_);
162 }
else if (IsGPUDeviceType(device_type)) {
163 auto gpu_id = device_option.device_id();
165 gpu_id >= 0 && gpu_id < FLAGS_caffe2_net_async_max_gpus,
166 "Invalid GPU id: " + caffe2::to_string(gpu_id));
167 return poolGetter(gpu_pools_, device_type, gpu_id, num_workers_);
169 CAFFE_THROW(
"Unsupported device type " + caffe2::to_string(device_type));
173 bool ParallelNet::SupportsAsync() {
177 void ParallelNet::finishRun() {}
179 std::vector<OperatorBase*> ParallelNet::GetOperators()
const {
183 std::shared_ptr<AsyncTaskGraphBase> GetAsyncTaskGraph(
184 ExecutorHelper* helper,
185 const ExecutionOptions& options) {
186 return std::make_shared<AsyncTaskGraph>(helper, options);
189 C10_DEFINE_SHARED_REGISTRY(
193 const ExecutionOptions&);
195 C10_REGISTER_CREATOR(TaskGraphRegistry, futures, GetAsyncTaskGraph);
197 REGISTER_NET(parallel, ParallelNet);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...