Caffe2 - C++ API
A deep learning, cross platform ML framework
graph_executor.h
1 #pragma once
2 
3 #include <torch/csrc/jit/argument_spec.h>
4 #include <torch/csrc/jit/autodiff.h>
5 #include <torch/csrc/jit/interpreter.h>
6 #include <torch/csrc/jit/ir.h>
7 #include <torch/csrc/jit/variable_tensor_list.h>
8 #include <memory>
9 
10 namespace torch {
11 namespace jit {
12 
13 struct GraphExecutorState;
14 
15 // Notice that those structs don't manage lifetime of their members.
16 // They is only valid only right after you call getDebugState() and should never
17 // be used again once another GraphExecutor function is called.
19  Code* code = nullptr;
20  const Graph* graph = nullptr;
21 };
22 
24  const Graph* graph = nullptr;
25  ExecutionPlanState fallback; // XXX: members of this field are optional
26  std::unordered_map<ArgumentSpec, ExecutionPlanState> execution_plans;
27 };
28 
29 struct GraphExecutorImpl;
30 struct TORCH_API GraphExecutor {
31  GraphExecutor() = default;
32  GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true);
33  void run(Stack& inputs);
34  explicit operator bool() const {
35  return pImpl != nullptr;
36  }
37  std::shared_ptr<Graph> graph() const;
38  std::shared_ptr<Graph> graphFor(const Stack& inputs) const;
39  GraphExecutorState getDebugState();
40  void debugDisableAutodiffSubgraphInlining();
41 
42  private:
43  std::shared_ptr<GraphExecutorImpl> pImpl;
44 };
45 
46 // These passes need to run before it is valid to pass to the interpreter
47 // regardless of whether sizes have been specialized or not.
48 TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
49 
50 namespace detail {
51 
52 GraphExecutor* getGradExecutor(Operation& op);
53 
54 } // namespace detail
55 
56 } // namespace jit
57 } // namespace torch
Definition: jit_type.h:17