Caffe2 - C++ API
A deep learning, cross platform ML framework
interpreter.h
1 #pragma once
2 #include <c10/util/Optional.h>
3 #include <memory>
4 #include <vector>
5 
6 #include <torch/csrc/WindowsTorchApiMacro.h>
7 #include <ATen/core/ivalue.h>
8 
9 namespace at {
10 class Tensor;
11 }
12 namespace c10 {
13 struct IValue;
14 }
15 namespace torch {
16 namespace jit {
17 
18 // The interpreter run Graphs with Tensor inputs and Tensor outputs
19 // a separate component in the autograd handles unwrapping and wrapping
20 // variable objects for use in the interpreter.
21 
22 struct Node;
23 struct GraphExecutor;
24 struct CodeImpl;
25 struct InterpreterStateImpl;
26 struct Graph;
27 struct Node;
28 using Stack = std::vector<c10::IValue>;
29 using c10::ivalue::Future;
30 using c10::ivalue::Tuple;
31 
32 struct TORCH_API Code {
33  Code() : pImpl(nullptr) {}
34  explicit Code(const std::shared_ptr<Graph>& graph);
35  ~Code();
36 
37  const std::vector<GraphExecutor*>& grad_executors();
38 
39  explicit operator bool() const {
40  return pImpl != nullptr;
41  }
42 
43  private:
44  std::shared_ptr<CodeImpl> pImpl;
45  friend struct InterpreterStateImpl;
46  friend std::ostream& operator<<(std::ostream& out, const Code& code);
47 };
48 
50  InterpreterState(const Code& code);
51  void run(Stack& stack);
52  c10::intrusive_ptr<Future> runAsync(Stack& stack);
53  c10::intrusive_ptr<Future> getFuture();
55 
56  private:
58  // Ideally we should use c10::intrusive_ptr<InterpreterStateImpl> for pImpl;
59  // but intrusive_ptr requires full definition of InterpreterStateImpl,
60  // which we need to hide in the header.
62  friend struct InterpreterStateImpl;
63 };
64 
65 // Created by wait()
66 struct Suspend : public std::exception {
67  const char* what() const noexcept override {
68  return "Suspend";
69  }
70 
71  explicit Suspend(c10::intrusive_ptr<Future> future_)
72  : future(std::move(future_)) {}
73 
75 };
76 
78  InterpreterContinuation(InterpreterState state_, Stack stack_, bool grad_mode_enabled_)
79  : state(state_), stack(std::move(stack_)), grad_mode_enabled(grad_mode_enabled_) {}
80 
81  void operator()();
82 
83  private:
84  InterpreterState state;
85  Stack stack;
86  bool grad_mode_enabled;
87 };
88 } // namespace jit
89 } // namespace torch
Definition: jit_type.h:17
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7
Flush-To-Zero and Denormals-Are-Zero mode.