1 #include <torch/csrc/jit/passes/onnx.h> 2 #include <ATen/core/functional.h> 3 #include <c10/util/Exception.h> 4 #include <torch/csrc/autograd/function.h> 5 #include <torch/csrc/autograd/symbolic.h> 6 #include <torch/csrc/jit/passes/dead_code_elimination.h> 7 #include <torch/csrc/utils/pybind.h> 9 #include <unordered_map> 14 void removePrintOps(Block* block) {
15 for (
auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
17 for (
auto b : it->blocks()) {
20 if (it->kind() == prim::Print || it->kind() == aten::warn) {
21 for (
size_t i = 0; i < it->inputs().size();) {
22 auto input = it->inputs().at(i);
24 if (input->uses().size() == 1 &&
25 input->node()->kind() == prim::Constant) {
27 input->node()->destroy();
37 void removePrintOps(std::shared_ptr<Graph>& graph) {
38 removePrintOps(graph->block());
42 std::shared_ptr<Graph> ToONNX(
43 std::shared_ptr<Graph>& graph,
44 ::torch::onnx::OperatorExportTypes operator_export_type) {
45 auto new_graph = std::make_shared<Graph>(graph->current_scope());
46 std::unordered_map<Value*, Value*> env;
47 removePrintOps(graph);
48 BlockToONNX(graph->block(), new_graph->block(), operator_export_type, env);
55 ::torch::onnx::OperatorExportTypes operator_export_type,
56 std::unordered_map<Value*, Value*> env) {
58 ctx.block = new_block;
60 py::object
onnx = py::module::import(
"torch.onnx");
61 py::object onnx_symbolic = py::module::import(
"torch.onnx.symbolic");
64 auto envFn = [&env](Value* n) -> Value* {
65 auto it = env.find(n);
66 AT_CHECK(it != env.end(),
"Dangling node reference");
67 AT_CHECK(it->second,
"Unused node was subsequently used");
72 for (
auto input : old_block->inputs()) {
73 auto n = ctx.block->addInput()->copyMetadata(input);
79 auto setOutputs = [&](
const std::string& op_name,
81 const value_list& outputs) {
82 auto old_outputs = node->outputs();
84 auto num_old_outputs = old_outputs.size();
85 if (outputs.size() != num_old_outputs) {
86 std::ostringstream ss;
87 ss <<
"symbolic for " << op_name
88 <<
" produced an incorrect number of outputs (expected ";
89 ss << num_old_outputs <<
", but got " << outputs.size() <<
")";
90 throw std::runtime_error(ss.str());
92 for (
size_t i = 0; i < num_old_outputs; ++i) {
93 auto old = old_outputs[i];
98 outputs[i]->setType(old->type());
101 outputs[i]->node()->setSourceLocation(node->getSourceLocation());
102 outputs[i]->node()->setScope(node->scope());
103 env[old] = outputs[i];
108 if (!old->uses().empty()) {
109 std::ostringstream ss;
110 ss <<
"symbolic for " << op_name <<
" returned None for the output " 112 ss <<
" (indicating conversion for that particular output is not supported), ";
113 ss <<
"but the network uses this output later";
115 throw std::runtime_error(ss.str());
122 auto cloneNode = [&](Node* node) {
123 auto n_ = ctx.block->appendNode(
124 ctx.block->owningGraph()->createClone(node, envFn));
125 for (
size_t i = 0; i < node->outputs().size(); i++) {
127 env[node->outputs()[i]] = n_->outputs()[i];
132 auto processSymbolicOutput = [&](
const std::string& op_name,
134 const py::object& raw_output) {
135 if (raw_output.ptr() == Py_None) {
140 std::vector<Value*> outputs;
142 if (py::isinstance<Value>(raw_output)) {
143 outputs = value_list{py::cast<Value*>(raw_output)};
145 outputs = py::cast<std::vector<Value*>>(raw_output);
147 }
catch (
const std::exception& ex) {
148 std::ostringstream ss;
149 ss <<
"Error casting results of symbolic for " << op_name
150 <<
": expected to return list of op nodes, instead received type ''" 151 << py::str(raw_output.get_type()) <<
"': " << py::str(raw_output);
152 throw std::runtime_error(ss.str());
155 setOutputs(op_name, n, outputs);
158 auto callPySymbolicFunction = [&](Node* n) {
162 py::tuple py_inputs(n->inputs().size());
163 Py_ssize_t input_nr = 0;
164 for (
auto* input : n->inputs()) {
165 py_inputs[input_nr++] = py::cast(envFn(input));
168 WithInsertPoint insert_point_guard(ctx.block);
169 WithCurrentScope scope_guard(*ctx.block->owningGraph(), n->scope());
170 py::object raw_output = onnx.attr(
"_run_symbolic_function")(
171 ctx.block->owningGraph(), n, py_inputs, env, operator_export_type);
175 processSymbolicOutput(n->kind().toUnqualString(), n, raw_output);
178 auto callPySymbolicMethod = [&](PythonOp* op) {
180 auto pyobj = py::handle(op->pyobj.get());
181 auto func = op->autogradFunction();
186 if (!py::hasattr(pyobj,
"symbolic")) {
193 Py_ssize_t input_nr = 0;
194 py::tuple py_symbolic_args(1 + op->cconv.size());
195 py_symbolic_args[input_nr++] = py::cast(ctx.block->owningGraph());
196 auto inputs = op->inputs();
197 auto node_it = inputs.begin();
198 auto scalar_it = op->scalar_args.begin();
199 for (
auto arg_type : op->cconv) {
201 if (arg_type ==
'c') {
203 scalar_it != op->scalar_args.end(),
204 "expected too many scalar args");
205 obj = py::reinterpret_borrow<py::object>(
206 py::handle((scalar_it++)->
get()));
207 }
else if (arg_type ==
'd') {
208 AT_CHECK(node_it != inputs.end(),
"expected too many inputs");
209 obj = py::cast(envFn(*node_it++));
211 throw std::runtime_error(
"unexpected calling convention");
213 py_symbolic_args[input_nr++] = obj;
216 WithInsertPoint insert_point_guard(ctx.block);
217 WithCurrentScope scope_guard(*ctx.block->owningGraph(), op->scope());
221 py::object raw_output = onnx.attr(
"_run_symbolic_method")(
222 op->name(), pyobj.attr(
"symbolic"), py_symbolic_args);
224 processSymbolicOutput(op->name(), op, raw_output);
228 for (
auto node : old_block->nodes()) {
229 if (node->kind() == prim::PythonOp) {
230 callPySymbolicMethod(static_cast<PythonOp*>(node));
232 callPySymbolicFunction(node);
235 for (
auto output : old_block->outputs()) {
236 ctx.block->registerOutput(env.at(output));
237 env.at(output)->setType(output->type());
240 EliminateDeadCode(ctx.block);