1 #include <torch/csrc/jit/script/module.h> 2 #include <c10/util/Exception.h> 3 #include <torch/csrc/jit/export.h> 4 #include <torch/csrc/jit/operator.h> 5 #include <torch/csrc/jit/script/compiler.h> 6 #include <torch/csrc/jit/script/error_report.h> 7 #include <torch/csrc/jit/script/schema_matching.h> 14 void placeholderCreator(
Method&) {
18 Value* try_emit_call_to(
25 std::stringstream& failure_messages,
27 bool conv_tensors_to_nums) {
29 callee.ensure_defined();
32 <<
" method '" << callee.name()
33 <<
"' is called recursively involving this call site. " 34 <<
"Recursive calls are not supported";
36 auto fn = callee.graph();
38 auto matched_schema = tryMatchSchema(
46 conv_tensors_to_nums);
52 for (
auto member : callee.initial_ivalues()) {
55 <<
" attempting to call a method with parameters/attributes" 56 " from a raw graph. File a bug report";
59 auto type = incompleteInferTypeFrom(*member);
60 matched_schema->inputs.push_back(
61 caller->get_or_add_attribute(type, member));
63 callee.check_single_output();
64 return inlineCallTo(graph, *callee.graph(), matched_schema->inputs).
at(0);
67 Value* Method::emit_call_to(
73 std::stringstream failure_messages;
74 if (
auto result = try_emit_call_to(
89 void Method::ensure_defined() {
91 auto creator = method_creator;
92 method_creator = placeholderCreator;
94 method_creator =
nullptr;
99 to_impl(device, dtype, non_blocking);
103 to_impl(c10::nullopt, dtype, non_blocking);
107 to_impl(device, c10::nullopt, non_blocking);
110 void Module::save(std::ostream& out,
const ExtraFilesMap& extra_files) {
111 ExportModule(*
this, out, extra_files);
115 const std::string& filename,
116 const ExtraFilesMap& extra_files) {
117 ExportModule(*
this, filename, extra_files);
120 void Module::to_impl(
125 for (
auto& child : modules) {
126 child->module->to_impl(device, dtype, non_blocking);
129 for (
auto& parameter : parameters) {
134 auto new_data = data.to(
135 device.value_or(data.
device()),
136 dtype.value_or(data.scalar_type()),
TORCH_API void to(at::Device device, at::ScalarType dtype, bool non_blocking=false)
Recursively casts all parameters to the given dtype and device.
Represents a a compute device on which a tensor is located.
Device device() const
Returns a Tensor's device.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
void set_data(const at::Tensor &new_data)
Sets the Tensor held by this Variable to the one supplied.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Flush-To-Zero and Denormals-Are-Zero mode.