1 #include <torch/csrc/autograd/engine.h> 3 #include <torch/csrc/autograd/function.h> 4 #include <torch/csrc/autograd/functions/basic_ops.h> 5 #include <torch/csrc/autograd/grad_mode.h> 6 #include <torch/csrc/autograd/anomaly_mode.h> 7 #include <torch/csrc/autograd/variable.h> 8 #include <torch/csrc/utils/memory.h> 10 #include <ATen/DeviceGuard.h> 11 #include <ATen/ExpandUtils.h> 12 #include <c10/util/Exception.h> 15 #include <condition_variable> 24 #include <unordered_set> 30 namespace torch {
namespace autograd {
33 static constexpr
int NO_DEVICE = -2;
39 static thread_local
int worker_device = NO_DEVICE;
44 static thread_local
bool checkpoint_valid =
true;
55 std::shared_ptr<Function> fn;
64 , inputs(std::move(inputs)) {}
76 return t1.fn->sequence_nr() < t2.fn->sequence_nr();
83 std::condition_variable not_empty;
129 std::exception_ptr exception;
132 std::atomic_bool has_error;
133 std::atomic<uint64_t> outstanding_tasks;
140 std::condition_variable not_done;
141 std::unordered_map<Function*, InputBuffer> not_ready;
142 std::unordered_map<Function*, int> dependencies;
146 Capture(
int input_idx,
int output_idx) : input_idx(input_idx), output_idx(output_idx) {}
151 bool should_execute()
const {
152 return needed || captures;
156 std::unique_ptr<std::vector<Capture>> captures;
163 std::unordered_map<Function*, ExecInfo> exec_info;
164 std::vector<Variable> captured_vars;
166 void init_to_execute(Function& graph_root,
const edge_list& outputs);
172 bool can_checkpoint() {
173 return exec_info.empty();
176 GraphTask(
bool keep_graph,
bool grad_mode)
178 , outstanding_tasks(0)
179 , keep_graph(keep_graph)
180 , grad_mode(grad_mode)
181 , owner(NO_DEVICE) {}
186 std::lock_guard<std::mutex> lock(mutex);
187 ++item.base->outstanding_tasks;
188 heap.push(std::move(item));
190 not_empty.notify_one();
194 std::unique_lock<std::mutex> lock(mutex);
195 not_empty.wait(lock, [
this]{
return !heap.empty(); });
197 auto task = std::move(const_cast<FunctionTask&>(heap.top())); heap.pop();
201 Engine::Engine() =
default;
204 Engine::~Engine() =
default;
206 auto Engine::thread_init(
int device) ->
void {
239 static_cast<size_t>(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
242 for (
size_t i = 0; i < static_cast<size_t>(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES); i++) {
243 auto* impl = c10::impl::device_guard_impl_registry[i].load();
244 if (impl && device < impl->deviceCount()) {
245 guards[i].reset_device(
at::Device(static_cast<c10::DeviceType>(i), device));
249 worker_device = device;
250 thread_main(
nullptr);
267 auto Engine::thread_main(
GraphTask *graph_task) ->
void {
268 auto queue = ready_queues[worker_device + 1];
271 while (!graph_task || graph_task->outstanding_tasks > 0) {
273 if (task.fn && !task.base->has_error.load()) {
274 GradMode::set_enabled(task.base->grad_mode);
276 evaluate_function(task);
277 }
catch (std::exception& e) {
278 thread_on_exception(task, e);
284 auto base_owner = task.base->owner;
286 if (base_owner == NO_DEVICE) {
287 if (--task.base->outstanding_tasks == 0) {
288 std::lock_guard<std::mutex> lock(task.base->mutex);
289 task.base->not_done.notify_all();
294 if (base_owner == worker_device) {
295 --task.base->outstanding_tasks;
300 }
else if (base_owner != worker_device) {
301 if (--task.base->outstanding_tasks == 0) {
303 std::atomic_thread_fence(std::memory_order_release);
311 auto Engine::thread_on_exception(
FunctionTask& task, std::exception& e) ->
void {
312 std::lock_guard<std::mutex> lock(task.base->mutex);
313 if (!task.base->has_error.load()) {
314 if (AnomalyMode::is_enabled()) {
315 task.fn->metadata()->print_stack();
317 task.base->exception = std::current_exception();
318 task.base->has_error =
true;
322 static variable_list call_pre_hooks(Function& fn, variable_list inputs) {
323 for (
const auto& hook : fn.pre_hooks()) {
324 inputs = (*hook)(inputs);
329 static variable_list call_post_hooks(Function& fn, variable_list outputs,
const variable_list& inputs) {
330 for (
const auto& hook : fn.post_hooks()) {
331 outputs = (*hook)(outputs, inputs);
336 static bool is_compatible_type(
const at::Type& expected,
const at::Type& actual) {
339 return expected == actual || (actual.is_sparse() &&
340 expected == actual.toBackend(toDense(actual.backend())));
344 static void validate_outputs(
const edge_list& edges, variable_list& grads,
const F& format_error) {
345 if (grads.size() != edges.size()) {
346 std::stringstream ss;
347 ss <<
"invalid number of gradients - expected ";
348 ss << edges.size() <<
", but got " << grads.size();
349 AT_ERROR(format_error(ss.str()));
351 for (
size_t i = 0; i < grads.size(); i++) {
352 const auto& edge = edges[i];
353 if (!edge.is_valid())
continue;
355 const auto& metadata = edge.function->input_metadata(edge.input_nr);
356 const auto& output = grads[i];
357 if (!output.defined()) {
364 if (!grads[i].sizes().equals(metadata.shape())) {
365 if (!at::is_expandable_to(metadata.shape(), grads[i].sizes())) {
366 std::stringstream ss;
367 ss <<
"invalid gradient at index " << i <<
" - got ";
368 ss << grads[i].sizes() <<
" but expected shape compatible with ";
369 ss << metadata.shape();
370 AT_ERROR(format_error(ss.str()));
372 grads[i] = at::sum_to(std::move(grads[i]), metadata.shape());
374 if (!is_compatible_type(metadata.type(), grads[i].type())) {
375 std::stringstream ss;
376 ss <<
"invalid gradient at index " << i <<
" - expected type ";
377 ss << metadata.type() <<
" but got " << grads[i].type();
378 AT_ERROR(format_error(ss.str()));
380 auto output_device = output.device();
381 if (output_device != metadata.device()) {
382 std::stringstream ss;
383 ss <<
"invalid gradient at index " << i <<
" - expected device ";
384 ss << metadata.device() <<
" but got " << output_device;
385 AT_ERROR(format_error(ss.str()));
390 static variable_list call_function(
FunctionTask& task) {
391 bool prev_checkpoint_valid_state = checkpoint_valid;
392 checkpoint_valid = task.base->can_checkpoint() && prev_checkpoint_valid_state;
394 auto inputs = call_pre_hooks(fn, InputBuffer::variables(std::move(task.inputs)));
396 if(!task.base->keep_graph) {
397 fn.will_release_variables();
400 const auto has_post_hooks = !fn.post_hooks().empty();
401 variable_list outputs;
416 auto inputs_copy = inputs;
417 outputs = fn(std::move(inputs_copy));
419 outputs = fn(std::move(inputs));
422 validate_outputs(fn.next_edges(), outputs, [&](
const std::string& msg) {
423 std::ostringstream ss;
424 ss <<
"Function " << fn.name() <<
" returned an " << msg;
427 checkpoint_valid = prev_checkpoint_valid_state;
431 return call_post_hooks(fn, std::move(outputs), inputs);
436 auto Engine::evaluate_function(
FunctionTask& task) ->
void {
438 auto & exec_info = task.base->exec_info;
439 if (!exec_info.empty()) {
440 auto & fn_info = exec_info.at(task.fn.get());
441 if (
auto *capture_vec = fn_info.captures.get()) {
442 std::lock_guard<std::mutex> lock(task.base->mutex);
443 for (
auto capture : *capture_vec) {
444 task.base->captured_vars[capture.output_idx] = task.inputs[capture.input_idx];
447 if (!fn_info.needed)
return;
450 auto outputs = call_function(task);
453 if (!task.base->keep_graph) {
454 fn.release_variables();
457 int num_outputs = outputs.size();
458 if (num_outputs == 0)
return;
460 if (AnomalyMode::is_enabled()) {
462 for (
int i = 0; i < num_outputs; ++i) {
463 auto& output = outputs[i];
465 if (output.defined() && output.ne(output).any().item<uint8_t>()) {
466 std::stringstream ss;
467 ss <<
"Function '" << fn.name() <<
"' returned nan values in its " << i <<
"th output.";
468 throw std::runtime_error(ss.str());
473 std::lock_guard<std::mutex> lock(task.base->mutex);
474 for (
int i = 0; i < num_outputs; ++i) {
475 auto& output = outputs[i];
476 const auto& next = fn.next_edge(i);
478 if (!next.is_valid())
continue;
481 bool is_ready =
false;
482 auto& dependencies = task.base->dependencies;
483 auto it = dependencies.find(next.function.get());
484 if (it == dependencies.end()) {
485 auto name = next.function->name();
486 throw std::runtime_error(std::string(
"dependency not found for ") + name);
487 }
else if (--it->second == 0) {
488 dependencies.erase(it);
492 auto& not_ready = task.base->not_ready;
493 auto not_ready_it = not_ready.find(next.function.get());
494 if (not_ready_it == not_ready.end()) {
496 if (!exec_info.empty()) {
497 auto it = exec_info.find(next.function.get());
498 if (it == exec_info.end() || !it->second.should_execute()) {
503 InputBuffer input_buffer(next.function->num_inputs());
504 input_buffer.add(next.input_nr, std::move(output));
506 auto& queue = ready_queue(input_buffer.device());
507 queue.push(
FunctionTask(task.base, next.function, std::move(input_buffer)));
509 not_ready.emplace(next.function.get(), std::move(input_buffer));
513 auto &input_buffer = not_ready_it->second;
514 input_buffer.add(next.input_nr, std::move(output));
516 auto& queue = ready_queue(input_buffer.device());
517 queue.push(
FunctionTask(task.base, next.function, std::move(input_buffer)));
518 not_ready.erase(not_ready_it);
525 auto Engine::compute_dependencies(Function* root,
GraphTask& task) ->
void {
527 std::unordered_set<Function*> seen;
528 std::vector<Function*> queue { root };
532 auto& dependencies = task.dependencies;
533 while (!queue.empty()) {
534 auto fn = queue.back(); queue.pop_back();
535 for (
const auto& edge : fn->next_edges()) {
536 if (
auto next_ptr = edge.function.get()) {
537 dependencies[next_ptr] += 1;
538 const bool was_inserted = seen.insert(next_ptr).second;
539 if (was_inserted) queue.push_back(next_ptr);
547 std::mutex &callbacks_lock)
548 : callbacks(callbacks)
549 , callbacks_lock(callbacks_lock) { clear(); }
553 std::lock_guard<std::mutex> lock(callbacks_lock);
557 std::vector<std::function<void()>>& callbacks;
558 std::mutex& callbacks_lock;
561 auto Engine::execute(
const edge_list& roots,
562 const variable_list& inputs,
565 const edge_list& outputs) -> variable_list {
566 std::call_once(start_threads_flag, &Engine::start_threads,
this);
569 validate_outputs(roots, const_cast<variable_list&>(inputs), [](
const std::string& msg) {
576 GraphTask graph_task(keep_graph, create_graph);
577 std::unique_lock<std::mutex> lock(graph_task.mutex);
580 auto graph_root = std::make_shared<GraphRoot>(roots, inputs);
581 compute_dependencies(graph_root.get(), graph_task);
582 if (!outputs.empty()) {
583 graph_task.init_to_execute(*graph_root, outputs);
588 if (worker_device == NO_DEVICE) {
590 graph_task.not_done.wait(lock, [&graph_task]{
591 return graph_task.outstanding_tasks.load() == 0;
597 graph_task.owner = worker_device;
599 thread_main(&graph_task);
603 if (graph_task.has_error.load()) {
604 std::rethrow_exception(graph_task.exception);
607 if (!graph_task.not_ready.empty()) {
608 throw std::runtime_error(
"could not compute gradients for some functions");
614 std::unique_lock<std::mutex> cb_lock(post_callbacks_lock);
618 for (
size_t i = 0; i < final_callbacks.size(); ++i) {
620 final_callbacks[i]();
624 return graph_task.captured_vars;
630 static Engine& get_base_engine() {
635 std::atomic<EngineStub> engine_stub(get_base_engine);
637 void set_default_engine_stub(EngineStub stub) {
638 engine_stub.store(stub);
643 return engine_stub.load()();
646 void Engine::queue_callback(std::function<
void()> callback) {
647 std::lock_guard<std::mutex> lock(post_callbacks_lock);
648 final_callbacks.emplace_back(std::move(callback));
651 bool Engine::is_checkpoint_valid() {
652 return checkpoint_valid;
657 if (device.
type() == at::kCPU) {
658 return *ready_queues.at(0);
660 return *ready_queues.at(device.
index() + 1);
667 auto Engine::ready_queue_by_index(
int device_index) ->
ReadyQueue& {
668 return *ready_queues.at(device_index + 1);
671 auto Engine::start_threads() ->
void {
674 for (
const auto& impl_atomic : c10::impl::device_guard_impl_registry) {
675 auto* impl = impl_atomic.load();
677 num_devices = std::max(num_devices, impl->deviceCount());
683 int num_threads = num_devices + 1;
684 ready_queues = std::vector<std::shared_ptr<ReadyQueue>>(num_threads);
685 for (
auto& queue : ready_queues)
687 for (
int i = 0; i < num_threads; ++i) {
688 std::thread t(&Engine::thread_init,
this, i - 1);
693 void GraphTask::init_to_execute(Function& graph_root,
const edge_list& outputs) {
694 exec_info[&graph_root].needed =
true;
697 for (
auto & output_edge : outputs) {
698 Function *output = output_edge.function.get();
699 auto & info = exec_info[output];
701 info.captures = make_unique<std::vector<ExecInfo::Capture>>();
702 info.captures->emplace_back(output_edge.input_nr, output_idx++);
704 captured_vars.resize(output_idx);
714 Frame (Function *fn) : fn(fn), next_next_fn(0) {}
718 Function* get_next_fn() {
719 const auto & next = fn->next_edges();
720 auto num_next = next.size();
721 while (next_next_fn < num_next) {
722 auto fn = next[next_next_fn++].function.get();
728 std::vector<Frame> stack;
729 std::unordered_set<Function*> seen;
730 for (
const auto & input : graph_root.next_edges()) {
731 if (seen.count(input.function.get()) > 0)
continue;
732 stack.emplace_back(input.function.get());
733 while (!stack.empty()) {
734 auto &frame = stack.back();
735 if (Function *next_fn = frame.get_next_fn()) {
736 if ( seen.emplace(next_fn).second) {
737 stack.emplace_back(next_fn);
744 const auto & next_edges = frame.fn->next_edges();
745 const bool needed = std::any_of(
746 next_edges.begin(), next_edges.end(), [&](
const Edge& edge) {
747 auto it = exec_info.find(edge.function.get());
748 return it != exec_info.end() && it->second.should_execute();
750 exec_info[frame.fn].needed = needed;
optional< Device > device_of(Tensor t)
Return the Device of a Tensor, if the Tensor is defined.
Represents a particular input of a function.
static Engine & get_default_engine()
Returns a reference to a static Engine instance.
Represents a a compute device on which a tensor is located.
int16_t DeviceIndex
An index representing a specific device; e.g., the 1 in GPU 1.
A OptionalDeviceGuard is an RAII class that sets a device to some value on initialization, and resets the device to its original value on destruction.
DeviceIndex index() const noexcept
Returns the optional index.
DeviceType type() const noexcept
Returns the type of device this is.