Caffe2 - C++ API
A deep learning, cross platform ML framework
tracer.h
1 #pragma once
2 
3 #include <ATen/Backtrace.h>
4 #include <ATen/core/functional.h>
5 #include <ATen/core/stack.h>
6 #include <c10/util/Exception.h>
7 #include <torch/csrc/WindowsTorchApiMacro.h>
8 #include <torch/csrc/autograd/function_hook.h>
9 #include <torch/csrc/autograd/variable.h>
10 #include <torch/csrc/jit/constants.h>
11 #include <torch/csrc/jit/ir.h>
12 #include <torch/csrc/jit/tracing_state.h>
13 #include <torch/csrc/utils/variadic.h>
14 
15 #include <cstdint>
16 #include <iostream>
17 #include <memory>
18 #include <mutex>
19 #include <unordered_map>
20 #include <vector>
21 
22 namespace torch {
23 namespace jit {
24 namespace tracer {
25 
26 using ::c10::ivalue::List;
27 using ::c10::ivalue::Shared;
28 
29 using ::c10::IValue;
30 using ::c10::ivalue::Future;
31 using ::c10::ivalue::Tuple;
32 
33 using ::c10::ivalue::BoolList;
34 using ::c10::ivalue::DoubleList;
35 using ::c10::ivalue::GenericList;
36 using ::c10::ivalue::IntList;
37 using ::c10::ivalue::TensorList;
38 
39 using ::c10::ivalue::ConstantString;
40 
42 using variable_list = std::vector<Variable>;
43 
44 TORCH_API void recordSourceLocation(Node* n);
45 TORCH_API void setRecordSourceLocation(void (*v)(Node*));
46 
47 // Having finished adding a new 'node' to the graph IR 'setValueTrace'
48 // associates this node with an output variable, so that further operations
49 // involving this variable know which node in the IR to reference.
50 TORCH_API void setValueTrace(const IValue& v, Value* value);
51 
52 TORCH_API void delValueTrace(const Variable& var);
53 
54 TORCH_API std::function<void()> pauseTracing();
55 
56 TORCH_API Value* getValueTrace(const IValue& var);
57 
58 TORCH_API Value* getNestedValueTrace(const IValue& v);
59 
60 TORCH_API Value* getOutputTrace(
61  const std::shared_ptr<TracingState>& state,
62  const Variable& var);
63 
64 TORCH_API Value* getNestedOutputTrace(
65  const std::shared_ptr<TracingState>& state,
66  const IValue& iv);
67 
68 TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs);
69 
70 TORCH_API void exit(const Stack& outputs);
71 
72 TORCH_API void abandon();
73 
74 // NB: those serve both as an intermediate steps in addInputs below,
75 // as well as the overloads that terminate template recursion
76 TORCH_API void addInputs(Node* n, const char* name, int64_t value);
77 TORCH_API void addInputs(
78  Node* n,
79  const char* name,
81 TORCH_API void addInputs(Node* n, const char* name, bool value);
82 TORCH_API void addInputs(Node* n, const char* name, double value);
83 TORCH_API void addInputs(Node* n, const char* name, const at::Scalar& value);
84 TORCH_API void addInputs(
85  Node* n,
86  const char* name,
87  const c10::optional<at::Scalar>& value);
88 TORCH_API void addInputs(Node* n, const char* name, const at::Tensor& value);
89 TORCH_API void addInputs(Node* n, const char* name, at::IntArrayRef value);
90 TORCH_API void addInputs(
91  Node* n,
92  const char* name,
93  at::TensorList value,
94  bool allow_undefined = false);
95 TORCH_API void addInputs(
96  Node* n,
97  const char* name,
98  const ArrayRef<double>& value);
99 TORCH_API void addInputs(
100  Node* n,
101  const char* name,
102  const std::vector<double>& value);
103 TORCH_API void addInputs(Node* n, const char* name, const std::string& value);
104 TORCH_API void addInputs(
105  Node* n,
106  const char* name,
107  const at::SparseTensorRef& value);
108 TORCH_API void addInputs(
109  Node* n,
110  const char* name,
111  const at::TensorOptions& value);
112 TORCH_API void addInputs(Node* n, const char* name, at::Device value);
113 TORCH_API void addInputs(Node* n, const char* name, at::Layout value);
114 TORCH_API void addInputs(Node* n, const char* name, at::ScalarType value);
115 TORCH_API void addInputs(
116  Node* n,
117  const char* name,
118  const c10::optional<at::ScalarType>& value);
119 TORCH_API void addInputs(Node* n, const char* name, at::Generator* value);
120 
121 template<typename T>
122 TORCH_API void addInputs(
123  Node* n,
124  const char* name,
125  const std::vector<T>& value);
126 
127 template<typename K, typename V>
128 TORCH_API void addInputs(
129  Node* n,
130  const char* name,
131  const std::unordered_map<K, V>& value);
132 
133 template<typename T>
134 void addInputs(
135  Node* n,
136  const char* name,
137  const std::vector<T>& value) {
138  AT_ERROR("Tracing a list of arbitrary type is currently not supported!");
139 }
140 template<typename K, typename V>
141 void addInputs(
142  Node* n,
143  const char* name,
144  const std::unordered_map<K, V>& value) {
145  AT_ERROR("Tracing a dict of arbitrary types is currently not supported!");
146 }
147 
148 template <size_t N>
149 void addInputs(Node* n, const char* name, std::array<bool, N> value) {
150  throw std::runtime_error(
151  "Found an unsupported argument type in the JIT tracer. File a bug report.");
152 }
153 
154 TORCH_API void ensureUniqueIfOutOfPlaced(
155  const char* name,
156  const at::Tensor& tensor);
157 
158 template <
159  typename T,
160  typename = torch::enable_if_t<
161  (!std::is_convertible<torch::decay_t<T>, at::TensorList>::value &&
162  !std::is_convertible<torch::decay_t<T>, at::Tensor>::value)>>
163 void addOutput(Node* node, T&&) {
164  AT_ERROR(
165  "Found an unsupported argument type ",
166  c10::demangle_type<T>(),
167  " in the JIT tracer. File a bug report.");
168 }
169 TORCH_API void addOutput(Node* node, const at::Tensor& tensor);
170 TORCH_API void setOutput(Value* value, const at::Tensor& output);
171 TORCH_API void addOutput(Node* node, const std::vector<at::Tensor>& list);
172 
173 TORCH_API autograd::Variable getSizeOf(
174  const autograd::Variable& var,
175  int64_t dim);
176 
177 } // namespace tracer
178 } // namespace jit
179 } // namespace torch
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17