1 #include <torch/csrc/jit/tracer.h> 3 #include <c10/util/Exception.h> 4 #include <torch/csrc/autograd/engine.h> 5 #include <torch/csrc/autograd/function.h> 6 #include <torch/csrc/autograd/variable.h> 7 #include <torch/csrc/jit/passes/dead_code_elimination.h> 8 #include <torch/csrc/jit/passes/remove_expands.h> 24 void genericAddInput(Node* n,
T value) {
25 Value* v = n->owningGraph()->insertConstant(value);
26 recordSourceLocation(v->node());
31 void badArgType(
const T& v) {
33 "Found an unsupported argument type in the JIT tracer: ",
34 c10::demangle_type<T>(),
35 ". File a bug report.");
38 thread_local std::shared_ptr<TracingState> tracing_state;
42 TORCH_API std::function<void()> pauseTracing() {
44 std::shared_ptr<tracer::TracingState> state = getTracingState();
45 tracer::setTracingState(
nullptr);
47 return [state]() { tracer::setTracingState(state); };
50 void delValueTrace(
const Variable& var) {
51 AT_ASSERT(var.defined());
52 auto& env_stack = getTracingState()->env_stack;
53 for (
size_t i = 0; i < env_stack.size(); ++i) {
54 auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map;
56 auto it = value_map.find(var);
57 if (it == value_map.end()) {
62 getTracingState()->env_stack.back().value_map.erase(var);
79 Value* getValueTrace(
const IValue& var) {
80 auto& state = getTracingState();
81 auto& env_stack = getTracingState()->env_stack;
84 auto ten = var.toTensor();
86 Node* n = state->graph->createNone(TensorType::get());
87 return state->graph->insertNode(n)->output();
89 for (
size_t i = 0; i < env_stack.size(); ++i) {
90 auto& value_map = env_stack.at(env_stack.size() - 1 - i).value_map;
91 auto it = value_map.find(ten);
92 if (it == value_map.end()) {
95 if (!it->second->hasUniqueName()) {
96 auto unique_name = getTracingState()->lookup_var_name_fn(ten);
97 if (!unique_name.empty()) {
98 it->second->setUniqueName(unique_name);
105 Value* constant = state->graph->insertConstant(ten);
106 recordSourceLocation(constant->node());
107 constant->inferTypeFrom(ten);
108 auto it = env_stack.back().value_map.find(ten);
109 it = env_stack.back().value_map.emplace_hint(it, ten, constant);
111 }
else if (var.isFuture()) {
112 auto fut = var.toFuture();
113 for (
size_t i = 0; i < env_stack.size(); ++i) {
114 auto& future_map = env_stack.at(env_stack.size() - 1 - i).future_map;
115 auto it = future_map.find(fut);
116 if (it == future_map.end()) {
122 std::ostringstream oss;
123 oss <<
"Tried to trace Future that the tracer was not aware of.";
124 throw std::runtime_error(oss.str());
126 std::ostringstream oss;
127 oss <<
"Unknown type used in value trace lookup!";
128 throw std::runtime_error(oss.str());
135 Value* getNestedValueTrace(
const IValue& v) {
136 auto& state = getTracingState();
137 if (v.isTensorList()) {
139 ->insertNode(state->graph->createList(
143 [](
const IValue& val) {
return getNestedValueTrace(val); })))
145 }
else if (v.isTuple()) {
147 ->insertNode(state->graph->createTuple(fmap(
148 v.toTuple()->elements(),
149 [](
const IValue& val) {
return getNestedValueTrace(val); })))
152 return getValueTrace(v.toTensor());
155 Value* getOutputTrace(
156 const std::shared_ptr<TracingState>& state,
157 const Variable& var) {
158 if (!var.defined()) {
159 Node* n = state->graph->createNone(TensorType::get());
160 return state->graph->insertNode(n)->output();
163 auto& value_map = getTracingState()->env_stack.back().value_map;
164 auto it = value_map.find(var);
165 if (it == value_map.end()) {
166 std::ostringstream os;
167 os <<
"output of traced region did not have observable " 168 <<
"data dependence with trace inputs; this probably indicates your program " 169 <<
"cannot be understood by the tracer.";
170 throw std::runtime_error(os.str());
175 Value* getNestedOutputTrace(
176 const std::shared_ptr<TracingState>& state,
179 return getOutputTrace(state, iv.toTensor());
180 }
else if (iv.isTuple()) {
181 const auto& elems = iv.toTuple()->elements();
183 state->graph->createTuple(fmap(elems, [&state](
const IValue& ival) {
184 return getNestedOutputTrace(state, ival);
186 state->graph->insertNode(tuple_node);
187 return tuple_node->output();
190 "Only tensors or tuples of tensors can be output from traced functions");
197 std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) {
199 AT_ERROR(
"Tracing can't be nested");
201 auto state = std::make_shared<TracingState>();
202 setTracingState(state);
204 const std::function<IValue(IValue, TypePtr, Value*)> add_input =
205 [&](IValue input, TypePtr type, Value* value) -> IValue {
206 value->setType(type);
207 if (type->isSubtypeOf(TensorType::get())) {
208 auto input_tensor = input.toTensor();
209 auto name = Variable(input_tensor).name();
210 auto& value_map = state->env_stack.back().value_map;
211 if (value_map.find(input_tensor) != value_map.end()) {
212 input_tensor = input_tensor.view(input_tensor.sizes());
214 value->setUniqueName(name);
215 value_map[input_tensor] = value;
217 }
else if (
auto tuple_type = type->cast<TupleType>()) {
219 state->graph->insertNode(state->graph->createTupleUnpack(value));
220 auto elem_values = unpack_node->outputs();
221 auto elem_types = tuple_type->elements();
222 Stack elems = input.toTuple()->elements();
223 size_t num_elems = elems.size();
225 elem_values.size() == num_elems && elem_types.size() == num_elems);
226 for (
size_t i = 0; i < num_elems; ++i) {
227 elems[i] = add_input(elems[i], elem_types[i], elem_values[i]);
229 return Tuple::create(std::move(elems));
232 "Only tensors or tuples of tensors can be inputs to traced functions");
235 for (IValue& input : inputs) {
237 input, incompleteInferTypeFrom(input), state->graph->addInput());
239 return std::make_pair(state, inputs);
245 void exit(
const Stack& outputs) {
246 auto& state = getTracingState();
248 for (
auto& output : outputs) {
249 state->graph->registerOutput(getNestedOutputTrace(state, output));
252 setTracingState(
nullptr);
257 setTracingState(
nullptr);
260 void setValueTrace(
const IValue& v, Value* value) {
262 auto var = v.toTensor();
263 AT_ASSERT(var.defined());
264 getTracingState()->env_stack.back().value_map[var] = value;
265 }
else if (v.isTensorList()) {
266 auto& outputs = v.toTensorList()->elements();
267 auto graph = getTracingState()->graph;
269 graph->insertNode(graph->createListUnpack(value, outputs.size()));
270 for (
size_t i = 0; i < outputs.size(); ++i) {
271 setValueTrace(outputs[i], unpack_node->outputs()[i]);
273 }
else if (v.isTuple()) {
274 auto& outputs = v.toTuple()->elements();
275 auto graph = getTracingState()->graph;
276 Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value));
277 for (
size_t i = 0; i < outputs.size(); ++i) {
278 setValueTrace(outputs[i], unpack_node->outputs()[i]);
280 }
else if (v.isGenericList()) {
281 auto elements = v.toGenericListRef();
282 auto graph = getTracingState()->graph;
284 graph->insertNode(graph->createListUnpack(value, elements.size()));
285 for (
size_t i = 0; i < elements.size(); ++i) {
286 setValueTrace(elements[i], unpack_node->outputs()[i]);
288 }
else if (v.isFuture()) {
289 auto fut = v.toFuture();
290 getTracingState()->env_stack.back().future_map[fut] = value;
292 std::ostringstream os;
293 os <<
"Tracer cannot set value trace for type " << v.tagKind() <<
". " 294 <<
"Supported types are tensor, tensor list, and tuple of tensors.";
295 throw std::runtime_error(os.str());
299 void addInputs(Node* n,
const char* name, int64_t value) {
300 using ArgumentStash = jit::tracer::ArgumentStash;
301 if (ArgumentStash::hasValue(name)) {
302 Value* v = ArgumentStash::popValue(name);
305 detail::genericAddInput(n, value);
311 detail::genericAddInput(n, *value);
313 Graph* g = n->owningGraph();
314 Value* none = g->insertNode(g->createNone(IntType::get()))->output();
318 void addInputs(Node* n,
const char* name,
bool value) {
319 detail::genericAddInput(n, value);
321 void addInputs(Node* n,
const char* name,
double value) {
322 detail::genericAddInput(n, value);
324 void addInputs(Node* n,
const char* name,
const at::Scalar& value) {
325 detail::genericAddInput(n, value);
332 detail::genericAddInput(n, *value);
334 Graph* g = n->owningGraph();
335 Value* none = g->insertNode(g->createNone(NumberType::get()))->output();
339 void addInputs(Node* n,
const char* name,
const std::string& value) {
340 detail::genericAddInput(n, value);
342 void addInputs(Node* n,
const char* name,
const at::Tensor& value) {
343 n->addInput(getValueTrace(value));
346 detail::badArgType(value);
348 void addInputs(Node* n,
const char* name,
at::Generator* value) {
350 detail::badArgType(value);
352 Graph* g = n->owningGraph();
354 g->insertNode(g->createNone(GeneratorType::get()))->output();
355 n->addInput(undef_gen);
357 void addInputs(Node* n,
const char* name,
at::Device value) {
358 detail::genericAddInput(n, value);
360 void addInputs(Node* n,
const char* name, at::Layout value) {
361 detail::genericAddInput(n, static_cast<int64_t>(value));
363 void addInputs(Node* n,
const char* name, at::ScalarType value) {
364 detail::genericAddInput(n, static_cast<int64_t>(value));
371 detail::genericAddInput(n, static_cast<int64_t>(*value));
373 Graph* g = n->owningGraph();
374 Value* none = g->insertNode(g->createNone(IntType::get()))->output();
383 bool allow_undefined) {
384 Graph* g = n->owningGraph();
385 Node* list_node =
nullptr;
386 if (allow_undefined) {
388 list_node = g->insertNode(
389 g->createList(OptionalType::ofTensor(), fmap(value, getValueTrace)));
391 list_node = g->insertNode(
392 g->createList(TensorType::get(), fmap(value, getValueTrace)));
394 n->addInput(list_node->output());
400 addInputs(n, name, at::typeMetaToScalarType(options.
dtype()));
401 addInputs(n, name, options.
layout());
402 addInputs(n, name, options.
device());
406 using ArgumentStash = jit::tracer::ArgumentStash;
407 std::vector<Value*> info = ArgumentStash::hasIntArrayRef(name)
408 ? ArgumentStash::popIntArrayRef(name)
409 : ArgumentStash::IntArrayRefTrace(value.size());
411 auto& g = getTracingState()->graph;
412 for (
size_t i = 0; i < info.size(); ++i) {
413 if (info[i] !=
nullptr)
415 info[i] = g->insertConstant(value[i]);
416 recordSourceLocation(info[i]->node());
418 for (jit::Value* v : info) {
419 if (*v->type() != *jit::IntType::get()) {
420 throw std::runtime_error(
421 "Type mismatch in setposattr for IntArrayRef. Check that your program " 422 "is valid without tracing, and please file a bug report if it is.");
426 g->insertNode(g->createList(jit::IntType::get(), info))->output());
429 void addInputs(Node* n,
const char* name,
const ArrayRef<double>& value) {
430 AT_ERROR(
"Tracing float lists currently not supported!");
435 const std::vector<double>& value) {
436 AT_ERROR(
"Tracing float lists currently not supported!");
439 void addOutput(Node* node,
const at::Tensor& output) {
440 setOutput(node->addOutput(), output);
443 void setOutput(Value* value,
const at::Tensor& output) {
444 if (output.defined()) {
445 value->inferTypeFrom(output);
446 setValueTrace(autograd::as_variable_ref(output), value);
450 void addOutput(Node* node,
const std::vector<at::Tensor>& outputs) {
451 Value* value = node->addOutput()->setType(ListType::ofTensors());
452 Graph* graph = node->owningGraph();
453 Node* unpack_node = graph->insertNode(
454 graph->create(prim::ListUnpack, {value}, outputs.size()));
455 for (
size_t i = 0; i < outputs.size(); ++i) {
456 Value* output_val = unpack_node->outputs()[i];
457 output_val->inferTypeFrom(outputs[i]);
458 setValueTrace(outputs[i], output_val);
462 const std::shared_ptr<TracingState>& getTracingState() {
463 return detail::tracing_state;
466 void setTracingState(std::shared_ptr<TracingState> state) {
467 detail::tracing_state = std::move(state);
470 TracingState::TracingState()
471 : env_stack{TracingEnvironmentFrame()}, graph(
new Graph()) {}
473 TracingState::~TracingState() =
default;
475 autograd::Variable getSizeOf(
const autograd::Variable& var, int64_t dim) {
476 auto& tracing_state = getTracingState();
477 auto& graph = tracing_state->graph;
480 autograd::make_variable(scalar_to_tensor(
at::Scalar(var.size(dim))));
481 auto* value = getValueTrace(var);
482 auto dim_val = graph->insertConstant(dim);
483 recordSourceLocation(dim_val->node());
484 auto* node = graph->insertNode(graph->create(aten::size, {value, dim_val}));
485 recordSourceLocation(node);
486 node->output()->setType(jit::IntType::get());
489 graph->insertNode(graph->createNumToTensor(node->output()))->output();
490 setValueTrace(size_var, ten);
494 void ensureUniqueIfOutOfPlaced(
const char* name,
const at::Tensor& tensor) {
495 auto& state = getTracingState();
496 if (state && state->force_outplace ==
false) {
501 auto aliases = tensor.storage().use_count();
502 if (isTracing() && aliases > 1) {
503 std::stringstream ss;
504 ss <<
"There are " << aliases
505 <<
" live references to the data region being modified when tracing in-place operator " 507 <<
". This might cause the trace to be incorrect, because all other views " 508 <<
"that also reference this data will not not reflect this change in the trace! " 509 <<
"On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. " 510 <<
"are outputs of torch.split), this might still be safe.";
511 warn(ss.str().c_str());
518 thread_local ArgumentStash ArgumentStash::stash;
520 void ArgumentStash::stashIntArrayRefElem(
521 const std::string& arg_name,
524 const Variable& var) {
528 auto& list_trace = stash.intlists.emplace(arg_name, size).first->second;
529 AT_ASSERT(size == list_trace.size());
530 AT_ASSERT(idx < list_trace.size());
531 AT_ASSERT(list_trace[idx] ==
nullptr);
533 Value* ten = getValueTrace(var);
534 auto& g = *ten->owningGraph();
535 WithInsertPoint guard(ten->node()->next());
536 auto prim = g.insert(prim::Int, {ten});
537 list_trace[idx] = prim;
540 void ArgumentStash::stashValue(
541 const std::string& arg_name,
544 const TypePtr& type) {
548 Value* ten = getValueTrace(var);
549 WithInsertPoint guard(ten->node()->next());
550 auto& g = *ten->owningGraph();
552 if (type == IntType::get()) {
553 ten = g.insert(prim::Int, {ten});
554 }
else if (type == FloatType::get()) {
555 ten = g.insert(prim::Float, {ten});
558 stash.values.emplace(arg_name, ten);
565 void defaultRecordSourceLocation(Node* n) {}
566 std::atomic<decltype(&defaultRecordSourceLocation)> record_source_location(
567 defaultRecordSourceLocation);
568 void recordSourceLocation(Node* n) {
569 return record_source_location.load()(n);
571 void setRecordSourceLocation(
void (*v)(Node*)) {
572 record_source_location.store(v);
575 void defaultWarn(
const std::string& str) {
578 std::atomic<warn_fn_type> warn_callback{defaultWarn};
580 const char* WARN_PYTHON_DATAFLOW =
581 " might cause the trace to be incorrect. We can't record the data flow of " 582 "Python values, so this value will be treated as a constant in the future. " 583 "This means that the trace might not generalize to other inputs!";
584 const char* WARN_CONSTRUCTOR =
585 " results are registered as constants in the trace. You can safely ignore this " 586 "warning if you use this function to create tensors out of constant variables " 587 "that would be the same every time you call this function. In any other case, " 588 "this might cause the trace to be incorrect.";
589 const char* WARN_RESIZE =
590 " can't be represented in the JIT at the moment, so we won't connect any uses of " 591 "this value with its current trace. If you happen to use it again, it will show " 592 "up as a constant in the graph.";
593 const char* LEGACY_CONSTRUCTOR =
594 " is a legacy constructor and is not supported in the JIT.";
597 void _do_warn(
const char* _reason,
const char* _kind) {
598 std::string reason{_reason};
599 std::string kind{_kind ? _kind :
""};
600 std::ostringstream s;
602 warn_callback.load()(s.str());
605 void setWarn(warn_fn_type fn) {
606 warn_callback.store(fn);
C10_NODISCARD TensorOptions device(c10::optional< Device > device) const noexcept
Return a copy of TensorOptions with device set to the given one, or cleared if device is nullopt...
Scalar represents a 0-dimensional tensor which contains a single element.
Represents a a compute device on which a tensor is located.
C10_NODISCARD TensorOptions dtype(c10::optional< caffe2::TypeMeta > dtype) const noexcept
Return a copy of TensorOptions with dtype set to the given one.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
C10_NODISCARD TensorOptions layout(c10::optional< Layout > layout) const noexcept
Sets the layout of the TensorOptions.