Caffe2 - C++ API
A deep learning, cross platform ML framework
onnx.cpp
1 #include <torch/csrc/jit/passes/onnx.h>
2 #include <ATen/core/functional.h>
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/autograd/function.h>
5 #include <torch/csrc/autograd/symbolic.h>
6 #include <torch/csrc/jit/passes/dead_code_elimination.h>
7 #include <torch/csrc/utils/pybind.h>
8 #include <sstream>
9 #include <unordered_map>
10 
11 namespace torch {
12 namespace jit {
13 
14 void removePrintOps(Block* block) {
15  for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
16  ++it) {
17  for (auto b : it->blocks()) {
18  removePrintOps(b);
19  }
20  if (it->kind() == prim::Print || it->kind() == aten::warn) {
21  for (size_t i = 0; i < it->inputs().size();) {
22  auto input = it->inputs().at(i);
23  // only handling constants bc of potential side effects
24  if (input->uses().size() == 1 &&
25  input->node()->kind() == prim::Constant) {
26  it->removeInput(i);
27  input->node()->destroy();
28  } else {
29  ++i;
30  }
31  }
32  it.destroyCurrent();
33  }
34  }
35 }
36 
37 void removePrintOps(std::shared_ptr<Graph>& graph) {
38  removePrintOps(graph->block());
39 }
40 
41 // Transform PythonOps into Nodes that match ONNX semantics.
42 std::shared_ptr<Graph> ToONNX(
43  std::shared_ptr<Graph>& graph,
44  ::torch::onnx::OperatorExportTypes operator_export_type) {
45  auto new_graph = std::make_shared<Graph>(graph->current_scope());
46  std::unordered_map<Value*, Value*> env;
47  removePrintOps(graph);
48  BlockToONNX(graph->block(), new_graph->block(), operator_export_type, env);
49  return new_graph;
50 }
51 
52 void BlockToONNX(
53  Block* old_block,
54  Block* new_block,
55  ::torch::onnx::OperatorExportTypes operator_export_type,
56  std::unordered_map<Value*, Value*> env) {
58  ctx.block = new_block;
59 
60  py::object onnx = py::module::import("torch.onnx");
61  py::object onnx_symbolic = py::module::import("torch.onnx.symbolic");
62 
63  // Returns a node that n maps to in the new graph
64  auto envFn = [&env](Value* n) -> Value* {
65  auto it = env.find(n);
66  AT_CHECK(it != env.end(), "Dangling node reference");
67  AT_CHECK(it->second, "Unused node was subsequently used");
68  return it->second;
69  };
70 
71  // Initialize context and environment
72  for (auto input : old_block->inputs()) {
73  auto n = ctx.block->addInput()->copyMetadata(input);
74  env[input] = n;
75  }
76  // Put the new outputs in our environment map, and copy the type from the
77  // input graph if they were not set by the symbolic. This is called only
78  // with results of symbolic call (not for nodes that are just cloned).
79  auto setOutputs = [&](const std::string& op_name,
80  Node* node,
81  const value_list& outputs) {
82  auto old_outputs = node->outputs();
83  // Count all outputs, excluding Handles
84  auto num_old_outputs = old_outputs.size();
85  if (outputs.size() != num_old_outputs) {
86  std::ostringstream ss;
87  ss << "symbolic for " << op_name
88  << " produced an incorrect number of outputs (expected ";
89  ss << num_old_outputs << ", but got " << outputs.size() << ")";
90  throw std::runtime_error(ss.str());
91  }
92  for (size_t i = 0; i < num_old_outputs; ++i) {
93  auto old = old_outputs[i];
94  if (outputs[i]) {
95  // Allow symbolic() to skip specifying the type of the return node.
96  // Unfortunately, they are on the hook for all internal nodes
97  // (though in practice, the types are not computed.)
98  outputs[i]->setType(old->type());
99  // Copy over source location and scope information to all nodes
100  // created by the symbolic
101  outputs[i]->node()->setSourceLocation(node->getSourceLocation());
102  outputs[i]->node()->setScope(node->scope());
103  env[old] = outputs[i];
104  } else {
105  // Null output means that the ONNX op doesn't have outputs corresponding
106  // to certain PyTorch outputs
107  env[old] = nullptr;
108  if (!old->uses().empty()) {
109  std::ostringstream ss;
110  ss << "symbolic for " << op_name << " returned None for the output "
111  << i;
112  ss << " (indicating conversion for that particular output is not supported), ";
113  ss << "but the network uses this output later";
114  // TODO: Say what actually used it
115  throw std::runtime_error(ss.str());
116  }
117  }
118  }
119  };
120 
121  // Clone the node and add it to the new graph
122  auto cloneNode = [&](Node* node) {
123  auto n_ = ctx.block->appendNode(
124  ctx.block->owningGraph()->createClone(node, envFn));
125  for (size_t i = 0; i < node->outputs().size(); i++) {
126  // n_->outputs()[i]->setType(node->outputs()[i]->type());
127  env[node->outputs()[i]] = n_->outputs()[i];
128  }
129  };
130 
131  // Cast output of symbolic() python implementation
132  auto processSymbolicOutput = [&](const std::string& op_name,
133  Node* n,
134  const py::object& raw_output) {
135  if (raw_output.ptr() == Py_None) {
136  cloneNode(n);
137  return;
138  }
139  // Cast the outputs back to C++ and put them in the new graph
140  std::vector<Value*> outputs;
141  try {
142  if (py::isinstance<Value>(raw_output)) {
143  outputs = value_list{py::cast<Value*>(raw_output)};
144  } else {
145  outputs = py::cast<std::vector<Value*>>(raw_output);
146  }
147  } catch (const std::exception& ex) {
148  std::ostringstream ss;
149  ss << "Error casting results of symbolic for " << op_name
150  << ": expected to return list of op nodes, instead received type ''"
151  << py::str(raw_output.get_type()) << "': " << py::str(raw_output);
152  throw std::runtime_error(ss.str());
153  }
154 
155  setOutputs(op_name, n, outputs);
156  };
157 
158  auto callPySymbolicFunction = [&](Node* n) {
159  // The idea is delegate as much of the actual argument massaging to
160  // Python as possible
161 
162  py::tuple py_inputs(n->inputs().size());
163  Py_ssize_t input_nr = 0;
164  for (auto* input : n->inputs()) {
165  py_inputs[input_nr++] = py::cast(envFn(input));
166  }
167 
168  WithInsertPoint insert_point_guard(ctx.block);
169  WithCurrentScope scope_guard(*ctx.block->owningGraph(), n->scope());
170  py::object raw_output = onnx.attr("_run_symbolic_function")(
171  ctx.block->owningGraph(), n, py_inputs, env, operator_export_type);
172 
173  // TODO: Assert it's an ATen identifier???
174  // (Sometimes it's not...)
175  processSymbolicOutput(n->kind().toUnqualString(), n, raw_output);
176  };
177 
178  auto callPySymbolicMethod = [&](PythonOp* op) {
179  // Test if there is a symbolic function; bail if there is not
180  auto pyobj = py::handle(op->pyobj.get());
181  auto func = op->autogradFunction();
182  if (func) {
183  pyobj = func->get();
184  }
185 
186  if (!py::hasattr(pyobj, "symbolic")) {
187  cloneNode(op);
188  return;
189  }
190 
191  // Prepare args for Python. First one is the graph, and is followed
192  // by regular args, with Variables replaced by corresponding nodes.
193  Py_ssize_t input_nr = 0;
194  py::tuple py_symbolic_args(1 + op->cconv.size());
195  py_symbolic_args[input_nr++] = py::cast(ctx.block->owningGraph());
196  auto inputs = op->inputs();
197  auto node_it = inputs.begin();
198  auto scalar_it = op->scalar_args.begin();
199  for (auto arg_type : op->cconv) {
200  py::object obj;
201  if (arg_type == 'c') {
202  AT_CHECK(
203  scalar_it != op->scalar_args.end(),
204  "expected too many scalar args");
205  obj = py::reinterpret_borrow<py::object>(
206  py::handle((scalar_it++)->get()));
207  } else if (arg_type == 'd') {
208  AT_CHECK(node_it != inputs.end(), "expected too many inputs");
209  obj = py::cast(envFn(*node_it++));
210  } else {
211  throw std::runtime_error("unexpected calling convention");
212  }
213  py_symbolic_args[input_nr++] = obj;
214  }
215 
216  WithInsertPoint insert_point_guard(ctx.block);
217  WithCurrentScope scope_guard(*ctx.block->owningGraph(), op->scope());
218  // Call the symbolic function
219  // Use a little trampoline function so we can give good error messages
220  // upon argument mismatch
221  py::object raw_output = onnx.attr("_run_symbolic_method")(
222  op->name(), pyobj.attr("symbolic"), py_symbolic_args);
223 
224  processSymbolicOutput(op->name(), op, raw_output);
225  };
226 
227  // Finally, visit all nodes in the graph
228  for (auto node : old_block->nodes()) {
229  if (node->kind() == prim::PythonOp) {
230  callPySymbolicMethod(static_cast<PythonOp*>(node));
231  } else {
232  callPySymbolicFunction(node);
233  }
234  }
235  for (auto output : old_block->outputs()) {
236  ctx.block->registerOutput(env.at(output));
237  env.at(output)->setType(output->type());
238  }
239 
240  EliminateDeadCode(ctx.block);
241 }
242 
243 } // namespace jit
244 } // namespace torch
Definition: jit_type.h:17