Caffe2 - C++ API
A deep learning, cross platform ML framework
engine.h
1 #pragma once
2 
3 // Engine implements backpropagation from output variables and their gradients
4 // to "root" variables (variables created by the user with requires_grad=True).
5 
6 #include <torch/csrc/WindowsTorchApiMacro.h>
7 #include <torch/csrc/autograd/function.h>
8 #include <torch/csrc/autograd/input_buffer.h>
9 #include <torch/csrc/autograd/anomaly_mode.h>
10 
11 #include <deque>
12 #include <exception>
13 #include <functional>
14 #include <memory>
15 #include <unordered_map>
16 #include <utility>
17 #include <vector>
18 
19 namespace torch { namespace autograd {
20 struct ReadyQueue;
21 struct FunctionTask;
22 struct GraphTask;
23 }} // namespace torch::autograd
24 
25 namespace torch { namespace autograd {
26 // A single instance of this struct should be created through the whole process lifetime.
27 // The worker thread creation logic and Engine's destructor rely on this.
28 struct TORCH_API Engine {
30  static Engine& get_default_engine();
31 
32  Engine();
33  virtual ~Engine();
34 
35  using ready_queue_type = std::deque<std::pair<std::shared_ptr<Function>, InputBuffer>>;
36  using dependencies_type = std::unordered_map<Function*, int>;
37 
38  // Given a list of (Function, input number) pairs computes the value of the graph
39  // by following next_edge references.
40  virtual variable_list execute(
41  const edge_list& roots,
42  const variable_list& inputs,
43  bool keep_graph,
44  bool create_graph,
45  const edge_list& outputs = {});
46  virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() {
47  return nullptr;
48  }
49 
50  void queue_callback(std::function<void()> callback);
51 
52  bool is_checkpoint_valid();
53 
54 protected:
55  void compute_dependencies(Function* root, GraphTask& task);
56  void evaluate_function(FunctionTask& task);
57  ReadyQueue& ready_queue(at::Device device);
58  ReadyQueue& ready_queue_by_index(int device_index);
59  void start_threads();
60  virtual void thread_init(int device);
61  virtual void thread_main(GraphTask *graph_task);
62  virtual void thread_on_exception(FunctionTask& task, std::exception& e);
63 
64  std::once_flag start_threads_flag;
65  std::vector<std::shared_ptr<ReadyQueue>> ready_queues;
66  std::vector<std::function<void()>> final_callbacks;
67  std::mutex post_callbacks_lock;
68 };
69 
70 // allow python_engine to override the default engine when it loads
71 using EngineStub = Engine& (*)();
72 TORCH_API void set_default_engine_stub(EngineStub stub);
73 
74 }} // namespace torch::autograd
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
Definition: jit_type.h:17