1 #include <torch/csrc/python_headers.h> 3 #include <torch/csrc/jit/argument_spec.h> 4 #include <torch/csrc/jit/export.h> 5 #include <torch/csrc/jit/ir.h> 6 #include <torch/csrc/jit/passes/alias_analysis.h> 7 #include <torch/csrc/jit/passes/python_print.h> 8 #include <torch/csrc/jit/passes/shape_analysis.h> 9 #include <torch/csrc/jit/pybind.h> 10 #include <torch/csrc/jit/python_tracer.h> 11 #include <torch/csrc/utils/auto_gil.h> 12 #include <torch/csrc/utils/pybind.h> 13 #include <torch/csrc/utils/python_strings.h> 23 std::string getPythonName(
const PyObject* obj_) {
26 PyObject* obj =
const_cast<PyObject*
>(obj_);
27 auto v = py::getattr(obj,
"__name__", py::str(
"<python_value>"));
32 std::ostream& printPyObject(std::ostream& out,
const THPObjectPtr& obj) {
35 auto pyobj = py::handle(const_cast<PyObject*>(obj.get()));
36 if (py::isinstance<py::tuple>(pyobj)) {
51 auto pytuple = pyobj.cast<py::tuple>();
54 for (
const auto& o : pytuple) {
59 out << THPUtils_unpackString(str.get());
68 return out << THPUtils_unpackString(py::str(pyobj).ptr());
72 std::vector<Node*> findAllNodes(
75 bool recurse =
true) {
76 std::vector<Node*> ret;
77 for (Block* block : blocks) {
78 for (Node* n : block->nodes()) {
79 if (n->kind() == kind) {
83 auto nodes = findAllNodes(n->blocks(), kind, recurse);
84 ret.insert(ret.end(), nodes.begin(), nodes.end());
91 std::vector<Node*> findAllNodes(
94 bool recurse =
true) {
95 std::vector<Block*> blocks = {block};
96 return findAllNodes(blocks, kind, recurse);
102 bool recurse =
true) {
103 for (Block* block : blocks) {
104 for (Node* n : block->nodes()) {
105 if (n->kind() == kind) {
109 auto node = findNode(n->blocks(), kind, recurse);
110 if (node !=
nullptr) {
119 Node* findNode(Block* block, Symbol kind,
bool recurse =
true) {
120 std::vector<Block*> blocks = {block};
121 return findNode(blocks, kind, recurse);
128 std::string name()
const override {
130 if (
auto autograd = autogradFunction()) {
131 return getPythonName(autograd->get());
133 return getPythonName(pyobj.get());
136 void cloneFrom(
Node* other_)
override {
137 Node::cloneFrom(other_);
138 auto other = other_->cast<
PythonOp>();
139 this->cconv = other->cconv;
140 Py_INCREF(other->pyobj.get());
142 for (
auto& sa : other->scalar_args) {
144 this->scalar_args.emplace_back(sa.get());
147 Node* allocNewInstance(
Graph* g)
override {
156 py::handle obj =
const_cast<PyObject*
>(pyobj.get());
158 auto r = py::getattr(obj,
"__self__", py::none());
162 auto apply = py::getattr(r,
"apply", py::none());
166 auto c = PyObject_RichCompareBool(apply.ptr(), obj.ptr(), Py_NE);
167 if (PyErr_Occurred())
168 throw py::error_already_set();
175 void writeScalars(std::ostream& out)
const override {
178 for (
auto& scalar : scalar_args) {
181 printPyObject(out, scalar);
191 void initPythonIRBindings(PyObject* module_) {
192 setAllocPythonOp(pythonAllocPythonOp);
194 auto m = py::handle(module_).cast<py::module>();
195 #define GS(name) def(#name, &Graph ::name) 196 py::class_<Graph, std::shared_ptr<Graph>>(m,
"Graph")
201 std::stringstream ss;
207 [](std::shared_ptr<Graph> g) {
213 [](std::shared_ptr<Graph> g,
214 std::vector<at::Tensor> inputs,
218 ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
219 PropagateInputShapes(g);
223 [](
const std::shared_ptr<Graph> g,
224 const std::vector<at::Tensor>& initializers,
225 int64_t onnx_opset_version,
226 bool defer_weight_export,
227 ::torch::onnx::OperatorExportTypes operator_export_type) {
229 RawDataExportMap export_map;
230 std::tie(graph, export_map) = export_onnx(
235 operator_export_type);
236 std::unordered_map<std::string, py::bytes>
237 python_serialized_export_map;
238 for (
auto& kv : export_map) {
240 size_t copy_bytes = t.element_size() * t.numel();
244 python_serialized_export_map[kv.first] =
245 py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
247 return std::make_tuple(
248 py::bytes(graph), python_serialized_export_map);
250 py::arg(
"initializers"),
251 py::arg(
"onnx_opset_version") = 0,
252 py::arg(
"defer_weight_export") =
false,
253 py::arg(
"operator_export_type") =
254 ::torch::onnx::OperatorExportTypes::ONNX)
256 "_pretty_print_onnx",
257 [](
const std::shared_ptr<Graph> g,
258 const std::vector<at::Tensor>& initializers,
259 int64_t onnx_opset_version,
260 bool defer_weight_export,
261 ::torch::onnx::OperatorExportTypes operator_export_type,
262 bool google_printer) {
263 return pretty_print_onnx(
268 operator_export_type,
271 py::arg(
"initializers"),
272 py::arg(
"onnx_opset_version") = 0,
273 py::arg(
"defer_weight_export") =
false,
274 py::arg(
"operator_export_type") =
275 ::torch::onnx::OperatorExportTypes::ONNX,
276 py::arg(
"google_printer") =
false)
280 return py::make_iterator(g.inputs().begin(), g.inputs().end());
285 return py::make_iterator(g.outputs().begin(), g.outputs().end());
291 return py::make_iterator(g.nodes().begin(), g.nodes().end());
295 [](
Graph& g,
const std::string& kind,
bool recurse) {
296 return findNode(g.block(), Symbol::fromQualString(kind), recurse);
300 py::arg(
"recurse") =
true)
303 [](
Graph& g,
const std::string& kind,
bool recurse) {
305 g.block(), Symbol::fromQualString(kind), recurse);
309 py::arg(
"recurse") =
true)
310 .def(
"addInput", [](
Graph& g) {
return g.addInput(); })
311 .def(
"copy", [](
Graph& g) {
return g.copy(); })
316 [](
Graph& g,
const char* str) {
317 return g.create(Symbol::fromQualString(str));
321 [](
Graph& g,
const char* str,
size_t noutputs) {
322 return g.create(Symbol::fromQualString(str), noutputs);
326 [](
Graph& g,
const char* str,
const std::vector<Value*>& inputs) {
327 return g.create(Symbol::fromQualString(str), inputs);
333 const std::vector<Value*>& inputs,
335 return g.create(Symbol::fromQualString(str), inputs, noutputs);
337 .def(
"param_node", [](
Graph& g) {
return g.block()->param_node(); })
338 .def(
"return_node", [](
Graph& g) {
return g.block()->return_node(); })
342 std::ostringstream oss;
346 .GS(createFusionGroup)
350 return g.createClone(
351 n, [&](
Value* e) {
return fn(e).cast<
Value*>(); });
359 #define VS(name) def(#name, &Value ::name) 360 py::class_<Value, std::unique_ptr<Value, py::nodelete>>(m,
"Value")
364 std::stringstream ss;
365 ss << n.uniqueName() <<
" defined in (" << *n.node() <<
")";
378 .VS(replaceAllUsesWith)
379 .def(
"node", [](
Value& v) {
return v.node(); })
383 node->setType(other->type());
388 .def(
"toIValue", [](
Value& n) {
return toIValue(&n); })
389 .def(
"type", [](
Value& v) {
return v.type(); });
392 py::class_<Block, std::unique_ptr<Block, py::nodelete>>(m,
"Block")
396 return py::make_iterator(b.nodes().begin(), b.nodes().end());
400 [](
Block& b,
const std::string& kind,
bool recurse) {
401 return findNode(&b, Symbol::fromQualString(kind), recurse);
405 py::arg(
"recurse") =
true)
408 [](
Block& b,
const std::string& kind,
bool recurse) {
409 return findAllNodes(&b, Symbol::fromQualString(kind), recurse);
413 py::arg(
"recurse") =
true)
417 return py::make_iterator(b.inputs().begin(), b.inputs().end());
422 return py::make_iterator(b.outputs().begin(), b.outputs().end());
427 return b.return_node();
432 return b.param_node();
435 #define NS(name) def(#name, &Node ::name) 436 py::class_<Node, std::unique_ptr<Node, py::nodelete>>(m,
"Node")
440 std::stringstream ss;
446 [](
Node& n) -> py::object {
447 std::stringstream ss;
448 if (
auto sl = n.getSourceLocation()) {
450 return py::str(ss.str());
455 .def(
"hasMultipleOutputs", [](
Node& n) {
return n.outputs().size() > 1; })
456 .def(
"outputsSize", [](
Node& n) {
return n.outputs().size(); })
458 .def(
"inputsAt", [](
Node& n,
size_t i) {
return n.inputs().at(i); })
462 return py::make_iterator(n.inputs().begin(), n.inputs().end());
467 return py::make_iterator(n.outputs().begin(), n.outputs().end());
469 .def(
"outputsAt", [](
Node& n,
size_t i) {
return n.outputs().at(i); })
472 [](
Node& n,
const std::string& kind,
bool recurse) {
473 return findNode(n.blocks(), Symbol::fromQualString(kind), recurse);
477 py::arg(
"recurse") =
true)
480 [](
Node& n,
const std::string& kind,
bool recurse) {
482 n.blocks(), Symbol::fromQualString(kind), recurse);
486 py::arg(
"recurse") =
true)
487 .def(
"input", [](
Node& n) {
return n.input(); })
488 .def(
"output", [](
Node& n) {
return n.output(); })
491 .NS(replaceInputWith)
492 .NS(replaceAllUsesWith)
504 .NS(isNondeterministic)
508 return py::make_iterator(n.blocks().begin(), n.blocks().end());
513 #define AS(name) def(#name, &Node::name) 518 #define AS(name) def(#name, &Node::name##S) 527 #define CREATE_ACCESSOR(Kind, method) \ 529 [](Node& n, const char* name, Kind##Attr::ValueType v) { \ 530 return n.method##_(Symbol::attr(name), std::move(v)); \ 532 .def(#method, [](Node& n, const char* name) { \ 533 return n.method(Symbol::attr(name)); \ 535 .CREATE_ACCESSOR(Float, f)
536 .CREATE_ACCESSOR(Floats, fs)
537 .CREATE_ACCESSOR(String, s)
538 .CREATE_ACCESSOR(Strings, ss)
539 .CREATE_ACCESSOR(Int, i)
540 .CREATE_ACCESSOR(Ints, is)
541 .CREATE_ACCESSOR(
Graph, g)
542 .CREATE_ACCESSOR(Graphs, gs)
543 #undef CREATE_ACCESSOR 548 AT_ASSERT(!v.requires_grad());
549 return n.t_(Symbol::attr(name), v);
553 [](
Node& n,
const char* name) {
return n.t(Symbol::attr(name)); })
559 std::vector<torch::autograd::Variable> vs) {
560 std::vector<at::Tensor> tensors;
561 tensors.reserve(vs.size());
562 for (
auto& variable : vs) {
563 AT_ASSERT(!variable.requires_grad());
564 tensors.push_back(variable);
566 return n.ts_(Symbol::attr(name), std::move(tensors));
570 [](
Node& n,
const char* name) {
571 auto tensors = n.ts(Symbol::attr(name));
572 std::vector<torch::autograd::Variable> variables;
573 variables.reserve(tensors.size());
574 for (
auto& tensor : tensors) {
575 variables.emplace_back(std::move(tensor));
588 [](
Node& n,
const char* name) {
return n.t(Symbol::attr(name)); })
591 [](
Node& n,
const char* name, TensorsAttr::ValueType v) {
595 return n.ts_(Symbol::attr(name), std::move(v));
599 [](
Node& n,
const char* name) {
return n.ts(Symbol::attr(name)); })
603 return py::handle(n.expect<
PythonOp>()->pyobj.get())
606 .def(
"cconv", [](
Node& n) {
return n.expect<
PythonOp>()->cconv; })
607 .def(
"pyname", [](
Node& n) {
return n.expect<
PythonOp>()->name(); })
608 .def(
"scalar_args", [](
Node& n) {
610 auto scalars = py::list();
611 auto append = scalars.attr(
"append");
612 for (
auto& arg : op->scalar_args) {
613 append(py::handle(arg.get()));
619 py::class_<Type, std::shared_ptr<Type>>(m,
"Type")
620 .def(
"__repr__", [](
Type& t) {
return t.python_str(); })
624 std::ostringstream s;
628 .def(
"kind", [](
const Type& t) {
return typeKindToString(t.kind()); })
643 return std::static_pointer_cast<
Type>(
653 [](std::shared_ptr<Type>&
self, std::shared_ptr<Type>& other) {
654 return *
self == *other;
658 [](std::shared_ptr<Type>&
self, std::shared_ptr<Type> other) {
659 return self->isSubtypeOf(other);
662 py::class_<NumberType, Type, std::shared_ptr<NumberType>>(m,
"NumberType")
663 .def_static(
"get", &NumberType::get);
664 py::class_<IntType, Type, std::shared_ptr<IntType>>(m,
"IntType")
665 .def_static(
"get", &IntType::get);
666 py::class_<FloatType, Type, std::shared_ptr<FloatType>>(m,
"FloatType")
667 .def_static(
"get", &FloatType::get);
668 py::class_<TensorType, Type, std::shared_ptr<TensorType>>(m,
"TensorType")
669 .def_static(
"get", &TensorType::get);
670 py::class_<BoolType, Type, std::shared_ptr<BoolType>>(m,
"BoolType")
671 .def_static(
"get", &BoolType::get);
672 py::class_<StringType, Type, std::shared_ptr<StringType>>(m,
"StringType")
673 .def_static(
"get", &StringType::get);
675 py::class_<TupleType, Type, std::shared_ptr<TupleType>>(m,
"TupleType")
677 py::init([](std::vector<TypePtr> a) {
return TupleType::create(a); }))
679 std::vector<TypePtr> types;
680 for (
const auto& type :
self.elements()) {
681 types.push_back(type);
685 py::class_<ListType, Type, std::shared_ptr<ListType>>(m,
"ListType")
686 .def(py::init([](TypePtr a) {
return ListType::create(a); }))
687 .def_static(
"ofInts", &ListType::ofInts)
688 .def_static(
"ofTensors", &ListType::ofTensors)
689 .def(
"getElementType", &ListType::getElementType);
690 py::class_<DictType, Type, std::shared_ptr<DictType>>(m,
"DictType")
691 .def(py::init([](TypePtr key, TypePtr value) {
692 return DictType::create(key, value);
694 py::class_<OptionalType, Type, std::shared_ptr<OptionalType>>(
696 .def(py::init([](TypePtr a) {
return OptionalType::create(a); }))
697 .def_static(
"ofTensor", &OptionalType::ofTensor)
698 .def(
"getElementType", &OptionalType::getElementType);
700 py::class_<Use>(m,
"Use")
701 .def_readonly(
"user", &Use::user)
702 .def_readonly(
"offset", &Use::offset);
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...