Caffe2 - C++ API
A deep learning, cross platform ML framework
module.cpp
1 #include <torch/csrc/jit/script/module.h>
2 #include <c10/util/Exception.h>
3 #include <torch/csrc/jit/export.h>
4 #include <torch/csrc/jit/operator.h>
5 #include <torch/csrc/jit/script/compiler.h>
6 #include <torch/csrc/jit/script/error_report.h>
7 #include <torch/csrc/jit/script/schema_matching.h>
8 
9 namespace torch {
10 namespace jit {
11 namespace script {
12 
13 struct RecursiveMethodCallError : public std::exception {};
14 void placeholderCreator(Method&) {
16 }
17 
18 Value* try_emit_call_to(
19  Graph& graph,
20  const SourceRange& loc,
21  Method& callee,
24  ArrayRef<NamedValue> kwargs,
25  std::stringstream& failure_messages,
26  Method* caller,
27  bool conv_tensors_to_nums) {
28  try {
29  callee.ensure_defined();
30  } catch (RecursiveMethodCallError&) {
31  throw ErrorReport(loc)
32  << " method '" << callee.name()
33  << "' is called recursively involving this call site. "
34  << "Recursive calls are not supported";
35  }
36  auto fn = callee.graph();
37 
38  auto matched_schema = tryMatchSchema(
39  callee.getSchema(),
40  loc,
41  graph,
42  std::move(self),
43  args,
44  kwargs,
45  failure_messages,
46  conv_tensors_to_nums);
47  if (!matched_schema)
48  return nullptr;
49 
50  // parameters to callee method (which become parameters to _this_ method
51  // if they were not already)
52  for (auto member : callee.initial_ivalues()) {
53  if (!caller) {
54  throw ErrorReport(loc)
55  << " attempting to call a method with parameters/attributes"
56  " from a raw graph. File a bug report";
57  }
58  // TODO: preserve the type information so we don't have to infer it here
59  auto type = incompleteInferTypeFrom(*member);
60  matched_schema->inputs.push_back(
61  caller->get_or_add_attribute(type, member));
62  }
63  callee.check_single_output();
64  return inlineCallTo(graph, *callee.graph(), matched_schema->inputs).at(0);
65 }
66 
67 Value* Method::emit_call_to(
68  const SourceRange& loc,
69  Method& callee,
71  ArrayRef<NamedValue> kwargs) {
72  AT_ASSERT(!executor);
73  std::stringstream failure_messages;
74  if (auto result = try_emit_call_to(
75  *graph(),
76  loc,
77  callee,
78  c10::nullopt,
79  args,
80  kwargs,
81  failure_messages,
82  this,
83  /*conv_tensors_to_nums=*/true)) {
84  return result;
85  }
86  throw ErrorReport(loc) << failure_messages.str();
87 }
88 
89 void Method::ensure_defined() {
90  if (method_creator) {
91  auto creator = method_creator;
92  method_creator = placeholderCreator;
93  creator(*this);
94  method_creator = nullptr;
95  }
96 }
97 
98 void Module::to(at::Device device, at::ScalarType dtype, bool non_blocking) {
99  to_impl(device, dtype, non_blocking);
100 }
101 
102 void Module::to(at::ScalarType dtype, bool non_blocking) {
103  to_impl(/*device=*/c10::nullopt, dtype, non_blocking);
104 }
105 
106 void Module::to(at::Device device, bool non_blocking) {
107  to_impl(device, /*dtype=*/c10::nullopt, non_blocking);
108 }
109 
110 void Module::save(std::ostream& out, const ExtraFilesMap& extra_files) {
111  ExportModule(*this, out, extra_files);
112 }
113 
114 void Module::save(
115  const std::string& filename,
116  const ExtraFilesMap& extra_files) {
117  ExportModule(*this, filename, extra_files);
118 }
119 
120 void Module::to_impl(
121  const c10::optional<at::Device>& device,
122  const c10::optional<at::ScalarType>& dtype,
123  bool non_blocking) {
124  // First call `to()` on every child module.
125  for (auto& child : modules) {
126  child->module->to_impl(device, dtype, non_blocking);
127  }
128  // Then convert every of our parameters.
129  for (auto& parameter : parameters) {
130  // Need to access the `at::Tensor` as a `Variable` here.
131  autograd::Variable variable = parameter.value().slot()->toTensor();
132  at::Tensor data = variable.data();
133  // Use the data's original device or dtype if not supplied here.
134  auto new_data = data.to(
135  device.value_or(data.device()),
136  dtype.value_or(data.scalar_type()),
137  non_blocking);
138  variable.set_data(new_data);
139  }
140 }
141 
142 } // namespace script
143 } // namespace jit
144 } // namespace torch
TORCH_API void to(at::Device device, at::ScalarType dtype, bool non_blocking=false)
Recursively casts all parameters to the given dtype and device.
Definition: module.cpp:98
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
Device device() const
Returns a Tensor&#39;s device.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
void set_data(const at::Tensor &new_data)
Sets the Tensor held by this Variable to the one supplied.
Definition: variable.h:678
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
Flush-To-Zero and Denormals-Are-Zero mode.