1 #include <torch/csrc/jit/interpreter.h> 2 #include <torch/csrc/python_headers.h> 4 #include <torch/csrc/autograd/edge.h> 5 #include <torch/csrc/autograd/function.h> 6 #include <torch/csrc/autograd/profiler.h> 7 #include <torch/csrc/autograd/variable.h> 8 #include <torch/csrc/jit/custom_operator.h> 9 #include <torch/csrc/jit/graph_executor.h> 10 #include <torch/csrc/jit/ir.h> 11 #include <torch/csrc/jit/operator.h> 12 #include <torch/csrc/jit/pybind_utils.h> 16 #include <torch/csrc/Exceptions.h> 17 #include <torch/csrc/autograd/python_engine.h> 18 #include <torch/csrc/autograd/python_variable.h> 19 #include <torch/csrc/jit/pybind.h> 20 #include <torch/csrc/utils/auto_gil.h> 30 Operation createPythonOperation(
const Node* op_) {
32 const PythonOp* op =
static_cast<const PythonOp*
>(op_);
33 const py::function func = py::reinterpret_borrow<const py::function>(
35 py::handle(const_cast<PythonOp*>(op)->pyobj.get()));
37 size_t num_inputs = 0;
38 for (
auto arg_type : op->cconv) {
43 AT_ASSERT(op->outputs().size() == 1);
45 return [=](Stack& stack) {
47 py::tuple py_inputs(op->cconv.size());
49 size_t next_scalar = 0;
50 size_t next_tensor = 0;
51 for (
auto arg_type : op->cconv) {
52 if (arg_type ==
'c') {
53 py_inputs[i] = py::reinterpret_borrow<const py::object>(
55 const_cast<PythonOp*
>(op)->scalar_args[next_scalar++].
get());
56 }
else if (arg_type ==
'd') {
58 toPyObject(std::move(peek(stack, next_tensor, num_inputs)));
63 drop(stack, num_inputs);
65 py::object py_output(func(*py_inputs));
66 stack.push_back(returnToIValue(op->output()->type(), py_output));
67 }
catch (py::error_already_set& e) {
68 throw std::runtime_error(e.what());
74 RegisterOperators reg({Operator(prim::PythonOp, createPythonOperation)});