Caffe2 - C++ API
A deep learning, cross platform ML framework
engine.cpp
1 #include <torch/csrc/autograd/engine.h>
2 
3 #include <torch/csrc/autograd/function.h>
4 #include <torch/csrc/autograd/functions/basic_ops.h>
5 #include <torch/csrc/autograd/grad_mode.h>
6 #include <torch/csrc/autograd/anomaly_mode.h>
7 #include <torch/csrc/autograd/variable.h>
8 #include <torch/csrc/utils/memory.h>
9 
10 #include <ATen/DeviceGuard.h>
11 #include <ATen/ExpandUtils.h>
12 #include <c10/util/Exception.h>
13 
14 #include <atomic>
15 #include <condition_variable>
16 #include <cstdint>
17 #include <functional>
18 #include <iostream>
19 #include <memory>
20 #include <mutex>
21 #include <set>
22 #include <string>
23 #include <thread>
24 #include <unordered_set>
25 #include <typeinfo>
26 #include <sstream>
27 #include <queue>
28 #include <TH/TH.h>
29 
30 namespace torch { namespace autograd {
31 
32 // NB: -1 indicates the CPU worker!
33 static constexpr int NO_DEVICE = -2;
34 
35 // Threads spawned by the engine are assigned a constant 'worker_device'
36 // specifying what device they process work for. This variable is initialized
37 // at thread creation time and is constant afterwards. This is used when
38 // handling reentrant backwards calls; see Note [Reentrant backwards]
39 static thread_local int worker_device = NO_DEVICE;
40 
41 // This variable is true if ALL invocations in the stack of re-entrant engine
42 // invocations are imperative backwards. This special variable is needed for the
43 // gradient checkpointing feature only.
44 static thread_local bool checkpoint_valid = true;
45 
46 // XXX: Changes to the way multithreading works in execute should be done with
47 // great care. Right now the implementation guarantees that a single function's
48 // apply will never be entered concurrently (even if multiple graphs are
49 // executed at the same time). Adding multiple threads per-device or removing
50 // engine thread affinity to the device can break this invariant, and we depend
51 // on it in a few places (e.g. AccumulateGrad function).
52 
53 struct FunctionTask {
54  GraphTask* base;
55  std::shared_ptr<Function> fn;
56  // This buffer serves as an implicit "addition" node for all of the
57  // gradients flowing here. Once all the dependencies are finished, we
58  // use the contents of this buffer to run the function.
59  InputBuffer inputs;
60 
61  FunctionTask(GraphTask* base, std::shared_ptr<Function> fn, InputBuffer inputs)
62  : base(base)
63  , fn(std::move(fn))
64  , inputs(std::move(inputs)) {}
65 };
66 
67 // Returns true when t2 should be (weakly) BEFORE t1 in the queue.
68 // Empty FunctionTask are first.
70  bool operator()(FunctionTask const & t1, FunctionTask const & t2) {
71  if (!t1.fn) {
72  return false;
73  } else if (!t2.fn) {
74  return true;
75  } else {
76  return t1.fn->sequence_nr() < t2.fn->sequence_nr();
77  }
78  }
79 };
80 
81 struct ReadyQueue {
82  std::priority_queue<FunctionTask, std::vector<FunctionTask>, CompareFunctionTaskTime> heap;
83  std::condition_variable not_empty;
84  std::mutex mutex;
85 
86  void push(FunctionTask item);
87  FunctionTask pop();
88 };
89 
90 // Note [Reentrant backwards]
91 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
92 // To understand the reentrant backwards problem, we have to notice two
93 // aspects of how the autograd engine is implemented today:
94 //
95 // 1. When you call Engine::execute(), you want to block until
96 // differentiation finishes so that you can get the final result variables
97 // of the backwards pass.
98 //
99 // 2. The engine operates by having a single worker thread per work queue,
100 // and every work queue is pinned to a specific device where the
101 // operation is executed.
102 //
103 // The problem is, suppose that you call backward() inside of a worker
104 // thread. By property (1), we're supposed to block until the nested task
105 // finishes. However, by property (2), this worker thread is on the
106 // hook for processing the tasks assigned to it; we better not block,
107 // because then all of our backward executions (including the one we
108 // just started) will deadlock!
109 //
110 // Here's our cunning idea: instead of blocking, just get back to work
111 // on whatever task queue you should have been working on previously
112 // (this is saved via the thread local variable worker_device)! There are
113 // "simply" two things you have to arrange for:
114 //
115 // - We have to promptly kick ourselves out of the thread_main() loop
116 // when our graph_task complete, because we need to unblock the
117 // parent function tasks that started the reentrant execution in
118 // the first place. This is why thread_main() takes an optional
119 // graph_task as input.
120 //
121 // - When we finish a GraphTask, we have to make sure we wake up the worker
122 // thread so that it actually has a chance to exit the thread_main()
123 // loop. Thus the faffing about in thread_main() after
124 // evaluate_function() completes.
125 
126 
127 // GraphTask holds metadata needed for a single execution of backward()
128 struct GraphTask {
129  std::exception_ptr exception;
130  // Indicates if an error occurred while executing any task. When this is
131  // true, it signals all threads to stop executing.
132  std::atomic_bool has_error;
133  std::atomic<uint64_t> outstanding_tasks;
134  bool keep_graph;
135  bool grad_mode;
136 
137  std::mutex mutex;
138  // Notified when a task finishes executing. Check outstanding_tasks to see
139  // if all tasks are done.
140  std::condition_variable not_done;
141  std::unordered_map<Function*, InputBuffer> not_ready;
142  std::unordered_map<Function*, int> dependencies;
143 
144  struct ExecInfo {
145  struct Capture {
146  Capture(int input_idx, int output_idx) : input_idx(input_idx), output_idx(output_idx) {}
147  int input_idx; // within Function inputs
148  int output_idx; // within the output vector of a GraphTask
149  };
150 
151  bool should_execute() const {
152  return needed || captures;
153  }
154 
155  bool needed = false;
156  std::unique_ptr<std::vector<Capture>> captures;
157  };
158  // Exec info has a bit complicated semantics. If it's empty, it means the task is
159  // run in a "default" mode, which means that all next_edges we encounter should
160  // get executed. If it's not empty, only functions that have an entry and this entry
161  // has needed == True should be executed.
162  // exec_info.empty() means it's .backward(), otherwise it's .grad().
163  std::unordered_map<Function*, ExecInfo> exec_info;
164  std::vector<Variable> captured_vars;
165 
166  void init_to_execute(Function& graph_root, const edge_list& outputs);
167 
168  // The value of worker_device in the thread that created this task.
169  // See Note [Reentrant backwards]
170  int owner;
171 
172  bool can_checkpoint() {
173  return exec_info.empty();
174  }
175 
176  GraphTask(bool keep_graph, bool grad_mode)
177  : has_error(false)
178  , outstanding_tasks(0)
179  , keep_graph(keep_graph)
180  , grad_mode(grad_mode)
181  , owner(NO_DEVICE) {}
182 };
183 
184 auto ReadyQueue::push(FunctionTask item) -> void {
185  {
186  std::lock_guard<std::mutex> lock(mutex);
187  ++item.base->outstanding_tasks;
188  heap.push(std::move(item));
189  }
190  not_empty.notify_one();
191 }
192 
193 auto ReadyQueue::pop() -> FunctionTask {
194  std::unique_lock<std::mutex> lock(mutex);
195  not_empty.wait(lock, [this]{ return !heap.empty(); });
196  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
197  auto task = std::move(const_cast<FunctionTask&>(heap.top())); heap.pop();
198  return task;
199 }
200 
201 Engine::Engine() = default;
202 
203 // This Engine's ReadyQueues and their corresponding threads are leaked here
204 Engine::~Engine() = default;
205 
206 auto Engine::thread_init(int device) -> void {
207  THInferNumThreads();
208  // Note [Allocating GPUs to autograd threads]
209  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
210  // What's our strategy here? Originally, the autograd engine was written
211  // with only CUDA in mind. We allocate one thread to handle all CPU
212  // operations, and a thread per CUDA device.
213  //
214  // But what if we have OTHER devices? There are two plausible
215  // strategies:
216  //
217  // - We can allocate threads equal to max(num_cuda_devices, num_xla_devices,
218  // ...) and colocate cuda device 0 with xla device 0
219  // - We can allocate threads equal to sum(num_cuda_devices, num_xla_devices,
220  // ...) keeping everyone separate.
221  //
222  // We don't have any good reason to prefer one or the other, so we've
223  // arbitrarily picked to colocate devices. Maybe the other approach is
224  // better.
225  //
226  // NB: We MUST NOT construct the guard for device -1,
227  // as in some settings we compile with cuda, but
228  // have lazy stubs for CUDA functionality (so actually
229  // attempting to setup a guard(-1) will cause an
230  // error, because it will still query cudaGetDevice).
231  //
232  // NB: These are not OptionalCUDAGuard/etc because engine.cpp
233  // is built as part of the CPU-only library; so we need to
234  // dynamic dispatch.
235  //
236  // NB: We need an array here since neither DeviceGuard nor OptionalDeviceGuard
237  // are movable.
238  std::array<c10::OptionalDeviceGuard,
239  static_cast<size_t>(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
240  guards; // Guards! Guards!
241  if (device != -1) {
242  for (size_t i = 0; i < static_cast<size_t>(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES); i++) {
243  auto* impl = c10::impl::device_guard_impl_registry[i].load();
244  if (impl && device < impl->deviceCount()) {
245  guards[i].reset_device(at::Device(static_cast<c10::DeviceType>(i), device));
246  }
247  }
248  }
249  worker_device = device;
250  thread_main(nullptr);
251 }
252 
253 // NOTE: graph_tasks do not necessarily form a stack. Imagine this
254 // case:
255 //
256 // +----> Eval1
257 // Root
258 // +----> Eval2
259 //
260 // Once Root is executed, both Eval1 and Eval2 are added to the ready queue.
261 // Next, Eval1 is run and this causes the worker to enter thread_main again.
262 // Then, it pops the next task from the queue, but at this point it is Eval2.
263 // It enters thread_main once again, but now with graph_task of Eval2, which is
264 // completely unrelated to that of Eval1 (it's not a recursive call).
265 // It's all ok and is handled right now, but it should be accounted for
266 // in case this code is to be changed.
267 auto Engine::thread_main(GraphTask *graph_task) -> void {
268  auto queue = ready_queues[worker_device + 1];
269  // Why the test on graph_task->outstanding_tasks? See
270  // Note [Reentrant backwards]
271  while (!graph_task || graph_task->outstanding_tasks > 0) {
272  FunctionTask task = queue->pop();
273  if (task.fn && !task.base->has_error.load()) {
274  GradMode::set_enabled(task.base->grad_mode);
275  try {
276  evaluate_function(task);
277  } catch (std::exception& e) {
278  thread_on_exception(task, e);
279  }
280  }
281  // Notify downstream about the completion of tasks depending
282  // on both where the task was executed, and who owned the overall
283  // graph (in case of reentrant execution.) See Note [Reentrant backwards].
284  auto base_owner = task.base->owner;
285  // Task from a non-worker thread. Easy case.
286  if (base_owner == NO_DEVICE) {
287  if (--task.base->outstanding_tasks == 0) {
288  std::lock_guard<std::mutex> lock(task.base->mutex);
289  task.base->not_done.notify_all();
290  }
291  } else {
292  // If it's a task initiated from this thread, decrease the counter, but
293  // don't do anything - loop condition will do all checks for us next.
294  if (base_owner == worker_device) {
295  --task.base->outstanding_tasks;
296  // Otherwise send a dummy function task to the owning thread just to
297  // ensure that it's not sleeping. If it has work, it might see that
298  // graph_task->outstanding_tasks == 0 before it gets to the task, but
299  // it's a no-op anyway.
300  } else if (base_owner != worker_device) {
301  if (--task.base->outstanding_tasks == 0) {
302  // Synchronize outstanding_tasks with queue mutex
303  std::atomic_thread_fence(std::memory_order_release);
304  ready_queue_by_index(base_owner).push(FunctionTask(task.base, nullptr, InputBuffer(0)));
305  }
306  }
307  }
308  }
309 }
310 
311 auto Engine::thread_on_exception(FunctionTask& task, std::exception& e) -> void {
312  std::lock_guard<std::mutex> lock(task.base->mutex);
313  if (!task.base->has_error.load()) {
314  if (AnomalyMode::is_enabled()) {
315  task.fn->metadata()->print_stack();
316  }
317  task.base->exception = std::current_exception();
318  task.base->has_error = true;
319  }
320 }
321 
322 static variable_list call_pre_hooks(Function& fn, variable_list inputs) {
323  for (const auto& hook : fn.pre_hooks()) {
324  inputs = (*hook)(inputs);
325  }
326  return inputs;
327 }
328 
329 static variable_list call_post_hooks(Function& fn, variable_list outputs, const variable_list& inputs) {
330  for (const auto& hook : fn.post_hooks()) {
331  outputs = (*hook)(outputs, inputs);
332  }
333  return outputs;
334 }
335 
336 static bool is_compatible_type(const at::Type& expected, const at::Type& actual) {
337  // Types are compatible if they exactly match or if the gradient is a sparse
338  // version of the expected type.
339  return expected == actual || (actual.is_sparse() &&
340  expected == actual.toBackend(toDense(actual.backend())));
341 }
342 
343 template<typename F>
344 static void validate_outputs(const edge_list& edges, variable_list& grads, const F& format_error) {
345  if (grads.size() != edges.size()) {
346  std::stringstream ss;
347  ss << "invalid number of gradients - expected ";
348  ss << edges.size() << ", but got " << grads.size();
349  AT_ERROR(format_error(ss.str()));
350  }
351  for (size_t i = 0; i < grads.size(); i++) {
352  const auto& edge = edges[i];
353  if (!edge.is_valid()) continue;
354 
355  const auto& metadata = edge.function->input_metadata(edge.input_nr);
356  const auto& output = grads[i];
357  if (!output.defined()) {
358  // FIXME: TestJit.test_ge_optimized fails this assertion.
359  // std::stringstream ss;
360  // ss << "undefined gradient at index " << i;
361  // AT_ERROR(format_error(ss.str()));
362  continue;
363  }
364  if (!grads[i].sizes().equals(metadata.shape())) {
365  if (!at::is_expandable_to(metadata.shape(), grads[i].sizes())) {
366  std::stringstream ss;
367  ss << "invalid gradient at index " << i << " - got ";
368  ss << grads[i].sizes() << " but expected shape compatible with ";
369  ss << metadata.shape();
370  AT_ERROR(format_error(ss.str()));
371  }
372  grads[i] = at::sum_to(std::move(grads[i]), metadata.shape());
373  }
374  if (!is_compatible_type(metadata.type(), grads[i].type())) {
375  std::stringstream ss;
376  ss << "invalid gradient at index " << i << " - expected type ";
377  ss << metadata.type() << " but got " << grads[i].type();
378  AT_ERROR(format_error(ss.str()));
379  }
380  auto output_device = output.device();
381  if (output_device != metadata.device()) {
382  std::stringstream ss;
383  ss << "invalid gradient at index " << i << " - expected device ";
384  ss << metadata.device() << " but got " << output_device;
385  AT_ERROR(format_error(ss.str()));
386  }
387  }
388 }
389 
390 static variable_list call_function(FunctionTask& task) {
391  bool prev_checkpoint_valid_state = checkpoint_valid;
392  checkpoint_valid = task.base->can_checkpoint() && prev_checkpoint_valid_state;
393  auto& fn = *task.fn;
394  auto inputs = call_pre_hooks(fn, InputBuffer::variables(std::move(task.inputs)));
395 
396  if(!task.base->keep_graph) {
397  fn.will_release_variables();
398  }
399 
400  const auto has_post_hooks = !fn.post_hooks().empty();
401  variable_list outputs;
402 
403  if(has_post_hooks){
404  // In functions/accumulate_grad.cpp, there is some logic to check the conditions under which
405  // the incoming gradient can be stolen directly (which elides a deep copy) instead of cloned.
406  // One of these conditions is that the incoming gradient's refcount must be 1 (nothing else
407  // is referencing the same data). Stashing inputs_copy here bumps the refcount, so if post hooks
408  // are employed, it's actually still ok for accumulate_grad.cpp to steal the gradient if the
409  // refcount is 2.
410  //
411  // "new_grad.use_count() <= 1 + !post_hooks().empty()" in accumulate_grad.cpp accounts for this,
412  // but also creates a silent dependency between engine.cpp (ie, this particular engine
413  // implementation) and accumulate_grad.cpp.
414  //
415  // If you change the logic here, make sure it's compatible with accumulate_grad.cpp.
416  auto inputs_copy = inputs;
417  outputs = fn(std::move(inputs_copy));
418  }else{
419  outputs = fn(std::move(inputs));
420  }
421 
422  validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) {
423  std::ostringstream ss;
424  ss << "Function " << fn.name() << " returned an " << msg;
425  return ss.str();
426  });
427  checkpoint_valid = prev_checkpoint_valid_state;
428 
429  if(has_post_hooks){
430  // NOLINTNEXTLINE(bugprone-use-after-move)
431  return call_post_hooks(fn, std::move(outputs), inputs);
432  }
433  return outputs;
434 }
435 
436 auto Engine::evaluate_function(FunctionTask& task) -> void {
437  // If exec_info is not empty, we have to instrument the execution
438  auto & exec_info = task.base->exec_info;
439  if (!exec_info.empty()) {
440  auto & fn_info = exec_info.at(task.fn.get());
441  if (auto *capture_vec = fn_info.captures.get()) {
442  std::lock_guard<std::mutex> lock(task.base->mutex);
443  for (auto capture : *capture_vec) {
444  task.base->captured_vars[capture.output_idx] = task.inputs[capture.input_idx];
445  }
446  }
447  if (!fn_info.needed) return;
448  }
449 
450  auto outputs = call_function(task);
451 
452  auto& fn = *task.fn;
453  if (!task.base->keep_graph) {
454  fn.release_variables();
455  }
456 
457  int num_outputs = outputs.size();
458  if (num_outputs == 0) return; // Don't even acquire the mutex
459 
460  if (AnomalyMode::is_enabled()) {
461  AutoGradMode grad_mode(false);
462  for (int i = 0; i < num_outputs; ++i) {
463  auto& output = outputs[i];
464  at::OptionalDeviceGuard guard(device_of(output));
465  if (output.defined() && output.ne(output).any().item<uint8_t>()) {
466  std::stringstream ss;
467  ss << "Function '" << fn.name() << "' returned nan values in its " << i << "th output.";
468  throw std::runtime_error(ss.str());
469  }
470  }
471  }
472 
473  std::lock_guard<std::mutex> lock(task.base->mutex);
474  for (int i = 0; i < num_outputs; ++i) {
475  auto& output = outputs[i];
476  const auto& next = fn.next_edge(i);
477 
478  if (!next.is_valid()) continue;
479 
480  // Check if the next function is ready to be computed
481  bool is_ready = false;
482  auto& dependencies = task.base->dependencies;
483  auto it = dependencies.find(next.function.get());
484  if (it == dependencies.end()) {
485  auto name = next.function->name();
486  throw std::runtime_error(std::string("dependency not found for ") + name);
487  } else if (--it->second == 0) {
488  dependencies.erase(it);
489  is_ready = true;
490  }
491 
492  auto& not_ready = task.base->not_ready;
493  auto not_ready_it = not_ready.find(next.function.get());
494  if (not_ready_it == not_ready.end()) {
495  // Skip functions that aren't supposed to be executed
496  if (!exec_info.empty()) {
497  auto it = exec_info.find(next.function.get());
498  if (it == exec_info.end() || !it->second.should_execute()) {
499  continue;
500  }
501  }
502  // No buffers have been allocated for the function
503  InputBuffer input_buffer(next.function->num_inputs());
504  input_buffer.add(next.input_nr, std::move(output));
505  if (is_ready) {
506  auto& queue = ready_queue(input_buffer.device());
507  queue.push(FunctionTask(task.base, next.function, std::move(input_buffer)));
508  } else {
509  not_ready.emplace(next.function.get(), std::move(input_buffer));
510  }
511  } else {
512  // The function already has a buffer
513  auto &input_buffer = not_ready_it->second;
514  input_buffer.add(next.input_nr, std::move(output));
515  if (is_ready) {
516  auto& queue = ready_queue(input_buffer.device());
517  queue.push(FunctionTask(task.base, next.function, std::move(input_buffer)));
518  not_ready.erase(not_ready_it);
519  }
520  }
521  }
522 }
523 
524 /* Computes the number of dependencies for each function which requires grad */
525 auto Engine::compute_dependencies(Function* root, GraphTask& task) -> void {
526  // Just to make sure that they will never be added to the queue again
527  std::unordered_set<Function*> seen;
528  std::vector<Function*> queue { root };
529 
530  // Queue contains all nodes that will start propagating gradients.
531  // We no longer have to expand functions that don't require grad.
532  auto& dependencies = task.dependencies;
533  while (!queue.empty()) {
534  auto fn = queue.back(); queue.pop_back();
535  for (const auto& edge : fn->next_edges()) {
536  if (auto next_ptr = edge.function.get()) {
537  dependencies[next_ptr] += 1;
538  const bool was_inserted = seen.insert(next_ptr).second;
539  if (was_inserted) queue.push_back(next_ptr);
540  }
541  }
542  }
543 }
544 
546  ClearCallbacks(std::vector<std::function<void()>>& callbacks,
547  std::mutex &callbacks_lock)
548  : callbacks(callbacks)
549  , callbacks_lock(callbacks_lock) { clear(); }
550  ~ClearCallbacks() { clear(); }
551 
552  void clear() {
553  std::lock_guard<std::mutex> lock(callbacks_lock);
554  callbacks.clear();
555  }
556 
557  std::vector<std::function<void()>>& callbacks;
558  std::mutex& callbacks_lock;
559 };
560 
561 auto Engine::execute(const edge_list& roots,
562  const variable_list& inputs,
563  bool keep_graph,
564  bool create_graph,
565  const edge_list& outputs) -> variable_list {
566  std::call_once(start_threads_flag, &Engine::start_threads, this);
567 
568  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
569  validate_outputs(roots, const_cast<variable_list&>(inputs), [](const std::string& msg) {
570  return msg;
571  });
572 
573  // Callbacks are only valid for the duration of this run and should always be cleared
574  ClearCallbacks _cb_guard(final_callbacks, post_callbacks_lock);
575 
576  GraphTask graph_task(keep_graph, create_graph);
577  std::unique_lock<std::mutex> lock(graph_task.mutex);
578 
579  // Now compute the dependencies for all executable functions and queue the root
580  auto graph_root = std::make_shared<GraphRoot>(roots, inputs);
581  compute_dependencies(graph_root.get(), graph_task);
582  if (!outputs.empty()) {
583  graph_task.init_to_execute(*graph_root, outputs);
584  }
585  ready_queue(at::kCPU).push(FunctionTask(&graph_task, std::move(graph_root), InputBuffer(0)));
586 
587  // Not a worker
588  if (worker_device == NO_DEVICE) {
589  // Wait for all tasks to complete
590  graph_task.not_done.wait(lock, [&graph_task]{
591  return graph_task.outstanding_tasks.load() == 0;
592  });
593  } else {
594  // Get back to work while we wait for our new graph_task to
595  // complete!
596  // See Note [Reentrant backwards]
597  graph_task.owner = worker_device;
598  lock.unlock();
599  thread_main(&graph_task);
600  }
601 
602  // Check for an exception while running backwards
603  if (graph_task.has_error.load()) {
604  std::rethrow_exception(graph_task.exception);
605  }
606 
607  if (!graph_task.not_ready.empty()) {
608  throw std::runtime_error("could not compute gradients for some functions");
609  }
610 
611  // Unlocking is necessary, because the callback can register
612  // more callbacks (or they can be registered from other threads
613  // while it's waiting.
614  std::unique_lock<std::mutex> cb_lock(post_callbacks_lock);
615  // WARNING: Don't use a range-for loop here because more callbacks may be
616  // added in between callback calls, so iterators may become invalidated.
617  // NOLINTNEXTLINE(modernize-loop-convert)
618  for (size_t i = 0; i < final_callbacks.size(); ++i) {
619  cb_lock.unlock();
620  final_callbacks[i]();
621  cb_lock.lock();
622  }
623 
624  return graph_task.captured_vars;
625 }
626 
627 // note that when python is present, this base engine will be overriden
628 // with a PythonEngine. Because this typically happens before get_default_engine
629 // is called, this base engine will never be created.
630 static Engine& get_base_engine() {
631  static Engine engine;
632  return engine;
633 }
634 
635 std::atomic<EngineStub> engine_stub(get_base_engine);
636 
637 void set_default_engine_stub(EngineStub stub) {
638  engine_stub.store(stub);
639 }
640 
641 
643  return engine_stub.load()();
644 }
645 
646 void Engine::queue_callback(std::function<void()> callback) {
647  std::lock_guard<std::mutex> lock(post_callbacks_lock);
648  final_callbacks.emplace_back(std::move(callback));
649 }
650 
651 bool Engine::is_checkpoint_valid() {
652  return checkpoint_valid;
653 }
654 
655 auto Engine::ready_queue(at::Device device) -> ReadyQueue& {
656  // See Note [Allocating GPUs to autograd threads]
657  if (device.type() == at::kCPU) {
658  return *ready_queues.at(0);
659  } else {
660  return *ready_queues.at(device.index() + 1);
661  }
662 }
663 
664 // See Note [Allocating GPUs to autograd threads]
665 // NB: This would become obsolete if we truly allocated a CPU thread
666 // per device, rather than colocate.
667 auto Engine::ready_queue_by_index(int device_index) -> ReadyQueue& {
668  return *ready_queues.at(device_index + 1);
669 }
670 
671 auto Engine::start_threads() -> void {
672  // See Note [Allocating GPUs to autograd threads]
673  c10::DeviceIndex num_devices = 0;
674  for (const auto& impl_atomic : c10::impl::device_guard_impl_registry) {
675  auto* impl = impl_atomic.load();
676  if (impl) {
677  num_devices = std::max(num_devices, impl->deviceCount());
678  }
679  }
680 
681  // One for CPU, plus one for every GPU device (but colocate GPUs of different
682  // types)
683  int num_threads = num_devices + 1;
684  ready_queues = std::vector<std::shared_ptr<ReadyQueue>>(num_threads);
685  for (auto& queue : ready_queues)
686  queue.reset(new ReadyQueue());
687  for (int i = 0; i < num_threads; ++i) {
688  std::thread t(&Engine::thread_init, this, i - 1);
689  t.detach();
690  }
691 }
692 
693 void GraphTask::init_to_execute(Function& graph_root, const edge_list& outputs) {
694  exec_info[&graph_root].needed = true;
695 
696  int output_idx = 0;
697  for (auto & output_edge : outputs) {
698  Function *output = output_edge.function.get();
699  auto & info = exec_info[output];
700  if (!info.captures)
701  info.captures = make_unique<std::vector<ExecInfo::Capture>>();
702  info.captures->emplace_back(output_edge.input_nr, output_idx++);
703  }
704  captured_vars.resize(output_idx);
705 
706  // NB: this is an uglier version (recursion replaced with iteration) of the following code:
707  // is_needed = {}
708  // def compute_is_needed(fn):
709  // if fn not in is_needed:
710  // is_needed[fn] = any(compute_is_needed(next_edge)
711  // for next_edge in fn.next_edges)
712  // return is_needed[fn]
713  struct Frame {
714  Frame (Function *fn) : fn(fn), next_next_fn(0) {}
715  Function *fn;
716  size_t next_next_fn;
717 
718  Function* get_next_fn() {
719  const auto & next = fn->next_edges();
720  auto num_next = next.size();
721  while (next_next_fn < num_next) {
722  auto fn = next[next_next_fn++].function.get();
723  if (fn) return fn;
724  }
725  return nullptr;
726  }
727  };
728  std::vector<Frame> stack;
729  std::unordered_set<Function*> seen;
730  for (const auto & input : graph_root.next_edges()) {
731  if (seen.count(input.function.get()) > 0) continue;
732  stack.emplace_back(input.function.get());
733  while (!stack.empty()) {
734  auto &frame = stack.back();
735  if (Function *next_fn = frame.get_next_fn()) {
736  if (/* bool unseen = */ seen.emplace(next_fn).second) {
737  stack.emplace_back(next_fn);
738  continue; // recurse
739  }
740  } else {
741  // NB: if we were using real recursion we could have saved some lookups
742  // using a return value from recursive call. It would make this manually unrolled
743  // version a lot more complicated, so I skipped that.
744  const auto & next_edges = frame.fn->next_edges();
745  const bool needed = std::any_of(
746  next_edges.begin(), next_edges.end(), [&](const Edge& edge) {
747  auto it = exec_info.find(edge.function.get());
748  return it != exec_info.end() && it->second.should_execute();
749  });
750  exec_info[frame.fn].needed = needed;
751  stack.pop_back();
752  }
753  }
754  }
755 }
756 
757 }} // namespace torch::autograd
optional< Device > device_of(Tensor t)
Return the Device of a Tensor, if the Tensor is defined.
Definition: DeviceGuard.h:17
Represents a particular input of a function.
Definition: edge.h:14
static Engine & get_default_engine()
Returns a reference to a static Engine instance.
Definition: engine.cpp:642
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
int16_t DeviceIndex
An index representing a specific device; e.g., the 1 in GPU 1.
Definition: Device.h:18
A OptionalDeviceGuard is an RAII class that sets a device to some value on initialization, and resets the device to its original value on destruction.
Definition: DeviceGuard.h:119
Definition: jit_type.h:17
DeviceIndex index() const noexcept
Returns the optional index.
Definition: Device.h:70
DeviceType type() const noexcept
Returns the type of device this is.
Definition: Device.h:65