1 #include <torch/csrc/utils/auto_gil.h> 2 #include <torch/csrc/utils/pybind.h> 4 #include <torch/csrc/jit/argument_spec.h> 5 #include <torch/csrc/jit/batched/BatchTensor.h> 6 #include <torch/csrc/jit/export.h> 7 #include <torch/csrc/jit/fuser/interface.h> 8 #include <torch/csrc/jit/fuser/kernel_cache.h> 9 #include <torch/csrc/jit/graph_executor.h> 10 #include <torch/csrc/jit/import.h> 11 #include <torch/csrc/jit/operator.h> 12 #include <torch/csrc/jit/passes/canonicalize.h> 13 #include <torch/csrc/jit/passes/canonicalize_ops.h> 14 #include <torch/csrc/jit/passes/common_subexpression_elimination.h> 15 #include <torch/csrc/jit/passes/constant_pooling.h> 16 #include <torch/csrc/jit/passes/constant_propagation.h> 17 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h> 18 #include <torch/csrc/jit/passes/dead_code_elimination.h> 19 #include <torch/csrc/jit/passes/erase_number_types.h> 20 #include <torch/csrc/jit/passes/graph_fuser.h> 21 #include <torch/csrc/jit/passes/inline_fork_wait.h> 22 #include <torch/csrc/jit/passes/loop_unrolling.h> 23 #include <torch/csrc/jit/passes/lower_tuples.h> 24 #include <torch/csrc/jit/passes/onnx.h> 25 #include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h> 26 #include <torch/csrc/jit/passes/onnx/peephole.h> 27 #include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h> 28 #include <torch/csrc/jit/passes/peephole.h> 29 #include <torch/csrc/jit/passes/remove_expands.h> 30 #include <torch/csrc/jit/passes/remove_inplace_ops.h> 31 #include <torch/csrc/jit/passes/shape_analysis.h> 32 #include <torch/csrc/jit/passes/specialize_autogradzero.h> 33 #include <torch/csrc/jit/passes/to_batch.h> 34 #include <torch/csrc/jit/passes/utils/check_alias_annotation.h> 35 #include <torch/csrc/jit/pybind_utils.h> 36 #include <torch/csrc/jit/python_arg_flatten.h> 37 #include <torch/csrc/jit/python_ir.h> 38 #include <torch/csrc/jit/python_tracer.h> 39 #include <torch/csrc/jit/script/compiler.h> 40 #include <torch/csrc/jit/script/init.h> 41 #include <torch/csrc/jit/script/jit_exception.h> 42 #include <torch/csrc/jit/script/python_tree_views.h> 43 #include <torch/csrc/jit/tracer.h> 45 #include <caffe2/serialize/inline_container.h> 47 #include <ATen/core/function_schema.h> 49 #include <pybind11/functional.h> 61 using ::c10::Argument;
62 using ::c10::FunctionSchema;
68 using autograd::variable_list;
70 bool loadPythonClasses() {
83 void runJITCPPTests() {
84 AT_ERROR(
"JIT tests not yet supported on Windows");
87 void runJITCPPTests();
90 void initJITBindings(PyObject* module) {
91 auto m = py::handle(module).cast<py::module>();
93 py::register_exception<JITException>(m,
"JITException");
95 py::class_<python::IODescriptor> iodescriptor(
98 m.def(
"_jit_init", loadPythonClasses)
100 "_jit_debug_fuser_num_cached_kernel_specs",
101 torch::jit::fuser::debugNumCachedKernelSpecs)
102 .def(
"_jit_pass_onnx", ToONNX)
103 .def(
"_jit_pass_lower_all_tuples", LowerAllTuples)
104 .def(
"_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
105 .def(
"_jit_pass_fuse", FuseGraph)
108 [](std::shared_ptr<Graph>& g) {
109 return EliminateDeadCode(g->block());
113 [](std::shared_ptr<Graph>& g) {
114 return EliminateCommonSubexpression(g);
117 "_jit_pass_remove_inplace_ops",
118 [](std::shared_ptr<Graph> g) {
return RemoveInplaceOps(g); })
119 .def(
"_jit_pass_constant_pooling", ConstantPooling)
121 "_jit_pass_peephole",
122 [](
const std::shared_ptr<Graph>& g,
bool addmm_fusion_enabled) {
123 return PeepholeOptimize(g, addmm_fusion_enabled);
126 py::arg(
"addmm_fusion_enabled") =
false)
128 "_jit_pass_canonicalize",
129 [](
const std::shared_ptr<Graph>& g) {
return Canonicalize(g); })
130 .def(
"_jit_pass_lint", LintGraph)
132 "_jit_pass_shape_analysis",
133 [](std::shared_ptr<Graph> graph,
134 std::vector<at::Tensor> inputs,
138 ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
139 PropagateInputShapes(graph);
142 "_jit_pass_complete_shape_analysis",
143 [](std::shared_ptr<Graph> graph, py::tuple inputs,
bool with_grad) {
144 CompleteArgumentSpec spec(
146 evilDeprecatedBadCreateStackDoNotUse(inputs, graph->inputs()));
147 auto graph_inputs = graph->inputs();
148 AT_ASSERT(spec.size() == graph_inputs.size());
149 for (
size_t i = 0; i < graph_inputs.size(); ++i) {
150 graph_inputs[i]->setType(spec.at(i));
152 PropagateInputShapes(graph);
154 .def(
"_jit_pass_remove_expands", RemoveExpands)
155 .def(
"_jit_pass_erase_number_types", EraseNumberTypes)
156 .def(
"_jit_pass_inline_fork_wait", InlineForkWait)
157 .def(
"_jit_pass_prepare_division_for_onnx", PrepareDivisionForONNX)
158 .def(
"_jit_pass_loop_unrolling", UnrollLoops)
160 "_jit_pass_constant_propagation",
161 [](std::shared_ptr<Graph>& g) {
return ConstantPropagation(g); })
162 .def(
"_jit_pass_erase_shape_information", EraseShapeInformation)
164 "_jit_pass_create_autodiff_subgraphs",
165 [](std::shared_ptr<Graph> graph) { CreateAutodiffSubgraphs(graph); })
167 "_jit_run_cpp_tests",
174 return runJITCPPTests();
178 [](py::handle& obj) {
179 auto res = python::flatten(obj);
180 return std::make_pair(res.vars, res.desc);
184 [](autograd::variable_list vars, python::IODescriptor& desc) {
185 return py::reinterpret_steal<py::object>(
186 python::unflatten(vars, desc));
188 .def(
"_jit_pass_onnx_block", BlockToONNX)
189 .def(
"_jit_pass_fixup_onnx_loops", FixupONNXLoops)
190 .def(
"_jit_pass_canonicalize_ops", CanonicalizeOps)
191 .def(
"_jit_pass_specialize_autogradzero", specializeAutogradZero)
192 .def(
"_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
194 "_jit_differentiate",
199 auto g_clone = g.copy();
200 return differentiate(g_clone);
203 "_jit_check_alias_annotation",
204 [](std::shared_ptr<Graph> g,
206 const std::string& unqualified_op_name) {
207 auto stack = toStack(args);
208 checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
212 py::class_<CompleteArgumentSpec>(m,
"CompleteArgumentSpec")
213 .def(
"__repr__", [](CompleteArgumentSpec&
self) {
214 std::ostringstream s;
219 py::class_<ArgumentSpec>(m,
"ArgumentSpec");
220 py::class_<Code>(m,
"Code").def(
"grad_executors", [](Code& c) {
221 return py::make_iterator(
222 c.grad_executors().begin(), c.grad_executors().end());
225 py::class_<ExecutionPlanState>(m,
"ExecutionPlanState")
226 .def_property_readonly(
227 "graph", [](ExecutionPlanState& s) {
return s.graph; })
228 .def_property_readonly(
229 "code", [](ExecutionPlanState& s) {
return s.code; });
231 py::class_<Gradient>(m,
"Gradient")
232 .def_property_readonly(
"f", [](Gradient& m) {
return m.f; })
233 .def_property_readonly(
"df", [](Gradient& m) {
return m.df; })
234 .def_property_readonly(
235 "f_real_outputs", [](Gradient& m) {
return m.f_real_outputs; })
236 .def_property_readonly(
237 "df_input_vjps", [](Gradient& m) {
return m.df_input_vjps; })
238 .def_property_readonly(
239 "df_input_captured_inputs",
240 [](Gradient& m) {
return m.df_input_captured_inputs; })
241 .def_property_readonly(
242 "df_input_captured_outputs",
243 [](Gradient& m) {
return m.df_input_captured_outputs; })
244 .def_property_readonly(
245 "df_output_vjps", [](Gradient& m) {
return m.df_output_vjps; });
247 py::class_<GraphExecutorState>(m,
"GraphExecutorState")
248 .def_property_readonly(
249 "graph", [](GraphExecutorState& s) {
return s.graph; })
250 .def_property_readonly(
252 [](GraphExecutorState& s) {
return s.execution_plans; })
253 .def_property_readonly(
254 "fallback", [](GraphExecutorState& s) {
return s.fallback; });
256 py::class_<GraphExecutor>(m,
"GraphExecutor", py::dynamic_attr())
258 py::init([](py::function func,
260 py::function var_name_lookup_fn,
262 bool _force_outplace) {
263 auto graph = tracer::createGraphByTracing(
264 func, toStack(inputs), var_name_lookup_fn, _force_outplace);
265 return GraphExecutor(graph, optimize);
269 py::arg(
"var_name_lookup_fn"),
270 py::arg(
"optimize") =
true,
271 py::arg(
"_force_outplace") =
false)
273 py::init([](std::shared_ptr<Graph> graph,
bool optimize) {
274 return GraphExecutor(std::move(graph), optimize);
277 py::arg(
"optimize") =
true)
280 [](GraphExecutor& ge, py::args args) {
281 return ge.graphFor(evilDeprecatedBadCreateStackDoNotUse(
282 args, ge.graph()->inputs()));
284 .def_property_readonly(
285 "graph", [](GraphExecutor& ge) {
return ge.graph(); })
288 [](GraphExecutor& ge) {
return ge.getDebugState(); })
289 .def(
"__call__", [](GraphExecutor& ge, py::args args) -> py::object {
290 const auto& graph = ge.graph();
292 evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs());
297 return createPyObjectForStack(std::move(stack));
300 py::class_<PyTorchStreamWriter>(m,
"PyTorchFileWriter")
301 .def(py::init<std::string>())
304 [](PyTorchStreamWriter&
self,
305 const std::string& name,
307 size_t size) {
return self.writeRecord(name, data, size); })
308 .def(
"write_end_of_file", &PyTorchStreamWriter::writeEndOfFile);
310 py::class_<PyTorchStreamReader>(m,
"PyTorchFileReader")
311 .def(py::init<std::string>())
312 .def(
"get_record", [](PyTorchStreamReader&
self,
const std::string& key) {
315 std::tie(data, size) =
self.getRecord(key);
316 return py::bytes(reinterpret_cast<const char*>(data.get()), size);
320 "_jit_get_operation",
321 [](
const std::string& qualified_name) {
323 auto symbol = Symbol::fromQualString(qualified_name);
324 auto operations = getAllOperatorsFor(symbol);
325 AT_CHECK(!operations.empty(),
"No such operator ", qualified_name);
327 operations.size() == 1,
330 " overloads for operator ",
332 "! Overloads are not supported from Python.");
333 std::shared_ptr<Operator> op = operations[0];
334 AT_ASSERT(op !=
nullptr);
335 std::ostringstream docstring;
336 docstring <<
"Automatically bound operator '" << qualified_name
337 <<
"' with schema: " << op->schema();
338 return py::cpp_function(
339 [op](py::args args, py::kwargs kwargs) {
340 return invokeOperatorFromPython(
341 *op, std::move(args), std::move(kwargs));
343 py::name(qualified_name.c_str()),
344 py::doc(docstring.str().c_str()));
349 py::arg(
"qualified_name"));
351 py::class_<FunctionSchema>(m,
"FunctionSchema")
352 .def_property_readonly(
353 "name", [](FunctionSchema&
self) {
return self.name(); })
354 .def_property_readonly(
355 "overload_name", [](FunctionSchema&
self) {
return self.overload_name(); })
356 .def_property_readonly(
357 "arguments", [](FunctionSchema&
self) {
return self.arguments(); })
358 .def_property_readonly(
359 "returns", [](FunctionSchema&
self) {
return self.returns(); });
360 py::class_<Argument>(m,
"Argument")
361 .def_property_readonly(
"name", [](Argument&
self) {
return self.name(); })
362 .def_property_readonly(
"type", [](Argument&
self) {
return self.type(); })
363 .def_property_readonly(
365 [](Argument&
self) -> py::object {
366 return (
self.N()) ? py::cast(*
self.N()) :
py::none();
368 .def_property_readonly(
"default_value", [](Argument&
self) -> py::object {
369 if (!
self.default_value())
371 IValue v = *
self.default_value();
372 return toPyObject(std::move(v));
374 m.def(
"_jit_get_schemas_for_operator", [](
const std::string& qualified_name) {
375 auto symbol = Symbol::fromQualString(qualified_name);
376 auto operations = getAllOperatorsFor(symbol);
377 return fmap(operations, [](
const std::shared_ptr<Operator>& op) {
382 struct PythonFutureWrapper {
384 : fut(
std::move(fut)) {}
389 py::class_<PythonFutureWrapper>(m,
"Future");
391 m.def(
"fork", [](py::args args) {
392 AT_ASSERT(args.size() >= 1);
394 py::function f = py::cast<py::function>(args[0]);
395 py::tuple args_tup(args.size() - 1);
397 for (
size_t i = 1; i < args.size(); ++i) {
398 args_tup[i - 1] = args[i];
401 if (jit::tracer::isTracing()) {
402 auto graph = jit::tracer::getTracingState()->graph;
403 auto fork_node = graph->insertNode(graph->create(prim::fork, 1));
404 auto body_block = fork_node->addBlock();
407 py::object py_func_output;
408 auto retval = c10::make_intrusive<c10::ivalue::Future>();
410 WithInsertPoint guard(body_block);
411 IValue output_ivalue;
413 tracer::WithNestedTracingFrame env_guard;
416 py_func_output = f(*args_tup);
421 output_ivalue = toIValue(py_func_output);
422 Value* out_val = jit::tracer::getNestedValueTrace(output_ivalue);
423 body_block->registerOutput(out_val);
425 fork_node->output()->setType(FutureType::create(out_val->type()));
428 torch::jit::script::lambdaLiftFork(fork_node);
432 jit::tracer::setValueTrace(retval, node_output);
435 retval->markCompleted(output_ivalue);
437 return PythonFutureWrapper(retval);
439 auto retval = c10::make_intrusive<c10::ivalue::Future>();
440 retval->markCompleted(toIValue(f(*args_tup)));
441 return PythonFutureWrapper(retval);
445 m.def(
"wait", [](PythonFutureWrapper& fut) {
446 if (jit::tracer::isTracing()) {
447 auto graph = jit::tracer::getTracingState()->graph;
449 Value* fut_val = jit::tracer::getValueTrace(fut.fut);
450 auto output = graph->insert(aten::wait, {fut_val});
451 jit::tracer::setValueTrace(fut.fut->value(), output);
453 return fut.fut->value();
456 m.def(
"_jit_assert_is_instance", [](py::object obj, TypePtr type) {
460 initPythonIRBindings(module);
461 tracer::initPythonTracerBindings(module);
462 script::initTreeViewBindings(module);
463 script::initJitScriptBindings(module);
464 initBatchTensorBindings(module);
465 initRegisterBatchOpsBindings(module);
The primary ATen error class.
const char * what_without_backtrace() const noexcept
Returns only the error message string, without source location.