1 #include "caffe2/core/net_parallel.h"     3 #include "caffe2/core/operator.h"     8     caffe2_task_graph_engine,
    10     "Task graph engine type used by net executor");
    14 ParallelNet::ParallelNet(
    15     const std::shared_ptr<const NetDef>& net_def,
    17     : NetBase(net_def, ws), options_(net_def), run_future_(nullptr) {
    18   num_workers_ = net_def->num_workers();
    20       num_workers_, 0, 
"Expected positive number of worker threads");
    22   helper_ = caffe2::make_unique<ParallelNetExecutorHelper>(
this);
    25   operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
    26   operators_.reserve(operator_nodes_.size());
    27   for (
const auto& node : operator_nodes_) {
    28     auto op = node.operator_.get();
    29     op->SetExecutorHelper(helper_.get());
    30     operators_.push_back(op);
    33   task_graph_ = TaskGraphRegistry()->Create(
    34       FLAGS_caffe2_task_graph_engine, helper_.get(), options_);
    35   CAFFE_ENFORCE(task_graph_, 
"Couldn't initialize task graph");
    39   auto execution_chains = dag_utils::computeChains(operator_nodes_);
    40   std::vector<std::vector<int>> chains;
    41   chains.reserve(execution_chains.size());
    42   for (
const auto& kv : execution_chains) {
    43     chains.push_back(kv.second);
    45   auto chain_nodes = dag_utils::prepareChainGraphNodes(operator_nodes_, chains);
    46   CAFFE_ENFORCE_EQ(chains.size(), chain_nodes.size());
    49   for (
const auto& chain : chains) {
    50     for (
const auto& op_id : chain) {
    51       if (op_id == chain.back() || op_id == chain.front()) {
    54       auto op = operators_[op_id];
    55       if (IsCPUDeviceType(op->device_option().device_type()) &&
    64   for (
auto chain_id = 0; chain_id < chains.size(); ++chain_id) {
    65     std::vector<OperatorBase*> ops;
    66     ops.reserve(chains[chain_id].size());
    67     for (
auto op_id : chains[chain_id]) {
    68       ops.push_back(operators_[op_id]);
    70     CAFFE_ENFORCE(task_graph_->CreateNode(chain_id, ops));
    72   for (
auto chain_id = 0; chain_id < chain_nodes.size(); ++chain_id) {
    73     if (!chain_nodes[chain_id].parents_.empty()) {
    75           task_graph_->AddDependency(chain_id, chain_nodes[chain_id].parents_));
    80   task_graph_->FreezeGraph();
    81   run_future_ = task_graph_->GetFuture();
    82   run_future_->SetCallback([
this](
const AsyncTaskFuture* ) {
    87   LOG(INFO) << 
"Initialized parallel net: '" << Name()
    88             << 
"', #ops: " << net_def->op_size()
    89             << 
", #chains: " << chains.size() << 
", #workers: " << num_workers_
    90             << 
", dfs scheduling: " << options_.use_dfs_scheduling_
    91             << 
", task graph engine: " << FLAGS_caffe2_task_graph_engine;
    94 bool ParallelNet::RunAsync() {
    99     task_graph_->ExecuteGraph();
   100   } 
catch (
const std::exception&) {
   108 void ParallelNet::Wait() {
   109   CAFFE_ENFORCE(run_future_);
   113 void ParallelNet::reset() {
   114   task_graph_->Reset();
   117 bool ParallelNet::handleRunError() {
   118   CAFFE_ENFORCE(run_future_ && run_future_->IsCompleted());
   120   if (run_future_->IsFailed()) {
   121     LOG(ERROR) << 
"Failed parallel run (" << Name()
   122                << 
"): " << run_future_->ErrorMessage();
   124   return !run_future_->IsFailed();
   127 TaskThreadPoolBase* ParallelNet::poolGetter(
   132   std::unique_lock<std::mutex> pools_lock(pools_mutex_);
   133   auto pool = pools[device_id][pool_size];
   135     pool = c10::ThreadPoolRegistry()->Create(
   136         DeviceTypeName(device_type),
   139         options_.use_per_net_pools_);
   140     pools[device_id][pool_size] = pool;
   145 TaskThreadPoolBase* ParallelNet::Pool(
const DeviceOption& device_option) {
   146   if (options_.use_single_pool_) {
   147     return poolGetter(cpu_pools_, PROTO_CPU, -1, num_workers_);
   149   const auto device_type = device_option.device_type();
   150   if (IsCPUDeviceType(device_type)) {
   151     auto numa_node_id = -1;
   152     if (device_option.has_numa_node_id()) {
   153       numa_node_id = device_option.numa_node_id();
   154       CAFFE_ENFORCE_GE(numa_node_id, 0, 
"Invalid NUMA node id: ", numa_node_id);
   158         FLAGS_caffe2_net_async_max_numa_nodes,
   159         "Invalid NUMA node id: ",
   161     return poolGetter(cpu_pools_, device_type, numa_node_id, num_workers_);
   162   } 
else if (IsGPUDeviceType(device_type)) {
   163     auto gpu_id = device_option.device_id();
   165         gpu_id >= 0 && gpu_id < FLAGS_caffe2_net_async_max_gpus,
   166         "Invalid GPU id: " + caffe2::to_string(gpu_id));
   167     return poolGetter(gpu_pools_, device_type, gpu_id, num_workers_);
   169     CAFFE_THROW(
"Unsupported device type " + caffe2::to_string(device_type));
   173 bool ParallelNet::SupportsAsync() {
   177 void ParallelNet::finishRun() {}
   179 std::vector<OperatorBase*> ParallelNet::GetOperators()
 const {
   183 std::shared_ptr<AsyncTaskGraphBase> GetAsyncTaskGraph(
   184     ExecutorHelper* helper,
   185     const ExecutionOptions& options) {
   186   return std::make_shared<AsyncTaskGraph>(helper, options);
   189 C10_DEFINE_SHARED_REGISTRY(
   193     const ExecutionOptions&);
   195 C10_REGISTER_CREATOR(TaskGraphRegistry, futures, GetAsyncTaskGraph);
   197 REGISTER_NET(parallel, ParallelNet);
 A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...