1 #include <torch/csrc/jit/script/init.h> 3 #include <torch/csrc/Device.h> 4 #include <torch/csrc/Dtype.h> 5 #include <torch/csrc/Layout.h> 6 #include <torch/csrc/jit/import.h> 7 #include <torch/csrc/jit/script/compiler.h> 8 #include <torch/csrc/jit/script/module.h> 9 #include <torch/csrc/jit/script/schema_matching.h> 10 #include <torch/csrc/jit/script/sugared_value.h> 11 #include <torch/csrc/jit/testing/file_check.h> 13 #include <torch/csrc/jit/constants.h> 14 #include <torch/csrc/jit/hooks_for_testing.h> 15 #include <torch/csrc/jit/import_source.h> 16 #include <torch/csrc/jit/passes/python_print.h> 17 #include <torch/csrc/jit/passes/to_batch.h> 18 #include <torch/csrc/jit/pybind_utils.h> 19 #include <torch/csrc/jit/python_tracer.h> 20 #include <torch/csrc/jit/script/parser.h> 22 #include <torch/csrc/api/include/torch/ordered_dict.h> 24 #include <ATen/ATen.h> 25 #include <ATen/core/function_schema.h> 27 #include <pybind11/functional.h> 28 #include <pybind11/pybind11.h> 29 #include <pybind11/stl.h> 30 #include <pybind11/stl_bind.h> 39 PYBIND11_MAKE_OPAQUE(torch::jit::script::ExtraFilesMap);
45 using ::c10::Argument;
46 using ::c10::FunctionSchema;
48 using ResolutionCallback = std::function<py::function(std::string)>;
49 using FunctionDefaults = std::unordered_map<std::string, py::object>;
51 static std::string typeString(py::handle h) {
52 return py::str(h.get_type().attr(
"__name__"));
55 inline std::shared_ptr<SugaredValue> toSimple(Value* v) {
56 return std::make_shared<SimpleValue>(v);
62 std::shared_ptr<SugaredValue> toSugaredValue(
66 bool is_constant =
false,
67 bool is_submodule =
false);
70 PythonValue(py::object
self) :
self(std::move(
self)) {}
72 FunctionSchema getSchema(
const size_t n_args,
const size_t n_binders) {
73 auto annotations = py::module::import(
"torch.jit.annotations");
74 auto signature = annotations.attr(
"get_signature")(
self);
75 std::vector<Argument> args, rets;
78 size_t actual_n_args = n_args;
79 if (!signature.is_none()) {
80 std::vector<TypePtr> arg_types;
82 std::tie(arg_types, ret_type) =
83 py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
84 args.reserve(arg_types.size());
86 for (
auto& arg_type : arg_types) {
88 std::to_string(idx++), std::move(arg_type), {}, {},
false));
90 rets.push_back(
Argument(
"0", std::move(ret_type), {}, {},
false));
96 auto num_params = annotations.attr(
"get_num_params")(
self);
97 if (!num_params.is_none()) {
101 actual_n_args = py::cast<size_t>(num_params);
105 args.reserve(actual_n_args);
106 for (
size_t i = 0; i < actual_n_args; ++i) {
108 Argument(std::to_string(i), TensorType::get(), {}, {},
false));
110 TypePtr ret_type = TensorType::get();
111 if (n_binders == 0) {
112 ret_type = NoneType::get();
113 }
else if (n_binders > 1) {
114 std::vector<TypePtr> tuple_values(n_binders, ret_type);
115 ret_type = TupleType::create(std::move(tuple_values));
117 rets.push_back(
Argument(
"0", ret_type, {}, {},
false));
123 std::shared_ptr<SugaredValue> call(
128 size_t n_binders)
override {
129 auto inputs = toValues(*m.graph(), inputs_);
130 auto schema = getSchema(inputs.size(), n_binders);
132 std::stringstream failure_messages;
146 py::object func =
self;
147 std::string cconv(inputs.size(),
'd');
148 Node* new_node = m.graph()->insertNode(m.graph()->createPythonOp(
152 if (py::cast<bool>(py::module::import(
"torch.jit")
153 .attr(
"_try_get_ignored_op")(
self))) {
154 auto python_op =
static_cast<PythonOp*
>(new_node);
155 python_op->ignore_on_export =
true;
157 new_node->setSourceLocation(std::make_shared<SourceRange>(loc));
158 for (
auto& i : matched_schema->inputs)
159 new_node->addInput(i);
162 new_node->addOutput()->setType(matched_schema->return_types.at(0));
163 return std::make_shared<SimpleValue>(output);
166 std::string kind()
const override {
167 std::stringstream ss;
168 ss <<
"python value of type '" << typeString(
self) <<
"'";
172 std::vector<std::shared_ptr<SugaredValue>> asTuple(
176 const std::string type_str = typeString(
self);
177 std::stringstream ss;
178 ss << kind() <<
" cannot be used as a tuple";
179 auto nn = py::module::import(
"torch.nn");
180 if (py::isinstance(
self, nn.attr(
"ModuleList")) ||
181 py::isinstance(
self, nn.attr(
"Sequential"))) {
182 ss <<
". Did you forget to add it to __constants__? ";
188 py::object getattr(
const SourceRange& loc,
const std::string& name) {
190 return py::getattr(
self, name.c_str());
191 }
catch (py::error_already_set& e) {
192 throw ErrorReport(loc) <<
"object has no attribute " << name;
202 std::shared_ptr<SugaredValue> attr(
205 const std::string& field)
override {
206 py::object member = getattr(loc, field);
210 return toSugaredValue(member, m, loc,
true);
217 std::vector<std::shared_ptr<SugaredValue>> asTuple(
221 py::tuple tup =
self;
222 std::vector<std::shared_ptr<SugaredValue>> result;
223 result.reserve(tup.size());
224 for (py::handle t : tup) {
225 py::object obj = py::reinterpret_borrow<py::object>(t);
226 result.push_back(toSugaredValue(obj, m, loc,
true));
232 std::vector<Value*> values;
233 for (
const auto& sugared_item : asTuple(loc, m)) {
234 values.push_back(sugared_item->asValue(loc, m));
236 auto node = m.graph()->createTuple(values);
237 return m.graph()->insertNode(node)->output();
244 : module_(std::move(module)) {}
246 std::string kind()
const override {
247 return "constant parameter list";
250 std::shared_ptr<SugaredValue> call(
255 size_t n_binders)
override {
257 std::vector<Value*> params;
258 const auto& param_list = module_->get_parameters().items();
259 for (
auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
261 params.push_back(caller.get_or_add_parameter(param->slot()));
263 auto list = caller.graph()->createList(TensorType::get(), params);
264 caller.graph()->insertNode(list);
265 return toSimple(list->output());
269 std::shared_ptr<Module> module_;
280 ModuleValue(std::shared_ptr<Module> module) : module(std::move(module)) {}
282 std::string kind()
const override {
287 std::shared_ptr<SugaredValue> attr(
290 const std::string& field)
override {
294 if (field ==
"training") {
297 py::object py_module = py::cast(module);
298 bool training = py::cast<bool>(py::getattr(py_module,
"training"));
300 autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong));
301 module->register_buffer(
"training", std::move(t));
302 v = module->find_buffer(field);
304 Value* the_tensor = m.get_or_add_parameter(v->slot());
305 Value* the_bool = m.graph()->insert(prim::Bool, {the_tensor});
306 return std::make_shared<SimpleValue>(the_bool);
310 return std::make_shared<ModuleValue>(v->module);
311 }
else if (
Method* v = module->find_method(field)) {
312 return std::make_shared<MethodValue>(shared_from_this(), *v);
313 }
else if (
NamedIValue* v = module->find_parameter(field)) {
314 return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
315 }
else if (
NamedIValue* v = module->find_attribute(field)) {
316 return std::make_shared<SimpleValue>(
317 m.get_or_add_attribute(v->type, v->slot()));
322 py::object py_module = py::cast(module);
323 if (py::object attr = py::getattr(py_module, field.c_str(), py::none())) {
324 if (py::isinstance<py::function>(attr) &&
325 py::hasattr(attr,
"_is_parameter_list") &&
326 py::cast<bool>(py::getattr(attr,
"_is_parameter_list"))) {
327 return std::make_shared<ConstantParameterList>(module);
329 if (py::isinstance<py::function>(attr) ||
330 py::isinstance(attr, py::module::import(
"torch.nn").attr(
"Module")) ||
331 py_module.attr(
"_constants_set").contains(field.c_str())) {
332 return toSugaredValue(attr, m, loc,
true);
334 std::string hint =
"did you forget to add it __constants__?";
335 if (py::isinstance(attr, py::module::import(
"torch").attr(
"Tensor"))) {
336 hint =
"Tensors must be added to a module as a buffer or parameter";
339 <<
"attribute '" << field <<
"' of type '" << typeString(attr)
340 <<
"' is not usable in a script method (" << hint <<
")";
343 throw ErrorReport(loc) <<
"module has no attribute '" << field <<
"'";
347 std::shared_ptr<SugaredValue> call(
352 size_t n_binders)
override {
353 return attr(loc, caller,
"forward")
354 ->call(loc, caller, inputs, attributes, n_binders);
357 std::vector<std::shared_ptr<SugaredValue>> asTuple(
361 py::object py_module = py::cast(module);
364 py::module::import(
"torch.jit").attr(
"_ConstModuleList")))
365 return SugaredValue::asTuple(loc, m, size_hint);
366 std::vector<std::shared_ptr<SugaredValue>> result;
367 for (py::handle module : py_module) {
368 py::object obj = py::reinterpret_borrow<py::object>(module);
369 result.push_back(toSugaredValue(
380 std::shared_ptr<Module> module;
385 : dispatched_fn_(std::move(dispatched_fn)) {}
387 std::string kind()
const override {
388 return "boolean dispatch";
391 std::shared_ptr<SugaredValue> call(
396 size_t n_binders)
override {
398 Graph& graph = *(caller.graph());
400 auto index = py::cast<size_t>(dispatched_fn_[
"index"]);
401 auto arg_name = py::str(dispatched_fn_[
"arg_name"]);
403 if (index < inputs.
size()) {
405 result = constant_as<bool>(inputs.
at(index).value(graph));
406 }
else if (
auto i = findInputWithName(arg_name, attributes)) {
408 result = constant_as<bool>(attributes[*i].value(graph));
411 result = py::cast<bool>(dispatched_fn_[
"default"]);
415 throw ErrorReport(loc) <<
"value for boolean dispatch was not constant";
418 std::shared_ptr<SugaredValue> value;
420 value = toSugaredValue(dispatched_fn_[
"if_true"], caller, loc);
422 value = toSugaredValue(dispatched_fn_[
"if_false"], caller, loc);
424 return value->call(loc, caller, inputs, attributes, n_binders);
428 py::dict dispatched_fn_;
433 : possible_functions_(std::move(functions)) {}
435 std::string kind()
const override {
436 return "overloaded function";
439 std::shared_ptr<SugaredValue> call(
444 size_t n_binders)
override {
445 std::stringstream err;
446 auto possible_functions =
447 py::cast<std::vector<py::object>>(possible_functions_);
449 for (
const py::object& fn : possible_functions) {
450 auto& method = py::cast<Method&>(fn);
451 auto match = tryMatchSchema(
454 *caller.graph().get(),
462 .call(loc, caller, inputs, attributes, n_binders);
465 throw ErrorReport(loc) <<
"Could not find any matching overloads\n" 470 py::list possible_functions_;
473 std::shared_ptr<SugaredValue> toSugaredValue(
484 auto& g = *m.graph();
486 if (py::isinstance<py::bool_>(obj)) {
487 return toSimple(g.insertConstant(py::cast<bool>(obj),
nullptr, loc));
488 }
else if (py::isinstance<py::int_>(obj)) {
489 return toSimple(g.insertConstant(py::cast<int64_t>(obj),
nullptr, loc));
490 }
else if (py::isinstance<py::float_>(obj)) {
491 return toSimple(g.insertConstant(py::cast<double>(obj),
nullptr, loc));
492 }
else if (py::isinstance<py::str>(obj)) {
494 g.insertConstant(py::cast<std::string>(obj),
nullptr, loc));
495 }
else if (obj.is(py::none())) {
496 return toSimple(g.insertConstant(
IValue(),
nullptr, loc));
497 }
else if (THPDevice_Check(obj.ptr())) {
498 auto device =
reinterpret_cast<THPDevice*
>(obj.ptr());
499 return toSimple(g.insertConstant(device->device));
500 }
else if (THPLayout_Check(obj.ptr())) {
501 auto layout =
reinterpret_cast<THPLayout*
>(obj.ptr());
502 const auto v =
static_cast<int64_t
>(layout->layout);
503 return toSimple(g.insertConstant(v,
nullptr, loc));
504 }
else if (THPDtype_Check(obj.ptr())) {
505 auto dtype =
reinterpret_cast<THPDtype*
>(obj.ptr());
506 const auto v =
static_cast<int64_t
>(dtype->scalar_type);
507 return toSimple(g.insertConstant(v,
nullptr, loc));
508 }
else if (py::isinstance<py::tuple>(obj)) {
509 return std::make_shared<ConstantPythonTupleValue>(obj);
514 py::module::import(
"torch.jit").attr(
"_try_get_weak_module")(obj);
515 if (!weak_obj.is_none()) {
518 if (py::isinstance<Module>(obj)) {
519 auto mod = py::cast<std::shared_ptr<Module>>(obj);
524 if (!is_submodule && mod->get_parameters().size() != 0) {
526 <<
"Attempted to inline a Module with parameters. " 527 "Stateful modules to be inlined must be submodules of the callee.";
529 const auto script_class_type =
530 py::module::import(
"torch.jit").attr(
"ScriptClass");
531 const bool is_class_type = py::isinstance(obj, script_class_type);
533 const auto classname = py::cast<std::string>(py::getattr(obj,
"_name"));
534 auto classType = ClassType::get(classname);
535 AT_ASSERT(classType);
536 return std::make_shared<ClassValue>(std::move(classType));
538 return std::make_shared<ModuleValue>(mod);
539 }
else if (py::isinstance<py::module>(obj)) {
540 return std::make_shared<PythonModuleValue>(obj);
541 }
else if (obj.ptr() == py::module::import(
"torch.jit").attr(
"_fork").ptr()) {
542 return std::make_shared<ForkValue>();
544 obj.ptr() == py::module::import(
"torch.jit").attr(
"annotate").ptr()) {
545 return std::make_shared<AnnotateValue>();
548 py::object builtin_name =
549 py::module::import(
"torch.jit").attr(
"_find_builtin")(obj);
550 if (!builtin_name.is_none()) {
551 return std::make_shared<BuiltinFunction>(
552 Symbol::fromQualString(py::str(builtin_name)), c10::nullopt);
555 if (py::isinstance<py::function>(obj)) {
557 py::module::import(
"torch.jit").attr(
"_try_compile_weak_script")(obj);
558 if (!compiled_fn.is(py::none())) {
559 auto mod = py::cast<std::shared_ptr<Module>>(compiled_fn);
560 return std::make_shared<ModuleValue>(mod);
564 py::object dispatched_fn =
565 py::module::import(
"torch.jit").attr(
"_try_get_dispatched_fn")(obj);
566 if (!dispatched_fn.is_none()) {
567 return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
570 py::object overloads =
571 py::module::import(
"torch.jit").attr(
"_try_get_overloaded_fn")(obj);
572 if (!overloads.is_none()) {
573 return std::make_shared<OverloadedFunctionValue>(std::move(overloads));
576 return std::make_shared<PythonValue>(obj);
579 py::object unpackVariableTensorList(std::vector<at::Tensor> outputs) {
583 if (outputs.size() == 0) {
585 }
else if (outputs.size() == 1) {
586 return py::cast(autograd::as_variable_ref(outputs[0]));
588 py::tuple tuple(outputs.size());
589 for (
size_t i = 0; i < outputs.size(); i++) {
590 tuple[i] = py::cast(autograd::as_variable_ref(outputs[i]));
592 return std::move(tuple);
596 static void gatherParametersAndBuffers(
597 std::vector<IValue*>& values,
599 for (
auto& param : m.get_parameters()) {
600 values.push_back(param->slot());
602 for (
auto& param : m.get_attributes()) {
603 if (param->type->isSubtypeOf(TensorType::get())) {
604 values.push_back(param->slot());
607 for (
const auto& sub : m.get_modules()) {
608 gatherParametersAndBuffers(values, *sub->module);
614 Resolver pythonResolver(
const ResolutionCallback& rcb) {
616 -> std::shared_ptr<SugaredValue> {
618 py::object obj = rcb(name);
619 if (obj.is(py::none())) {
622 return toSugaredValue(obj, m, loc);
631 const FunctionDefaults& default_args) {
632 std::vector<Argument> new_args;
633 for (
auto& arg : schema.arguments()) {
634 auto it = default_args.find(arg.name());
635 if (it != default_args.end()) {
639 auto list_type = arg.type()->cast<
ListType>();
640 if (n && *n > 0 && list_type) {
642 value = toIValue(it->second, list_type->getElementType());
644 value = toIValue(it->second, arg.type());
646 new_args.emplace_back(
647 arg.name(), arg.type(), arg.N(), value, arg.kwarg_only());
648 }
catch (py::cast_error& e) {
650 <<
"Expected a default value of type " << arg.type()->str()
651 <<
" on parameter \"" << arg.name() <<
"\"";
654 new_args.push_back(arg);
659 new_name.value_or(schema.name()),
660 schema.overload_name(),
667 void initJitScriptBindings(PyObject* module) {
668 auto m = py::handle(module).cast<py::module>();
672 py::bind_map<ExtraFilesMap>(m,
"ExtraFilesMap");
677 py::class_<Module, std::shared_ptr<Module>>(m,
"ScriptModule")
681 [](std::shared_ptr<Module> m,
682 const std::string& filename,
683 const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
684 m->save(filename, _extra_files);
687 py::arg(
"_extra_files") = ExtraFilesMap())
690 [](std::shared_ptr<Module> m,
691 const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
692 std::ostringstream buf;
693 m->save(buf, _extra_files);
694 return py::bytes(buf.str());
696 py::arg(
"_extra_files") = ExtraFilesMap())
697 .def(
"_set_optimized", &Module::set_optimized)
700 [](std::shared_ptr<Module> m,
701 const std::string& script,
702 ResolutionCallback rcb,
706 self =
Self(std::make_shared<ModuleValue>(m));
708 defineMethodsInModule(m, script, pythonResolver(rcb),
self);
712 [](std::shared_ptr<Module> m,
713 const std::vector<Def>& defs,
714 const std::vector<ResolutionCallback>& rcbs,
715 const std::vector<FunctionDefaults>& defaults) {
716 std::vector<Resolver> resolvers;
717 resolvers.reserve(rcbs.size());
718 for (
auto& callback : rcbs) {
719 resolvers.push_back(pythonResolver(callback));
721 defineMethodsInModule(
722 m, defs, resolvers,
Self(std::make_shared<ModuleValue>(m)));
725 auto defaults_it = defaults.begin();
726 auto defs_it = defs.begin();
727 while (defs_it != defs.end()) {
728 auto& method = m->get_method((*defs_it).name().name());
729 method.setSchema(getSchemaWithNameAndDefaults(
737 didFinishEmitModule(m);
741 [](
Module&
self,
const std::string& name) ->
const Method& {
742 return self.get_method(name);
744 py::return_value_policy::reference_internal)
745 .def(
"_register_parameter", &Module::register_parameter)
747 "_register_attribute",
748 [](
Module&
self, std::string name, TypePtr type, py::object value) {
749 self.register_attribute(name, type, toIValue(value, type));
751 .def(
"_register_module", &Module::register_module)
752 .def(
"_register_buffer", &Module::register_buffer)
753 .def(
"_set_parameter", &Module::set_parameter)
754 .def(
"_get_parameter", &Module::get_parameter)
755 .def(
"_get_buffer", &Module::get_buffer)
756 .def(
"_get_module", &Module::get_module)
759 [](
Module&
self) -> py::tuple {
760 auto& modules =
self.get_modules();
761 py::tuple result(modules.size());
762 for (
size_t i = 0; i < modules.size(); ++i) {
763 auto& item = modules[i];
764 result[i] = std::make_pair(item.key(), item.value().module);
770 [](
Module&
self) -> py::tuple {
771 auto& parameters =
self.get_parameters();
772 py::tuple result(parameters.size());
773 for (
size_t i = 0; i < parameters.size(); ++i) {
774 auto& p = parameters[i];
776 result[i] = std::make_tuple(
777 p.key(), autograd::as_variable_ref(p->slot()->toTensor()));
783 [](
Module&
self) -> py::tuple {
784 auto& attributes =
self.get_attributes();
785 py::tuple result(attributes.size());
786 for (
size_t i = 0; i < attributes.size(); ++i) {
787 auto& buffer = attributes[i];
789 IValue v = *buffer->slot();
790 result[i] = std::make_tuple(
791 buffer.key(), buffer->type, toPyObject(std::move(v)));
797 [](
Module&
self,
const std::string& name) ->
bool {
798 return self.find_parameter(name);
802 [](
Module&
self,
const std::string& name) ->
bool {
803 return self.find_buffer(name);
807 [](
Module&
self,
const std::string& name) {
808 return bool(
self.find_module(name));
812 [](
Module&
self,
const std::string& name) {
813 return bool(
self.find_method(name));
820 return fmap(
self.get_methods(), [](
const Item& item) {
821 return (*item)->name();
825 "_create_method_from_graph",
827 const std::string& name,
828 std::shared_ptr<Graph> graph) {
829 self.create_method(name, std::move(graph), {});
832 "_create_method_from_trace",
833 [](std::shared_ptr<Module>
self,
834 const std::string& name,
836 py::tuple input_tuple,
837 py::function var_lookup_fn,
838 bool force_outplace) {
841 std::vector<IValue*> parameters;
842 gatherParametersAndBuffers(parameters, *
self);
843 Stack inputs = toStack(input_tuple);
844 for (
IValue* param : parameters) {
845 inputs.emplace_back(*param);
847 auto graph = tracer::createGraphByTracing(
853 self->create_method(name, std::move(graph), std::move(parameters));
854 didFinishEmitModule(
self);
858 [](py::args args, py::kwargs kwargs) {
864 Module&
self = py::cast<Module&>(args[0]);
865 if (
self.find_method(
"forward")) {
866 Method& m =
self.get_method(
"forward");
867 return m.graph_for(createStackForSchema(
868 m.getSchema(),
tuple_slice(std::move(args), 1), kwargs));
870 throw std::runtime_error(
871 "Attempted to call graph_for on a Module without a compiled forward()");
876 if (
self.find_method(
"forward")) {
877 Method& m =
self.get_method(
"forward");
878 return m.getDebugState();
880 throw std::runtime_error(
881 "Attempted to call get_debug_state on a Module without a compiled forward()");
884 "debug_disable_autodiff_subgraph_inlining",
886 if (
self.find_method(
"forward")) {
887 Method& m =
self.get_method(
"forward");
888 m.debugDisableAutodiffSubgraphInlining();
893 [](py::args args, py::kwargs kwargs) {
902 Module&
self = py::cast<Module&>(args[0]);
903 return invokeScriptMethodFromPython(
904 self.get_method(
"forward"),
911 std::ostringstream ss;
912 std::vector<at::Tensor> tensors;
913 std::vector<ClassTypePtr> classes;
914 PythonPrint(ss,
self, tensors, classes,
true);
915 return std::make_pair(ss.str(), tensors);
917 .def_property_readonly(
920 std::ostringstream ss;
921 std::vector<at::Tensor> tensors;
922 std::vector<ClassTypePtr> classes;
923 PythonPrint(ss,
self, tensors, classes,
false);
926 .def(
"apply", &Module::apply)
927 .def(
"_copy_into", &Module::copy_into)
930 [](std::shared_ptr<Module> m,
932 std::vector<std::tuple<std::shared_ptr<Module>, std::string>>
934 std::shared_ptr<Module> orig) {
935 std::vector<IValue*> member_inputs;
936 for (
auto& p : params) {
937 NamedIValue* np = std::get<0>(p)->find_parameter(std::get<1>(p));
939 np = std::get<0>(p)->find_buffer(std::get<1>(p));
941 AT_ASSERT(np !=
nullptr);
942 member_inputs.push_back(np->slot());
945 Method* orig_method = orig->find_method(name);
946 m->create_method(name, orig_method->graph()->copy(), member_inputs);
949 py::class_<Method>(m,
"ScriptMethod", py::dynamic_attr())
950 .def(
"graph", [&](
Method&
self) {
return self.graph(); })
953 [](py::args args, py::kwargs kwargs) {
955 Method& method = py::cast<Method&>(args[0]);
956 return invokeScriptMethodFromPython(
957 method,
tuple_slice(std::move(args), 1), std::move(kwargs));
959 .def_property_readonly(
"graph", [](
Method& m) {
return m.graph(); })
960 .def(
"propagate_shapes", &Method::propagate_shapes)
962 "propagate_and_assign_input_and_output_shapes",
963 &Method::propagate_and_assign_input_and_output_shapes)
964 .def(
"initial_ivalues", &Method::initial_ivalues)
967 [](py::args args, py::kwargs kwargs) {
969 Method&
self = py::cast<Method&>(args[0]);
970 return self.graph_for(createStackForSchema(
971 self.getSchema(),
tuple_slice(std::move(args), 1), kwargs));
974 "debug_disable_autodiff_subgraph_inlining",
975 &Method::debugDisableAutodiffSubgraphInlining)
976 .def(
"schema", &Method::getSchema)
977 .def(
"pretty_print_schema", &Method::pretty_print_schema)
978 .def(
"python_print", [](
Method& m) {
979 std::ostringstream oss;
980 std::vector<at::Tensor> constants;
981 std::vector<ClassTypePtr> classes;
982 PythonPrint(oss, m, constants, classes,
true);
983 return std::make_pair(oss.str(), std::move(constants));
987 "_jit_script_compile",
988 [](std::shared_ptr<Module> mod,
990 ResolutionCallback rcb,
991 FunctionDefaults defaults) {
992 auto def_f = def.withName(
"forward");
993 defineMethodsInModule(
994 mod, {def_f}, {pythonResolver(rcb)}, c10::nullopt);
995 auto& method = mod->get_method(
"forward");
996 method.setSchema(getSchemaWithNameAndDefaults(
997 def.range(), method.getSchema(), def.name().name(), defaults));
998 didFinishEmitModule(mod);
1003 "_jit_script_class_compile",
1004 [](std::shared_ptr<Module> module,
1006 ResolutionCallback rcb) {
1007 auto classType = ClassType::create(classDef.name().name(), module);
1008 std::vector<Resolver> rcbs;
1009 std::vector<Def> methodDefs;
1010 for (
const auto& def : classDef.defs()) {
1011 methodDefs.push_back(def);
1012 rcbs.push_back(pythonResolver(rcb));
1014 defineMethodsInModule(module, methodDefs, rcbs,
Self(classType));
1018 m.def(
"parse_type_comment", [](
const std::string& comment) {
1020 return Decl(p.parseTypeComment());
1023 m.def(
"merge_type_from_type_comment", &mergeTypesFromTypeComment);
1026 [](ModuleLookup module_lookup,
1027 const std::string& filename,
1028 py::object map_location,
1029 ExtraFilesMap& extra_files) {
1031 if (!map_location.is(py::none())) {
1032 AT_ASSERT(THPDevice_Check(map_location.ptr()));
1034 reinterpret_cast<THPDevice*
>(map_location.ptr())->device;
1036 import_ir_module(module_lookup, filename, optional_device, extra_files);
1039 "import_ir_module_from_buffer",
1040 [](ModuleLookup module_lookup,
1041 const std::string& buffer,
1042 py::object map_location,
1043 ExtraFilesMap& extra_files) {
1044 std::istringstream in(buffer);
1046 if (!map_location.is(py::none())) {
1047 AT_ASSERT(THPDevice_Check(map_location.ptr()));
1049 reinterpret_cast<THPDevice*
>(map_location.ptr())->device;
1051 import_ir_module(module_lookup, in, optional_device, extra_files);
1053 m.def(
"_jit_import_methods", import_methods);
1054 m.def(
"_jit_set_emit_module_hook", setEmitModuleHook);
1055 m.def(
"_jit_clear_class_registry", ClassType::clearRegistry);
1057 py::class_<testing::FileCheck>(m,
"FileCheck")
1059 .def(
"check", &testing::FileCheck::check)
1060 .def(
"check_not", &testing::FileCheck::check_not)
1061 .def(
"check_same", &testing::FileCheck::check_same)
1062 .def(
"check_next", &testing::FileCheck::check_next)
1063 .def(
"check_count", &testing::FileCheck::check_count)
1064 .def(
"check_dag", &testing::FileCheck::check_dag)
1065 .def(
"check_count", &testing::FileCheck::check_count)
1069 const std::string& str,
1071 bool exactly) {
return f.check_count(str, count, exactly); },
1075 py::arg(
"exactly") =
false)
constexpr size_t size() const
size - Get the array size.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
AT_CPP14_CONSTEXPR const T & at(size_t Index) const
Vector compatibility.
An ordered dictionary implementation, akin to Python's OrderedDict.