Caffe2 - C++ API
A deep learning, cross platform ML framework
python_tracer.cpp
1 #include <torch/csrc/python_headers.h>
2 
3 #include <torch/csrc/jit/export.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 #include <torch/csrc/jit/passes/lower_tuples.h>
6 #include <torch/csrc/jit/pybind.h>
7 #include <torch/csrc/jit/python_tracer.h>
8 #include <torch/csrc/jit/tracer.h>
9 #include <torch/csrc/utils/python_strings.h>
10 
11 #include <c10/util/Exception.h>
12 
13 #include <sstream>
14 
15 using namespace torch::autograd;
16 using namespace torch::jit;
17 using namespace torch::jit::tracer;
18 
19 namespace torch {
20 namespace jit {
21 namespace tracer {
22 
23 // Python interpreter retrieval routine adapted from
24 // https://stackoverflow.com/a/8706144
25 std::string getPythonInterpreterStackTrace() {
26  std::stringstream stack_trace;
27  AutoGIL gil;
28  PyFrameObject* frame = PyEval_GetFrame();
29  while (nullptr != frame) {
30  int line = PyCode_Addr2Line(frame->f_code, frame->f_lasti);
31  std::string filename = THPUtils_unpackString(frame->f_code->co_filename);
32  std::string funcname = THPUtils_unpackString(frame->f_code->co_name);
33  stack_trace << filename << "(" << line << "): " << funcname << "\n";
34  frame = frame->f_back;
35  }
36  return stack_trace.str();
37 }
38 
39 std::shared_ptr<torch::jit::Graph> createGraphByTracing(
40  const py::function& func,
41  Stack trace_inputs,
42  const py::function& var_name_lookup_fn,
43  bool force_outplace,
44  const c10::optional<size_t>& num_real_inputs) {
45  size_t num_func_inputs = num_real_inputs.value_or(trace_inputs.size());
46  auto enter_info = tracer::enter(std::move(trace_inputs));
47  getTracingState()->lookup_var_name_fn =
48  [var_name_lookup_fn](const Variable& var) -> std::string {
49  AutoGIL ag;
50  return py::cast<std::string>(var_name_lookup_fn(var));
51  };
52  getTracingState()->force_outplace = force_outplace;
53  try {
54  py::tuple py_inputs(num_func_inputs);
55  for (size_t i = 0; i < num_func_inputs; ++i) {
56  py_inputs[i] = py::cast(enter_info.second[i]);
57  }
58  auto out = func(*py_inputs);
59  if (out.ptr() == Py_None) {
60  AT_ERROR(
61  "The traced function didn't return any values! Side-effects are not "
62  "captured in traces, so it would be a no-op.");
63  }
64  tracer::exit({toIValue(out)});
65  auto graph = enter_info.first->graph;
66  EliminateDeadCode(graph);
67  LowerSimpleTuples(graph);
68 
69  return graph;
70  } catch (...) {
71  tracer::abandon();
72  throw;
73  }
74 }
75 
76 Node* preRecordPythonTrace(
77  THPObjectPtr pyobj,
78  const std::string& arg_types,
80  pyobj_list scalar_args) {
81  THPObjectPtr apply(PyObject_GetAttrString(pyobj.get(), "apply"));
82  if (!apply) {
83  throw python_error();
84  }
85 
86  auto& graph = getTracingState()->graph;
87 
88  Node* n = graph->createPythonOp(
89  std::move(apply), arg_types, std::move(scalar_args));
90  recordSourceLocation(n);
91 
92  for (const Variable& input : inputs) {
93  n->addInput(getValueTrace(input));
94  }
95 
96  // NB: Order matters. This must append after inputs but before outputs.
97  graph->appendNode(n);
98 
99  return n;
100 }
101 
102 void pythonRecordSourceLocation(Node* n) {
103  auto sl =
104  std::make_shared<StringSourceLocation>(getPythonInterpreterStackTrace());
105  n->setSourceLocation(sl);
106 }
107 
108 void pythonWarn(const std::string& reason) {
109  AutoGIL gil;
110  auto warn_class = py::module::import("torch.jit").attr("TracerWarning");
111  PyErr_WarnEx(warn_class.ptr(), reason.c_str(), 1);
112 }
113 
114 void initPythonTracerBindings(PyObject* module) {
115  setRecordSourceLocation(pythonRecordSourceLocation);
116 
117  auto m = py::handle(module).cast<py::module>();
118  py::class_<TracingState, std::shared_ptr<TracingState>>(
119  m, "TracingState", py::dynamic_attr())
120  // NB: no constructor; you have to get it from C++ code
121  .def(
122  "__repr__",
123  [](const TracingState& s) {
124  std::ostringstream ss;
125  ss << "<TracingState " << (const void*)&s << ">";
126  return ss.str();
127  })
128  .def(
129  "__str__",
130  [](const TracingState& s) -> std::string {
131  std::ostringstream ss;
132  ss << *s.graph;
133  return ss.str();
134  })
135  .def(
136  "push_scope",
137  [](TracingState& s, const std::string& scope_name) {
138  s.graph->push_scope(scope_name);
139  })
140  .def("pop_scope", [](TracingState& s) { s.graph->pop_scope(); })
141  .def(
142  "set_graph",
143  [](TracingState& s, std::shared_ptr<Graph> g) { s.graph = g; })
144  .def("graph", [](TracingState& s) { return s.graph; });
145 
146  m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); });
147  m.def("_tracer_enter", [](py::args trace_inputs) {
148  return tracer::enter(toStack(trace_inputs));
149  });
150  m.def("_tracer_exit", [](py::tuple var_outputs) {
151  tracer::exit(toStack(var_outputs));
152  });
153  m.def("_tracer_abandon", []() { tracer::abandon(); });
154  m.def("_get_tracing_state", []() { return getTracingState(); });
155  m.def("_set_tracing_state", [](std::shared_ptr<TracingState> state) {
156  return setTracingState(state);
157  });
158  m.def("_get_value_trace", [](const Variable& var) {
159  return getValueTrace(var);
160  });
161  m.def("_set_value_trace", [](const Variable& var, Value* value) {
162  return setValueTrace(var, value);
163  });
164  m.def("_tracer_set_get_unique_name_fn", [](py::function func) {
165  const auto& tracing_state = getTracingState();
166  AT_ASSERT(tracing_state);
167  tracing_state->lookup_var_name_fn =
168  [func](const Variable& var) -> std::string {
169  AutoGIL ag;
170  return py::cast<std::string>(func(var));
171  };
172  });
173  m.def("_tracer_set_force_outplace", [](bool force_outplace) {
174  const auto& tracing_state = getTracingState();
175  AT_ASSERT(tracing_state);
176  tracing_state->force_outplace = force_outplace;
177  });
178 }
179 
180 } // namespace tracer
181 } // namespace jit
182 } // namespace torch
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41