Caffe2 - C++ API
A deep learning, cross platform ML framework
python_interpreter.cpp
1 #include <torch/csrc/jit/interpreter.h>
2 #include <torch/csrc/python_headers.h>
3 
4 #include <torch/csrc/autograd/edge.h>
5 #include <torch/csrc/autograd/function.h>
6 #include <torch/csrc/autograd/profiler.h>
7 #include <torch/csrc/autograd/variable.h>
8 #include <torch/csrc/jit/custom_operator.h>
9 #include <torch/csrc/jit/graph_executor.h>
10 #include <torch/csrc/jit/ir.h>
11 #include <torch/csrc/jit/operator.h>
12 #include <torch/csrc/jit/pybind_utils.h>
13 
14 #include <typeinfo>
15 
16 #include <torch/csrc/Exceptions.h>
17 #include <torch/csrc/autograd/python_engine.h>
18 #include <torch/csrc/autograd/python_variable.h>
19 #include <torch/csrc/jit/pybind.h>
20 #include <torch/csrc/utils/auto_gil.h>
21 
22 namespace py = pybind11;
23 
24 namespace torch {
25 namespace jit {
26 
27 namespace {
28 
29 // Note: const_cast is used twice below to acquire a handle to a pyobject.
30 Operation createPythonOperation(const Node* op_) {
31  AutoGIL gil;
32  const PythonOp* op = static_cast<const PythonOp*>(op_);
33  const py::function func = py::reinterpret_borrow<const py::function>(
34  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
35  py::handle(const_cast<PythonOp*>(op)->pyobj.get()));
36 
37  size_t num_inputs = 0;
38  for (auto arg_type : op->cconv) {
39  if (arg_type == 'd')
40  num_inputs++;
41  }
42 
43  AT_ASSERT(op->outputs().size() == 1);
44 
45  return [=](Stack& stack) {
46  AutoGIL gil;
47  py::tuple py_inputs(op->cconv.size());
48  size_t i = 0;
49  size_t next_scalar = 0;
50  size_t next_tensor = 0;
51  for (auto arg_type : op->cconv) {
52  if (arg_type == 'c') {
53  py_inputs[i] = py::reinterpret_borrow<const py::object>(
54  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
55  const_cast<PythonOp*>(op)->scalar_args[next_scalar++].get());
56  } else if (arg_type == 'd') {
57  py_inputs[i] =
58  toPyObject(std::move(peek(stack, next_tensor, num_inputs)));
59  next_tensor++;
60  }
61  i++;
62  }
63  drop(stack, num_inputs);
64  try {
65  py::object py_output(func(*py_inputs));
66  stack.push_back(returnToIValue(op->output()->type(), py_output));
67  } catch (py::error_already_set& e) {
68  throw std::runtime_error(e.what());
69  }
70  return 0;
71  };
72 }
73 
74 RegisterOperators reg({Operator(prim::PythonOp, createPythonOperation)});
75 
76 } // namespace
77 } // namespace jit
78 } // namespace torch
Definition: jit_type.h:17