6 #include <torch/csrc/WindowsTorchApiMacro.h> 7 #include <torch/csrc/autograd/function.h> 8 #include <torch/csrc/autograd/input_buffer.h> 9 #include <torch/csrc/autograd/anomaly_mode.h> 15 #include <unordered_map> 19 namespace torch {
namespace autograd {
25 namespace torch {
namespace autograd {
30 static Engine& get_default_engine();
35 using ready_queue_type = std::deque<std::pair<std::shared_ptr<Function>,
InputBuffer>>;
36 using dependencies_type = std::unordered_map<Function*, int>;
40 virtual variable_list execute(
41 const edge_list& roots,
42 const variable_list& inputs,
45 const edge_list& outputs = {});
46 virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() {
50 void queue_callback(std::function<
void()> callback);
52 bool is_checkpoint_valid();
55 void compute_dependencies(Function* root,
GraphTask& task);
58 ReadyQueue& ready_queue_by_index(
int device_index);
60 virtual void thread_init(
int device);
61 virtual void thread_main(
GraphTask *graph_task);
62 virtual void thread_on_exception(
FunctionTask& task, std::exception& e);
64 std::once_flag start_threads_flag;
65 std::vector<std::shared_ptr<ReadyQueue>> ready_queues;
66 std::vector<std::function<void()>> final_callbacks;
67 std::mutex post_callbacks_lock;
71 using EngineStub =
Engine& (*)();
72 TORCH_API
void set_default_engine_stub(EngineStub stub);
Represents a a compute device on which a tensor is located.