Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_task_graph.h
1 #ifndef CAFFE2_NET_ASYNC_TASK_GRAPH_H
2 #define CAFFE2_NET_ASYNC_TASK_GRAPH_H
3 
4 #include "caffe2/core/net_async_base.h"
5 #include "caffe2/core/net_async_task.h"
6 #include "caffe2/core/net_async_task_future.h"
7 #include "caffe2/core/operator.h"
8 
9 namespace caffe2 {
10 
11 // AsyncTaskGraph represents an execution of a net, it owns the tasks and
12 // associated futures, sets up future callbacks and propagates errors.
13 // Usage steps:
14 // - Adding graph nodes and edges through CreateNode/AddDependency;
15 // - Freezing the graph (FreezeGraph), after the freezing a future
16 // can be obtained using GetFuture;
17 // - Execution of the graph is scheduled through ExecuteGraph, after each
18 // execution Reset must be called to prepare the graph for the next run
19 
21  public:
22  virtual bool CreateNode(
23  int node_id,
24  const std::vector<OperatorBase*>& ops) = 0;
25 
26  virtual bool AddDependency(
27  int child_node_id,
28  const std::vector<int>& parent_node_ids) = 0;
29 
30  virtual void FreezeGraph() = 0;
31 
32  virtual AsyncTaskFuture* ExecuteGraph() = 0;
33 
34  virtual AsyncTaskFuture* GetFuture() = 0;
35 
36  virtual void Reset() = 0;
37 
38  virtual ~AsyncTaskGraphBase() noexcept {}
39 };
40 
42  public:
43  AsyncTaskGraph(ExecutorHelper* helper, const ExecutionOptions& options);
44 
45  bool CreateNode(int node_id, const std::vector<OperatorBase*>& ops) override;
46 
47  bool AddDependency(int child_node_id, const std::vector<int>& parent_node_ids)
48  override;
49 
50  void FreezeGraph() override;
51 
52  AsyncTaskFuture* ExecuteGraph() override;
53 
54  AsyncTaskFuture* GetFuture() override;
55 
56  void Reset() override;
57 
58  private:
59  // used to, e.g., get access to executor's thread pools
60  // TODO: pass tracer and counters through ExecutorHelper
61  ExecutorHelper* helper_;
62  ExecutionOptions options_;
63 
64  bool frozen_;
65 
66  std::unordered_map<int, std::unique_ptr<AsyncTask>> nodes_;
67  std::unordered_map<int, std::unordered_set<int>> parents_;
68  std::unordered_map<int, std::unordered_set<int>> children_;
69  std::vector<std::unique_ptr<AsyncTaskFuture>> edge_futures_;
70 
71  std::vector<AsyncTask*> root_tasks_;
72 
73  std::unique_ptr<AsyncTaskFuture> run_future_;
74 };
75 
76 } // namespace caffe2
77 
78 #endif // CAFFE2_NET_ASYNC_TASK_GRAPH_H
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13