3 #include <ATen/core/functional.h> 4 #include <ATen/core/jit_type.h> 5 #include <ATen/core/stack.h> 6 #include <c10/util/Exception.h> 7 #include <torch/csrc/WindowsTorchApiMacro.h> 8 #include <torch/csrc/autograd/function_hook.h> 9 #include <torch/csrc/autograd/variable.h> 10 #include <torch/csrc/jit/constants.h> 11 #include <torch/csrc/utils/variadic.h> 13 #include <ATen/Backtrace.h> 19 #include <unordered_map> 27 using variable_list = std::vector<Variable>;
30 :
public std::enable_shared_from_this<TracingState> {
38 return std::hash<void*>()(t.unsafeGetTensorImpl());
44 return t1.is_same(t2);
49 std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq>
52 std::unordered_map<c10::intrusive_ptr<c10::ivalue::Future>,
Value*>
56 using TracingEnvironmentStack = std::vector<TracingEnvironmentFrame>;
58 TracingEnvironmentStack env_stack;
59 std::shared_ptr<Graph> graph;
61 bool force_outplace =
false;
62 std::function<std::string(const Variable& var)> lookup_var_name_fn =
63 [](
const Variable& var) {
return ""; };
79 return stash.intlists.empty();
82 TORCH_API
static void stashIntArrayRefElem(
83 const std::string& arg_name,
88 static bool hasIntArrayRef(
const std::string& arg_name) {
89 return stash.intlists.count(arg_name) > 0;
93 auto info = std::move(stash.intlists.at(arg_name));
94 stash.intlists.erase(arg_name);
101 TORCH_API
static void stashValue(
102 const std::string& arg_name,
105 const c10::TypePtr& type =
nullptr);
107 static bool hasValue(
const std::string& arg_name) {
108 return stash.values.count(arg_name) > 0;
111 static Value* popValue(
const std::string& arg_name) {
112 auto info = stash.values.at(arg_name);
113 stash.values.erase(arg_name);
119 std::unordered_map<std::string, IntArrayRefTrace> intlists;
120 std::unordered_map<std::string, Value*> values;
125 TORCH_API
const std::shared_ptr<TracingState>& getTracingState();
126 TORCH_API
void setTracingState(std::shared_ptr<TracingState> state);
128 inline bool isTracing() {
129 return static_cast<bool>(getTracingState());
132 using warn_fn_type = void (*)(
const std::string& msg);
133 TORCH_API
extern const char* WARN_PYTHON_DATAFLOW;
134 TORCH_API
extern const char* WARN_CONSTRUCTOR;
135 TORCH_API
extern const char* WARN_RESIZE;
136 TORCH_API
extern const char* LEGACY_CONSTRUCTOR;
137 TORCH_API
void _do_warn(
const char* _reason,
const char* _kind);
138 inline void warn(
const char* _reason,
const char* _kind =
nullptr) {
139 if (
const auto& state = getTracingState()) {
142 _do_warn(_reason, _kind);
145 TORCH_API
void setWarn(warn_fn_type fn);
148 NoWarn() : state(getTracingState()) {
159 std::shared_ptr<TracingState> state;
165 getTracingState()->env_stack.emplace_back();
169 getTracingState()->env_stack.pop_back();
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...