1 #include <torch/csrc/python_headers.h> 3 #include <torch/csrc/jit/export.h> 4 #include <torch/csrc/jit/passes/dead_code_elimination.h> 5 #include <torch/csrc/jit/passes/lower_tuples.h> 6 #include <torch/csrc/jit/pybind.h> 7 #include <torch/csrc/jit/python_tracer.h> 8 #include <torch/csrc/jit/tracer.h> 9 #include <torch/csrc/utils/python_strings.h> 11 #include <c10/util/Exception.h> 25 std::string getPythonInterpreterStackTrace() {
26 std::stringstream stack_trace;
28 PyFrameObject* frame = PyEval_GetFrame();
29 while (
nullptr != frame) {
30 int line = PyCode_Addr2Line(frame->f_code, frame->f_lasti);
31 std::string filename = THPUtils_unpackString(frame->f_code->co_filename);
32 std::string funcname = THPUtils_unpackString(frame->f_code->co_name);
33 stack_trace << filename <<
"(" << line <<
"): " << funcname <<
"\n";
34 frame = frame->f_back;
36 return stack_trace.str();
39 std::shared_ptr<torch::jit::Graph> createGraphByTracing(
40 const py::function& func,
42 const py::function& var_name_lookup_fn,
45 size_t num_func_inputs = num_real_inputs.value_or(trace_inputs.size());
46 auto enter_info = tracer::enter(std::move(trace_inputs));
47 getTracingState()->lookup_var_name_fn =
48 [var_name_lookup_fn](
const Variable& var) -> std::string {
50 return py::cast<std::string>(var_name_lookup_fn(var));
52 getTracingState()->force_outplace = force_outplace;
54 py::tuple py_inputs(num_func_inputs);
55 for (
size_t i = 0; i < num_func_inputs; ++i) {
56 py_inputs[i] = py::cast(enter_info.second[i]);
58 auto out = func(*py_inputs);
59 if (out.ptr() == Py_None) {
61 "The traced function didn't return any values! Side-effects are not " 62 "captured in traces, so it would be a no-op.");
64 tracer::exit({toIValue(out)});
65 auto graph = enter_info.first->graph;
66 EliminateDeadCode(graph);
67 LowerSimpleTuples(graph);
76 Node* preRecordPythonTrace(
78 const std::string& arg_types,
80 pyobj_list scalar_args) {
81 THPObjectPtr apply(PyObject_GetAttrString(pyobj.get(),
"apply"));
86 auto& graph = getTracingState()->graph;
88 Node* n = graph->createPythonOp(
89 std::move(apply), arg_types, std::move(scalar_args));
90 recordSourceLocation(n);
92 for (
const Variable& input : inputs) {
93 n->addInput(getValueTrace(input));
102 void pythonRecordSourceLocation(
Node* n) {
104 std::make_shared<StringSourceLocation>(getPythonInterpreterStackTrace());
105 n->setSourceLocation(sl);
108 void pythonWarn(
const std::string& reason) {
110 auto warn_class = py::module::import(
"torch.jit").attr(
"TracerWarning");
111 PyErr_WarnEx(warn_class.ptr(), reason.c_str(), 1);
114 void initPythonTracerBindings(PyObject* module) {
115 setRecordSourceLocation(pythonRecordSourceLocation);
117 auto m = py::handle(module).cast<py::module>();
118 py::class_<TracingState, std::shared_ptr<TracingState>>(
119 m,
"TracingState", py::dynamic_attr())
124 std::ostringstream ss;
125 ss <<
"<TracingState " << (
const void*)&s <<
">";
131 std::ostringstream ss;
138 s.graph->push_scope(scope_name);
140 .def(
"pop_scope", [](
TracingState& s) { s.graph->pop_scope(); })
143 [](
TracingState& s, std::shared_ptr<Graph> g) { s.graph = g; })
146 m.def(
"_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); });
147 m.def(
"_tracer_enter", [](py::args trace_inputs) {
148 return tracer::enter(toStack(trace_inputs));
150 m.def(
"_tracer_exit", [](py::tuple var_outputs) {
151 tracer::exit(toStack(var_outputs));
153 m.def(
"_tracer_abandon", []() { tracer::abandon(); });
154 m.def(
"_get_tracing_state", []() {
return getTracingState(); });
155 m.def(
"_set_tracing_state", [](std::shared_ptr<TracingState> state) {
156 return setTracingState(state);
158 m.def(
"_get_value_trace", [](
const Variable& var) {
159 return getValueTrace(var);
161 m.def(
"_set_value_trace", [](
const Variable& var,
Value* value) {
162 return setValueTrace(var, value);
164 m.def(
"_tracer_set_get_unique_name_fn", [](py::function func) {
165 const auto& tracing_state = getTracingState();
166 AT_ASSERT(tracing_state);
167 tracing_state->lookup_var_name_fn =
168 [func](
const Variable& var) -> std::string {
170 return py::cast<std::string>(func(var));
173 m.def(
"_tracer_set_force_outplace", [](
bool force_outplace) {
174 const auto& tracing_state = getTracingState();
175 AT_ASSERT(tracing_state);
176 tracing_state->force_outplace = force_outplace;
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...