Caffe2 - C++ API
A deep learning, cross platform ML framework
init.cpp
1 #include <torch/csrc/utils/auto_gil.h>
2 #include <torch/csrc/utils/pybind.h>
3 
4 #include <torch/csrc/jit/argument_spec.h>
5 #include <torch/csrc/jit/batched/BatchTensor.h>
6 #include <torch/csrc/jit/export.h>
7 #include <torch/csrc/jit/fuser/interface.h>
8 #include <torch/csrc/jit/fuser/kernel_cache.h>
9 #include <torch/csrc/jit/graph_executor.h>
10 #include <torch/csrc/jit/import.h>
11 #include <torch/csrc/jit/operator.h>
12 #include <torch/csrc/jit/passes/canonicalize.h>
13 #include <torch/csrc/jit/passes/canonicalize_ops.h>
14 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
15 #include <torch/csrc/jit/passes/constant_pooling.h>
16 #include <torch/csrc/jit/passes/constant_propagation.h>
17 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
18 #include <torch/csrc/jit/passes/dead_code_elimination.h>
19 #include <torch/csrc/jit/passes/erase_number_types.h>
20 #include <torch/csrc/jit/passes/graph_fuser.h>
21 #include <torch/csrc/jit/passes/inline_fork_wait.h>
22 #include <torch/csrc/jit/passes/loop_unrolling.h>
23 #include <torch/csrc/jit/passes/lower_tuples.h>
24 #include <torch/csrc/jit/passes/onnx.h>
25 #include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h>
26 #include <torch/csrc/jit/passes/onnx/peephole.h>
27 #include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
28 #include <torch/csrc/jit/passes/peephole.h>
29 #include <torch/csrc/jit/passes/remove_expands.h>
30 #include <torch/csrc/jit/passes/remove_inplace_ops.h>
31 #include <torch/csrc/jit/passes/shape_analysis.h>
32 #include <torch/csrc/jit/passes/specialize_autogradzero.h>
33 #include <torch/csrc/jit/passes/to_batch.h>
34 #include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
35 #include <torch/csrc/jit/pybind_utils.h>
36 #include <torch/csrc/jit/python_arg_flatten.h>
37 #include <torch/csrc/jit/python_ir.h>
38 #include <torch/csrc/jit/python_tracer.h>
39 #include <torch/csrc/jit/script/compiler.h>
40 #include <torch/csrc/jit/script/init.h>
41 #include <torch/csrc/jit/script/jit_exception.h>
42 #include <torch/csrc/jit/script/python_tree_views.h>
43 #include <torch/csrc/jit/tracer.h>
44 
45 #include <caffe2/serialize/inline_container.h>
46 
47 #include <ATen/core/function_schema.h>
48 
49 #include <pybind11/functional.h>
50 
51 #include <memory>
52 #include <sstream>
53 #include <stdexcept>
54 #include <string>
55 #include <tuple>
56 #include <utility>
57 
58 namespace torch {
59 namespace jit {
60 
61 using ::c10::Argument;
62 using ::c10::FunctionSchema;
65 
66 namespace {
67 
68 using autograd::variable_list;
69 
70 bool loadPythonClasses() {
71  // Leaving this code here, because it will likely be useful at some point
72  // PyObject *jit_module = PyImport_ImportModule("torch.jit");
73  // THPUtils_assert(jit_module, "class loader couldn't access "
74  //"torch.jit module");
75  // PyObject *jit_dict = PyModule_GetDict(jit_module);
76 
77  return true;
78 }
79 
80 } // anonymous namespace
81 
82 #if defined(_WIN32)
83 void runJITCPPTests() {
84  AT_ERROR("JIT tests not yet supported on Windows");
85 }
86 #else
87 void runJITCPPTests();
88 #endif
89 
90 void initJITBindings(PyObject* module) {
91  auto m = py::handle(module).cast<py::module>();
92 
93  py::register_exception<JITException>(m, "JITException");
94 
95  py::class_<python::IODescriptor> iodescriptor(
96  m, "IODescriptor"); // NOLINT(bugprone-unused-raii)
97 
98  m.def("_jit_init", loadPythonClasses)
99  .def(
100  "_jit_debug_fuser_num_cached_kernel_specs",
101  torch::jit::fuser::debugNumCachedKernelSpecs)
102  .def("_jit_pass_onnx", ToONNX)
103  .def("_jit_pass_lower_all_tuples", LowerAllTuples)
104  .def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
105  .def("_jit_pass_fuse", FuseGraph)
106  .def(
107  "_jit_pass_dce",
108  [](std::shared_ptr<Graph>& g) {
109  return EliminateDeadCode(g->block()); // overload resolution
110  })
111  .def(
112  "_jit_pass_cse",
113  [](std::shared_ptr<Graph>& g) {
114  return EliminateCommonSubexpression(g); // overload resolution
115  })
116  .def(
117  "_jit_pass_remove_inplace_ops",
118  [](std::shared_ptr<Graph> g) { return RemoveInplaceOps(g); })
119  .def("_jit_pass_constant_pooling", ConstantPooling)
120  .def(
121  "_jit_pass_peephole",
122  [](const std::shared_ptr<Graph>& g, bool addmm_fusion_enabled) {
123  return PeepholeOptimize(g, addmm_fusion_enabled);
124  },
125  py::arg("graph"),
126  py::arg("addmm_fusion_enabled") = false)
127  .def(
128  "_jit_pass_canonicalize",
129  [](const std::shared_ptr<Graph>& g) { return Canonicalize(g); })
130  .def("_jit_pass_lint", LintGraph)
131  .def(
132  "_jit_pass_shape_analysis",
133  [](std::shared_ptr<Graph> graph,
134  std::vector<at::Tensor> inputs,
135  bool with_grad) {
136  setInputTypes(
137  *graph,
138  ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
139  PropagateInputShapes(graph);
140  })
141  .def(
142  "_jit_pass_complete_shape_analysis",
143  [](std::shared_ptr<Graph> graph, py::tuple inputs, bool with_grad) {
144  CompleteArgumentSpec spec(
145  with_grad,
146  evilDeprecatedBadCreateStackDoNotUse(inputs, graph->inputs()));
147  auto graph_inputs = graph->inputs();
148  AT_ASSERT(spec.size() == graph_inputs.size());
149  for (size_t i = 0; i < graph_inputs.size(); ++i) {
150  graph_inputs[i]->setType(spec.at(i));
151  }
152  PropagateInputShapes(graph);
153  })
154  .def("_jit_pass_remove_expands", RemoveExpands)
155  .def("_jit_pass_erase_number_types", EraseNumberTypes)
156  .def("_jit_pass_inline_fork_wait", InlineForkWait)
157  .def("_jit_pass_prepare_division_for_onnx", PrepareDivisionForONNX)
158  .def("_jit_pass_loop_unrolling", UnrollLoops)
159  .def(
160  "_jit_pass_constant_propagation",
161  [](std::shared_ptr<Graph>& g) { return ConstantPropagation(g); })
162  .def("_jit_pass_erase_shape_information", EraseShapeInformation)
163  .def(
164  "_jit_pass_create_autodiff_subgraphs",
165  [](std::shared_ptr<Graph> graph) { CreateAutodiffSubgraphs(graph); })
166  .def(
167  "_jit_run_cpp_tests",
168  [] {
169  // We have to release the GIL inside this method, because if we
170  // happen to initialize the autograd engine in these tests, the
171  // newly spawned worker threads will try to initialize their
172  // PyThreadState*, and they need the GIL for this.
173  AutoNoGIL _no_gil;
174  return runJITCPPTests();
175  })
176  .def(
177  "_jit_flatten",
178  [](py::handle& obj) {
179  auto res = python::flatten(obj);
180  return std::make_pair(res.vars, res.desc);
181  })
182  .def(
183  "_jit_unflatten",
184  [](autograd::variable_list vars, python::IODescriptor& desc) {
185  return py::reinterpret_steal<py::object>(
186  python::unflatten(vars, desc));
187  })
188  .def("_jit_pass_onnx_block", BlockToONNX)
189  .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
190  .def("_jit_pass_canonicalize_ops", CanonicalizeOps)
191  .def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
192  .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
193  .def(
194  "_jit_differentiate",
195  [](Graph& g) {
196  // the python binding slightly differs in semantics
197  // it makes a copy of the input Graph, and works on that
198  // jit::differentiate mutates the input Graph
199  auto g_clone = g.copy();
200  return differentiate(g_clone);
201  })
202  .def(
203  "_jit_check_alias_annotation",
204  [](std::shared_ptr<Graph> g,
205  py::tuple args,
206  const std::string& unqualified_op_name) {
207  auto stack = toStack(args);
208  checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
209  });
210 
211  // NOLINTNEXTLINE(bugprone-unused-raii)
212  py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
213  .def("__repr__", [](CompleteArgumentSpec& self) {
214  std::ostringstream s;
215  s << self;
216  return s.str();
217  });
218  // NOLINTNEXTLINE(bugprone-unused-raii)
219  py::class_<ArgumentSpec>(m, "ArgumentSpec");
220  py::class_<Code>(m, "Code").def("grad_executors", [](Code& c) {
221  return py::make_iterator(
222  c.grad_executors().begin(), c.grad_executors().end());
223  });
224 
225  py::class_<ExecutionPlanState>(m, "ExecutionPlanState")
226  .def_property_readonly(
227  "graph", [](ExecutionPlanState& s) { return s.graph; })
228  .def_property_readonly(
229  "code", [](ExecutionPlanState& s) { return s.code; });
230 
231  py::class_<Gradient>(m, "Gradient")
232  .def_property_readonly("f", [](Gradient& m) { return m.f; })
233  .def_property_readonly("df", [](Gradient& m) { return m.df; })
234  .def_property_readonly(
235  "f_real_outputs", [](Gradient& m) { return m.f_real_outputs; })
236  .def_property_readonly(
237  "df_input_vjps", [](Gradient& m) { return m.df_input_vjps; })
238  .def_property_readonly(
239  "df_input_captured_inputs",
240  [](Gradient& m) { return m.df_input_captured_inputs; })
241  .def_property_readonly(
242  "df_input_captured_outputs",
243  [](Gradient& m) { return m.df_input_captured_outputs; })
244  .def_property_readonly(
245  "df_output_vjps", [](Gradient& m) { return m.df_output_vjps; });
246 
247  py::class_<GraphExecutorState>(m, "GraphExecutorState")
248  .def_property_readonly(
249  "graph", [](GraphExecutorState& s) { return s.graph; })
250  .def_property_readonly(
251  "execution_plans",
252  [](GraphExecutorState& s) { return s.execution_plans; })
253  .def_property_readonly(
254  "fallback", [](GraphExecutorState& s) { return s.fallback; });
255 
256  py::class_<GraphExecutor>(m, "GraphExecutor", py::dynamic_attr())
257  .def(
258  py::init([](py::function func,
259  py::tuple inputs,
260  py::function var_name_lookup_fn,
261  bool optimize,
262  bool _force_outplace) {
263  auto graph = tracer::createGraphByTracing(
264  func, toStack(inputs), var_name_lookup_fn, _force_outplace);
265  return GraphExecutor(graph, optimize);
266  }),
267  py::arg("func"),
268  py::arg("inputs"),
269  py::arg("var_name_lookup_fn"),
270  py::arg("optimize") = true,
271  py::arg("_force_outplace") = false)
272  .def(
273  py::init([](std::shared_ptr<Graph> graph, bool optimize) {
274  return GraphExecutor(std::move(graph), optimize);
275  }),
276  py::arg("graph"),
277  py::arg("optimize") = true)
278  .def(
279  "graph_for",
280  [](GraphExecutor& ge, py::args args) {
281  return ge.graphFor(evilDeprecatedBadCreateStackDoNotUse(
282  args, ge.graph()->inputs()));
283  })
284  .def_property_readonly(
285  "graph", [](GraphExecutor& ge) { return ge.graph(); })
286  .def(
287  "get_debug_state",
288  [](GraphExecutor& ge) { return ge.getDebugState(); })
289  .def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {
290  const auto& graph = ge.graph();
291  auto stack =
292  evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs());
293  {
294  AutoNoGIL no_gil_guard;
295  ge.run(stack);
296  }
297  return createPyObjectForStack(std::move(stack));
298  });
299 
300  py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
301  .def(py::init<std::string>())
302  .def(
303  "write_record",
304  [](PyTorchStreamWriter& self,
305  const std::string& name,
306  const char* data,
307  size_t size) { return self.writeRecord(name, data, size); })
308  .def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile);
309 
310  py::class_<PyTorchStreamReader>(m, "PyTorchFileReader")
311  .def(py::init<std::string>())
312  .def("get_record", [](PyTorchStreamReader& self, const std::string& key) {
313  at::DataPtr data;
314  size_t size;
315  std::tie(data, size) = self.getRecord(key);
316  return py::bytes(reinterpret_cast<const char*>(data.get()), size);
317  });
318 
319  m.def(
320  "_jit_get_operation",
321  [](const std::string& qualified_name) {
322  try {
323  auto symbol = Symbol::fromQualString(qualified_name);
324  auto operations = getAllOperatorsFor(symbol);
325  AT_CHECK(!operations.empty(), "No such operator ", qualified_name);
326  AT_CHECK(
327  operations.size() == 1,
328  "Found ",
329  operations.size(),
330  " overloads for operator ",
331  qualified_name,
332  "! Overloads are not supported from Python.");
333  std::shared_ptr<Operator> op = operations[0];
334  AT_ASSERT(op != nullptr);
335  std::ostringstream docstring;
336  docstring << "Automatically bound operator '" << qualified_name
337  << "' with schema: " << op->schema();
338  return py::cpp_function(
339  [op](py::args args, py::kwargs kwargs) {
340  return invokeOperatorFromPython(
341  *op, std::move(args), std::move(kwargs));
342  },
343  py::name(qualified_name.c_str()),
344  py::doc(docstring.str().c_str()));
345  } catch (const c10::Error& error) {
346  throw std::runtime_error(error.what_without_backtrace());
347  }
348  },
349  py::arg("qualified_name"));
350 
351  py::class_<FunctionSchema>(m, "FunctionSchema")
352  .def_property_readonly(
353  "name", [](FunctionSchema& self) { return self.name(); })
354  .def_property_readonly(
355  "overload_name", [](FunctionSchema& self) { return self.overload_name(); })
356  .def_property_readonly(
357  "arguments", [](FunctionSchema& self) { return self.arguments(); })
358  .def_property_readonly(
359  "returns", [](FunctionSchema& self) { return self.returns(); });
360  py::class_<Argument>(m, "Argument")
361  .def_property_readonly("name", [](Argument& self) { return self.name(); })
362  .def_property_readonly("type", [](Argument& self) { return self.type(); })
363  .def_property_readonly(
364  "N",
365  [](Argument& self) -> py::object {
366  return (self.N()) ? py::cast(*self.N()) : py::none();
367  })
368  .def_property_readonly("default_value", [](Argument& self) -> py::object {
369  if (!self.default_value())
370  return py::none();
371  IValue v = *self.default_value();
372  return toPyObject(std::move(v));
373  });
374  m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) {
375  auto symbol = Symbol::fromQualString(qualified_name);
376  auto operations = getAllOperatorsFor(symbol);
377  return fmap(operations, [](const std::shared_ptr<Operator>& op) {
378  return op->schema();
379  });
380  });
381 
382  struct PythonFutureWrapper {
383  explicit PythonFutureWrapper(c10::intrusive_ptr<c10::ivalue::Future> fut)
384  : fut(std::move(fut)) {}
385 
387  };
388 
389  py::class_<PythonFutureWrapper>(m, "Future");
390 
391  m.def("fork", [](py::args args) {
392  AT_ASSERT(args.size() >= 1);
393 
394  py::function f = py::cast<py::function>(args[0]);
395  py::tuple args_tup(args.size() - 1);
396 
397  for (size_t i = 1; i < args.size(); ++i) {
398  args_tup[i - 1] = args[i];
399  }
400 
401  if (jit::tracer::isTracing()) {
402  auto graph = jit::tracer::getTracingState()->graph;
403  auto fork_node = graph->insertNode(graph->create(prim::fork, 1));
404  auto body_block = fork_node->addBlock();
405 
406  Value* node_output;
407  py::object py_func_output;
408  auto retval = c10::make_intrusive<c10::ivalue::Future>();
409  // Insert new trace ops into the fork op's sub-block
410  WithInsertPoint guard(body_block);
411  IValue output_ivalue;
412  {
413  tracer::WithNestedTracingFrame env_guard;
414 
415  // Run the user-supplied function
416  py_func_output = f(*args_tup);
417 
418  // Convert the output of the user-supplied funciton to IValue. The type
419  // information of this IValue is used both to record the correct type in
420  // the trace.
421  output_ivalue = toIValue(py_func_output);
422  Value* out_val = jit::tracer::getNestedValueTrace(output_ivalue);
423  body_block->registerOutput(out_val);
424  node_output =
425  fork_node->output()->setType(FutureType::create(out_val->type()));
426 
427  // Lambda lift into a Subgraph attribute
428  torch::jit::script::lambdaLiftFork(fork_node);
429  }
430 
431  // Record the ivalue in the tracer
432  jit::tracer::setValueTrace(retval, node_output);
433 
434  // stuff the ivalue output in the Future
435  retval->markCompleted(output_ivalue);
436 
437  return PythonFutureWrapper(retval);
438  } else {
439  auto retval = c10::make_intrusive<c10::ivalue::Future>();
440  retval->markCompleted(toIValue(f(*args_tup)));
441  return PythonFutureWrapper(retval);
442  }
443  });
444 
445  m.def("wait", [](PythonFutureWrapper& fut) {
446  if (jit::tracer::isTracing()) {
447  auto graph = jit::tracer::getTracingState()->graph;
448 
449  Value* fut_val = jit::tracer::getValueTrace(fut.fut);
450  auto output = graph->insert(aten::wait, {fut_val});
451  jit::tracer::setValueTrace(fut.fut->value(), output);
452  }
453  return fut.fut->value();
454  });
455 
456  m.def("_jit_assert_is_instance", [](py::object obj, TypePtr type) {
457  toIValue(obj, type);
458  });
459 
460  initPythonIRBindings(module);
461  tracer::initPythonTracerBindings(module);
462  script::initTreeViewBindings(module);
463  script::initJitScriptBindings(module);
464  initBatchTensorBindings(module);
465  initRegisterBatchOpsBindings(module);
466 }
467 
468 } // namespace jit
469 } // namespace torch
The primary ATen error class.
Definition: Exception.h:27
Definition: jit_type.h:17
const char * what_without_backtrace() const noexcept
Returns only the error message string, without source location.
Definition: Exception.h:79