1 #ifndef CAFFE2_NET_ASYNC_TASK_GRAPH_H 2 #define CAFFE2_NET_ASYNC_TASK_GRAPH_H 4 #include "caffe2/core/net_async_base.h" 5 #include "caffe2/core/net_async_task.h" 6 #include "caffe2/core/net_async_task_future.h" 7 #include "caffe2/core/operator.h" 22 virtual bool CreateNode(
24 const std::vector<OperatorBase*>& ops) = 0;
26 virtual bool AddDependency(
28 const std::vector<int>& parent_node_ids) = 0;
30 virtual void FreezeGraph() = 0;
36 virtual void Reset() = 0;
45 bool CreateNode(
int node_id,
const std::vector<OperatorBase*>& ops)
override;
47 bool AddDependency(
int child_node_id,
const std::vector<int>& parent_node_ids)
50 void FreezeGraph()
override;
56 void Reset()
override;
66 std::unordered_map<int, std::unique_ptr<AsyncTask>> nodes_;
67 std::unordered_map<int, std::unordered_set<int>> parents_;
68 std::unordered_map<int, std::unordered_set<int>> children_;
69 std::vector<std::unique_ptr<AsyncTaskFuture>> edge_futures_;
71 std::vector<AsyncTask*> root_tasks_;
73 std::unique_ptr<AsyncTaskFuture> run_future_;
78 #endif // CAFFE2_NET_ASYNC_TASK_GRAPH_H
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...