Caffe2 - C++ API
A deep learning, cross platform ML framework
tracing_state.h
1 #pragma once
2 
3 #include <ATen/core/functional.h>
4 #include <ATen/core/jit_type.h>
5 #include <ATen/core/stack.h>
6 #include <c10/util/Exception.h>
7 #include <torch/csrc/WindowsTorchApiMacro.h>
8 #include <torch/csrc/autograd/function_hook.h>
9 #include <torch/csrc/autograd/variable.h>
10 #include <torch/csrc/jit/constants.h>
11 #include <torch/csrc/utils/variadic.h>
12 
13 #include <ATen/Backtrace.h>
14 
15 #include <cstdint>
16 #include <iostream>
17 #include <memory>
18 #include <mutex>
19 #include <unordered_map>
20 #include <vector>
21 
22 namespace torch {
23 namespace jit {
24 namespace tracer {
25 
27 using variable_list = std::vector<Variable>;
28 
29 struct TORCH_API TracingState
30  : public std::enable_shared_from_this<TracingState> {
31  TracingState();
32  ~TracingState();
33 
34  using WeakTensor = at::WeakTensor;
35 
37  size_t operator()(const WeakTensor& t) const {
38  return std::hash<void*>()(t.unsafeGetTensorImpl());
39  }
40  };
41 
42  struct WeakTensorEq {
43  bool operator()(const WeakTensor& t1, const WeakTensor& t2) const {
44  return t1.is_same(t2);
45  }
46  };
47 
49  std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq>
50  value_map;
51  // TODO weak refcount
52  std::unordered_map<c10::intrusive_ptr<c10::ivalue::Future>, Value*>
53  future_map;
54  };
55 
56  using TracingEnvironmentStack = std::vector<TracingEnvironmentFrame>;
57 
58  TracingEnvironmentStack env_stack;
59  std::shared_ptr<Graph> graph;
60  bool warn = true;
61  bool force_outplace = false;
62  std::function<std::string(const Variable& var)> lookup_var_name_fn =
63  [](const Variable& var) { return ""; };
64 };
65 
66 // This is meant to be used as a thread local place, where we can store extra
67 // info that gets lost when we call into ATen from Python bindings. One example
68 // for when this happens is when we get an IntArrayRef argument with e.g. sizes for
69 // view. When tracing, those might be tensors, which let us encode extra data
70 // dependencies, but once they get to the ATen call where we actually have the
71 // tracing logic, they get converted into a raw IntArrayRef, and we loose all
72 // information. To prevent this, we temporarily stash it in here.
73 struct ArgumentStash {
74  struct IntArrayRefTrace : std::vector<Value*> {
75  IntArrayRefTrace(int size) : std::vector<Value*>(size, nullptr) {}
76  };
77 
78  static bool empty() {
79  return stash.intlists.empty();
80  }
81 
82  TORCH_API static void stashIntArrayRefElem(
83  const std::string& arg_name,
84  size_t size,
85  size_t idx,
86  const Variable& var);
87 
88  static bool hasIntArrayRef(const std::string& arg_name) {
89  return stash.intlists.count(arg_name) > 0;
90  }
91 
92  static IntArrayRefTrace popIntArrayRef(const std::string& arg_name) {
93  auto info = std::move(stash.intlists.at(arg_name));
94  stash.intlists.erase(arg_name);
95  return info;
96  }
97 
98  // Value stashing: Use these methods to stash arguments which correspond
99  // to regular Value*'s in the graph. i.e. they don't require special
100  // handling like in the case of IntArrayRefs
101  TORCH_API static void stashValue(
102  const std::string& arg_name,
103  size_t idx,
104  const Variable& var,
105  const c10::TypePtr& type = nullptr);
106 
107  static bool hasValue(const std::string& arg_name) {
108  return stash.values.count(arg_name) > 0;
109  }
110 
111  static Value* popValue(const std::string& arg_name) {
112  auto info = stash.values.at(arg_name);
113  stash.values.erase(arg_name);
114  return info;
115  }
116 
117  private:
118  static thread_local ArgumentStash stash;
119  std::unordered_map<std::string, IntArrayRefTrace> intlists;
120  std::unordered_map<std::string, Value*> values;
121 };
122 
123 // Retrieve or set the current tracing state. Returns a nullptr if tracing is
124 // disabled.
125 TORCH_API const std::shared_ptr<TracingState>& getTracingState();
126 TORCH_API void setTracingState(std::shared_ptr<TracingState> state);
127 
128 inline bool isTracing() {
129  return static_cast<bool>(getTracingState());
130 }
131 
132 using warn_fn_type = void (*)(const std::string& msg);
133 TORCH_API extern const char* WARN_PYTHON_DATAFLOW;
134 TORCH_API extern const char* WARN_CONSTRUCTOR;
135 TORCH_API extern const char* WARN_RESIZE;
136 TORCH_API extern const char* LEGACY_CONSTRUCTOR;
137 TORCH_API void _do_warn(const char* _reason, const char* _kind);
138 inline void warn(const char* _reason, const char* _kind = nullptr) {
139  if (const auto& state = getTracingState()) {
140  if (!state->warn)
141  return;
142  _do_warn(_reason, _kind);
143  }
144 }
145 TORCH_API void setWarn(warn_fn_type fn);
146 
147 struct TORCH_API NoWarn {
148  NoWarn() : state(getTracingState()) {
149  if (state) {
150  prev = state->warn;
151  state->warn = false;
152  }
153  }
154  ~NoWarn() {
155  if (state) {
156  state->warn = prev;
157  }
158  }
159  std::shared_ptr<TracingState> state;
160  bool prev;
161 };
162 
165  getTracingState()->env_stack.emplace_back();
166  }
167 
169  getTracingState()->env_stack.pop_back();
170  }
171 };
172 
173 } // namespace tracer
174 } // namespace jit
175 } // namespace torch
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17