Caffe2 - C++ API
A deep learning, cross platform ML framework
init.cpp
1 #include <torch/csrc/jit/script/init.h>
2 
3 #include <torch/csrc/Device.h>
4 #include <torch/csrc/Dtype.h>
5 #include <torch/csrc/Layout.h>
6 #include <torch/csrc/jit/import.h>
7 #include <torch/csrc/jit/script/compiler.h>
8 #include <torch/csrc/jit/script/module.h>
9 #include <torch/csrc/jit/script/schema_matching.h>
10 #include <torch/csrc/jit/script/sugared_value.h>
11 #include <torch/csrc/jit/testing/file_check.h>
12 
13 #include <torch/csrc/jit/constants.h>
14 #include <torch/csrc/jit/hooks_for_testing.h>
15 #include <torch/csrc/jit/import_source.h>
16 #include <torch/csrc/jit/passes/python_print.h>
17 #include <torch/csrc/jit/passes/to_batch.h>
18 #include <torch/csrc/jit/pybind_utils.h>
19 #include <torch/csrc/jit/python_tracer.h>
20 #include <torch/csrc/jit/script/parser.h>
21 
22 #include <torch/csrc/api/include/torch/ordered_dict.h>
23 
24 #include <ATen/ATen.h>
25 #include <ATen/core/function_schema.h>
26 
27 #include <pybind11/functional.h>
28 #include <pybind11/pybind11.h>
29 #include <pybind11/stl.h>
30 #include <pybind11/stl_bind.h>
31 #include <cstddef>
32 #include <memory>
33 #include <sstream>
34 #include <string>
35 #include <tuple>
36 #include <utility>
37 #include <vector>
38 
39 PYBIND11_MAKE_OPAQUE(torch::jit::script::ExtraFilesMap);
40 
41 namespace torch {
42 namespace jit {
43 namespace script {
44 
45 using ::c10::Argument;
46 using ::c10::FunctionSchema;
47 
48 using ResolutionCallback = std::function<py::function(std::string)>;
49 using FunctionDefaults = std::unordered_map<std::string, py::object>;
50 
51 static std::string typeString(py::handle h) {
52  return py::str(h.get_type().attr("__name__"));
53 }
54 
55 inline std::shared_ptr<SugaredValue> toSimple(Value* v) {
56  return std::make_shared<SimpleValue>(v);
57 }
58 
59 // NB: This should be the single entry-point for instantiating a SugaredValue
60 // from a Python object. If you are adding support for converting a new Python
61 // type, *add it in this function's implementation*.
62 std::shared_ptr<SugaredValue> toSugaredValue(
63  py::object obj,
64  Method& m,
65  SourceRange loc,
66  bool is_constant = false,
67  bool is_submodule = false);
68 
69 struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
70  PythonValue(py::object self) : self(std::move(self)) {}
71 
72  FunctionSchema getSchema(const size_t n_args, const size_t n_binders) {
73  auto annotations = py::module::import("torch.jit.annotations");
74  auto signature = annotations.attr("get_signature")(self);
75  std::vector<Argument> args, rets;
76  // We may mutate this if we can determine the number of args from Python
77  // introspection.
78  size_t actual_n_args = n_args;
79  if (!signature.is_none()) {
80  std::vector<TypePtr> arg_types;
81  TypePtr ret_type;
82  std::tie(arg_types, ret_type) =
83  py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
84  args.reserve(arg_types.size());
85  size_t idx = 0; // Fake argument names by putting in the index
86  for (auto& arg_type : arg_types) {
87  args.push_back(Argument(
88  std::to_string(idx++), std::move(arg_type), {}, {}, false));
89  }
90  rets.push_back(Argument("0", std::move(ret_type), {}, {}, false));
91  } else {
92  // Create a default signature using what information we have
93 
94  // First see if we can introspect the number of function parameters
95  // irrespective of the presence of explicit type annotations
96  auto num_params = annotations.attr("get_num_params")(self);
97  if (!num_params.is_none()) {
98  // Return a signature with the correct number of params according to the
99  // Python function. The error handling in call() will catch any mismatch
100  // later.
101  actual_n_args = py::cast<size_t>(num_params);
102  }
103  // Construct the default signature: all arguments and returns will be
104  // DynamicType
105  args.reserve(actual_n_args);
106  for (size_t i = 0; i < actual_n_args; ++i) {
107  args.push_back(
108  Argument(std::to_string(i), TensorType::get(), {}, {}, false));
109  }
110  TypePtr ret_type = TensorType::get();
111  if (n_binders == 0) {
112  ret_type = NoneType::get();
113  } else if (n_binders > 1) {
114  std::vector<TypePtr> tuple_values(n_binders, ret_type);
115  ret_type = TupleType::create(std::move(tuple_values));
116  }
117  rets.push_back(Argument("0", ret_type, {}, {}, false));
118  }
119  return FunctionSchema("", "", std::move(args), std::move(rets));
120  }
121 
122  // call it like a function, e.g. `outputs = this(inputs)`
123  std::shared_ptr<SugaredValue> call(
124  const SourceRange& loc,
125  Method& m,
126  at::ArrayRef<NamedValue> inputs_,
127  at::ArrayRef<NamedValue> attributes,
128  size_t n_binders) override {
129  auto inputs = toValues(*m.graph(), inputs_);
130  auto schema = getSchema(inputs.size(), n_binders);
131 
132  std::stringstream failure_messages;
133  c10::optional<MatchedSchema> matched_schema = tryMatchSchema(
134  schema,
135  loc,
136  *m.graph(),
137  c10::nullopt,
138  inputs_,
139  attributes,
140  failure_messages,
141  /*conv_tensor_to_num*/ true);
142  if (!matched_schema)
143  throw ErrorReport(loc) << failure_messages.str();
144 
145  // Release the function object so we can wrap it in a PythonOp
146  py::object func = self;
147  std::string cconv(inputs.size(), 'd');
148  Node* new_node = m.graph()->insertNode(m.graph()->createPythonOp(
149  THPObjectPtr(func.release().ptr()), cconv, {}));
150 
151  // Mark if function is ignored on export
152  if (py::cast<bool>(py::module::import("torch.jit")
153  .attr("_try_get_ignored_op")(self))) {
154  auto python_op = static_cast<PythonOp*>(new_node);
155  python_op->ignore_on_export = true;
156  }
157  new_node->setSourceLocation(std::make_shared<SourceRange>(loc));
158  for (auto& i : matched_schema->inputs)
159  new_node->addInput(i);
160 
161  Value* output =
162  new_node->addOutput()->setType(matched_schema->return_types.at(0));
163  return std::make_shared<SimpleValue>(output);
164  }
165 
166  std::string kind() const override {
167  std::stringstream ss;
168  ss << "python value of type '" << typeString(self) << "'";
169  return ss.str();
170  }
171 
172  std::vector<std::shared_ptr<SugaredValue>> asTuple(
173  const SourceRange& loc,
174  Method& m,
175  const c10::optional<size_t>& size_hint = {}) override {
176  const std::string type_str = typeString(self);
177  std::stringstream ss;
178  ss << kind() << " cannot be used as a tuple";
179  auto nn = py::module::import("torch.nn");
180  if (py::isinstance(self, nn.attr("ModuleList")) ||
181  py::isinstance(self, nn.attr("Sequential"))) {
182  ss << ". Did you forget to add it to __constants__? ";
183  }
184  throw ErrorReport(loc) << ss.str();
185  }
186 
187  protected:
188  py::object getattr(const SourceRange& loc, const std::string& name) {
189  try {
190  return py::getattr(self, name.c_str());
191  } catch (py::error_already_set& e) {
192  throw ErrorReport(loc) << "object has no attribute " << name;
193  }
194  }
195 
196  py::object self;
197 };
198 
199 struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
200  explicit PythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {}
201 
202  std::shared_ptr<SugaredValue> attr(
203  const SourceRange& loc,
204  Method& m,
205  const std::string& field) override {
206  py::object member = getattr(loc, field);
207  // note: is_constant = true because we consider that global properties
208  // on modules like math.pi or torch.float to be constants
209  // eventhough it is possible, though rare, for someone to mutate them
210  return toSugaredValue(member, m, loc, /*is_constant=*/true);
211  }
212 };
213 
214 struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
215  explicit ConstantPythonTupleValue(py::object tup)
216  : PythonValue(std::move(tup)) {}
217  std::vector<std::shared_ptr<SugaredValue>> asTuple(
218  const SourceRange& loc,
219  Method& m,
220  const c10::optional<size_t>& size_hint = {}) override {
221  py::tuple tup = self;
222  std::vector<std::shared_ptr<SugaredValue>> result;
223  result.reserve(tup.size());
224  for (py::handle t : tup) {
225  py::object obj = py::reinterpret_borrow<py::object>(t);
226  result.push_back(toSugaredValue(obj, m, loc, true));
227  }
228  return result;
229  }
230 
231  Value* asValue(const SourceRange& loc, Method& m) override {
232  std::vector<Value*> values;
233  for (const auto& sugared_item : asTuple(loc, m)) {
234  values.push_back(sugared_item->asValue(loc, m));
235  }
236  auto node = m.graph()->createTuple(values);
237  return m.graph()->insertNode(node)->output();
238  }
239 };
240 
241 // Represents all the parameters of a module as a List[Tensor]
242 struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
243  ConstantParameterList(std::shared_ptr<Module> module)
244  : module_(std::move(module)) {}
245 
246  std::string kind() const override {
247  return "constant parameter list";
248  }
249 
250  std::shared_ptr<SugaredValue> call(
251  const SourceRange& loc,
252  Method& caller,
254  at::ArrayRef<NamedValue> attributes,
255  size_t n_binders) override {
256  // Add all module parameters as inputs to the graph
257  std::vector<Value*> params;
258  const auto& param_list = module_->get_parameters().items();
259  for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
260  auto& param = *it;
261  params.push_back(caller.get_or_add_parameter(param->slot()));
262  }
263  auto list = caller.graph()->createList(TensorType::get(), params);
264  caller.graph()->insertNode(list);
265  return toSimple(list->output());
266  }
267 
268  private:
269  std::shared_ptr<Module> module_;
270 };
271 
272 // defines how modules/methods behave inside the script subset.
273 // for now this does not have any interaction with python.
274 // in the future, we will add the ability to resolve `self.foo` to python
275 // {functions, modules, contants} so this SugaredValue is defined here
276 // anticipating we will eventually need to replace Module with a py::object
277 // holding the actual nn.Module class.
278 
279 struct ModuleValue : public SugaredValue {
280  ModuleValue(std::shared_ptr<Module> module) : module(std::move(module)) {}
281 
282  std::string kind() const override {
283  return "module";
284  }
285 
286  // select an attribute on it, e.g. `this.field`
287  std::shared_ptr<SugaredValue> attr(
288  const SourceRange& loc,
289  Method& m,
290  const std::string& field) override {
291  // workaround to make self.training work
292  // it adds a buffer 'training' to the model if one doesn't exist
293  // and then loads that parameter, casting it to bool
294  if (field == "training") {
295  NamedIValue* v = module->find_buffer(field);
296  if (!v) {
297  py::object py_module = py::cast(module);
298  bool training = py::cast<bool>(py::getattr(py_module, "training"));
299  auto t =
300  autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong));
301  module->register_buffer("training", std::move(t));
302  v = module->find_buffer(field);
303  }
304  Value* the_tensor = m.get_or_add_parameter(v->slot());
305  Value* the_bool = m.graph()->insert(prim::Bool, {the_tensor});
306  return std::make_shared<SimpleValue>(the_bool);
307  }
308 
309  if (NamedModule* v = module->find_module(field)) {
310  return std::make_shared<ModuleValue>(v->module);
311  } else if (Method* v = module->find_method(field)) {
312  return std::make_shared<MethodValue>(shared_from_this(), *v);
313  } else if (NamedIValue* v = module->find_parameter(field)) {
314  return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
315  } else if (NamedIValue* v = module->find_attribute(field)) {
316  return std::make_shared<SimpleValue>(
317  m.get_or_add_attribute(v->type, v->slot()));
318  }
319 
320  // This can also be a call to a non-script module, or a plain
321  // python method. If so return this as a python value.
322  py::object py_module = py::cast(module);
323  if (py::object attr = py::getattr(py_module, field.c_str(), py::none())) {
324  if (py::isinstance<py::function>(attr) &&
325  py::hasattr(attr, "_is_parameter_list") &&
326  py::cast<bool>(py::getattr(attr, "_is_parameter_list"))) {
327  return std::make_shared<ConstantParameterList>(module);
328  }
329  if (py::isinstance<py::function>(attr) ||
330  py::isinstance(attr, py::module::import("torch.nn").attr("Module")) ||
331  py_module.attr("_constants_set").contains(field.c_str())) {
332  return toSugaredValue(attr, m, loc, true);
333  } else {
334  std::string hint = "did you forget to add it __constants__?";
335  if (py::isinstance(attr, py::module::import("torch").attr("Tensor"))) {
336  hint = "Tensors must be added to a module as a buffer or parameter";
337  }
338  throw ErrorReport(loc)
339  << "attribute '" << field << "' of type '" << typeString(attr)
340  << "' is not usable in a script method (" << hint << ")";
341  }
342  }
343  throw ErrorReport(loc) << "module has no attribute '" << field << "'";
344  }
345 
346  // call module.forward
347  std::shared_ptr<SugaredValue> call(
348  const SourceRange& loc,
349  Method& caller,
351  at::ArrayRef<NamedValue> attributes,
352  size_t n_binders) override {
353  return attr(loc, caller, "forward")
354  ->call(loc, caller, inputs, attributes, n_binders);
355  }
356 
357  std::vector<std::shared_ptr<SugaredValue>> asTuple(
358  const SourceRange& loc,
359  Method& m,
360  const c10::optional<size_t>& size_hint = {}) override {
361  py::object py_module = py::cast(module);
362  if (!py::isinstance(
363  py_module,
364  py::module::import("torch.jit").attr("_ConstModuleList")))
365  return SugaredValue::asTuple(loc, m, size_hint);
366  std::vector<std::shared_ptr<SugaredValue>> result;
367  for (py::handle module : py_module) {
368  py::object obj = py::reinterpret_borrow<py::object>(module);
369  result.push_back(toSugaredValue(
370  obj,
371  m,
372  loc,
373  /*is_constant =*/false,
374  /*is_submodule =*/true));
375  }
376  return result;
377  }
378 
379  private:
380  std::shared_ptr<Module> module;
381 };
382 
383 struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
384  BooleanDispatchValue(py::dict dispatched_fn)
385  : dispatched_fn_(std::move(dispatched_fn)) {}
386 
387  std::string kind() const override {
388  return "boolean dispatch";
389  }
390 
391  std::shared_ptr<SugaredValue> call(
392  const SourceRange& loc,
393  Method& caller,
395  at::ArrayRef<NamedValue> attributes,
396  size_t n_binders) override {
397  c10::optional<bool> result;
398  Graph& graph = *(caller.graph());
399 
400  auto index = py::cast<size_t>(dispatched_fn_["index"]);
401  auto arg_name = py::str(dispatched_fn_["arg_name"]);
402 
403  if (index < inputs.size()) {
404  // Dispatch flag is in arg list
405  result = constant_as<bool>(inputs.at(index).value(graph));
406  } else if (auto i = findInputWithName(arg_name, attributes)) {
407  // Dispatch flag is in kwargs
408  result = constant_as<bool>(attributes[*i].value(graph));
409  } else {
410  // Didn't find dispatch flag, so use default value
411  result = py::cast<bool>(dispatched_fn_["default"]);
412  }
413 
414  if (!result) {
415  throw ErrorReport(loc) << "value for boolean dispatch was not constant";
416  }
417 
418  std::shared_ptr<SugaredValue> value;
419  if (*result) {
420  value = toSugaredValue(dispatched_fn_["if_true"], caller, loc);
421  } else {
422  value = toSugaredValue(dispatched_fn_["if_false"], caller, loc);
423  }
424  return value->call(loc, caller, inputs, attributes, n_binders);
425  }
426 
427  private:
428  py::dict dispatched_fn_;
429 };
430 
431 struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue {
432  OverloadedFunctionValue(py::list functions)
433  : possible_functions_(std::move(functions)) {}
434 
435  std::string kind() const override {
436  return "overloaded function";
437  }
438 
439  std::shared_ptr<SugaredValue> call(
440  const SourceRange& loc,
441  Method& caller,
443  at::ArrayRef<NamedValue> attributes,
444  size_t n_binders) override {
445  std::stringstream err;
446  auto possible_functions =
447  py::cast<std::vector<py::object>>(possible_functions_);
448 
449  for (const py::object& fn : possible_functions) {
450  auto& method = py::cast<Method&>(fn);
451  auto match = tryMatchSchema(
452  method.getSchema(),
453  loc,
454  *caller.graph().get(),
455  c10::nullopt,
456  inputs,
457  attributes,
458  err,
459  true);
460  if (match) {
461  return MethodValue(nullptr, method)
462  .call(loc, caller, inputs, attributes, n_binders);
463  }
464  }
465  throw ErrorReport(loc) << "Could not find any matching overloads\n"
466  << err.str();
467  }
468 
469  private:
470  py::list possible_functions_;
471 };
472 
473 std::shared_ptr<SugaredValue> toSugaredValue(
474  py::object obj,
475  Method& m,
476  SourceRange loc,
477  bool is_constant,
478  bool is_submodule) {
479  // directly create SimpleValues when possible, because they are first-class
480  // and can be re-assigned. Otherwise, this would be invalid:
481  // f = python_constant
482  // while ...
483  // f = f + 1
484  auto& g = *m.graph();
485  if (is_constant) {
486  if (py::isinstance<py::bool_>(obj)) {
487  return toSimple(g.insertConstant(py::cast<bool>(obj), nullptr, loc));
488  } else if (py::isinstance<py::int_>(obj)) {
489  return toSimple(g.insertConstant(py::cast<int64_t>(obj), nullptr, loc));
490  } else if (py::isinstance<py::float_>(obj)) {
491  return toSimple(g.insertConstant(py::cast<double>(obj), nullptr, loc));
492  } else if (py::isinstance<py::str>(obj)) {
493  return toSimple(
494  g.insertConstant(py::cast<std::string>(obj), nullptr, loc));
495  } else if (obj.is(py::none())) {
496  return toSimple(g.insertConstant(IValue(), nullptr, loc));
497  } else if (THPDevice_Check(obj.ptr())) {
498  auto device = reinterpret_cast<THPDevice*>(obj.ptr());
499  return toSimple(g.insertConstant(device->device));
500  } else if (THPLayout_Check(obj.ptr())) {
501  auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
502  const auto v = static_cast<int64_t>(layout->layout);
503  return toSimple(g.insertConstant(v, nullptr, loc));
504  } else if (THPDtype_Check(obj.ptr())) {
505  auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
506  const auto v = static_cast<int64_t>(dtype->scalar_type);
507  return toSimple(g.insertConstant(v, nullptr, loc));
508  } else if (py::isinstance<py::tuple>(obj)) {
509  return std::make_shared<ConstantPythonTupleValue>(obj);
510  }
511  }
512 
513  auto weak_obj =
514  py::module::import("torch.jit").attr("_try_get_weak_module")(obj);
515  if (!weak_obj.is_none()) {
516  obj = weak_obj;
517  }
518  if (py::isinstance<Module>(obj)) {
519  auto mod = py::cast<std::shared_ptr<Module>>(obj);
520  // In the case that this Python object is not a submodule, inline *ONLY
521  // PURE* ScriptModules. This allows us to call arbitrary @script functions
522  // within a scripting context while still enforcing that parameters from
523  // stateful submodules are properly accounted for.
524  if (!is_submodule && mod->get_parameters().size() != 0) {
525  throw ErrorReport()
526  << "Attempted to inline a Module with parameters. "
527  "Stateful modules to be inlined must be submodules of the callee.";
528  }
529  const auto script_class_type =
530  py::module::import("torch.jit").attr("ScriptClass");
531  const bool is_class_type = py::isinstance(obj, script_class_type);
532  if (is_class_type) {
533  const auto classname = py::cast<std::string>(py::getattr(obj, "_name"));
534  auto classType = ClassType::get(classname);
535  AT_ASSERT(classType);
536  return std::make_shared<ClassValue>(std::move(classType));
537  }
538  return std::make_shared<ModuleValue>(mod);
539  } else if (py::isinstance<py::module>(obj)) {
540  return std::make_shared<PythonModuleValue>(obj);
541  } else if (obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr()) {
542  return std::make_shared<ForkValue>();
543  } else if (
544  obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) {
545  return std::make_shared<AnnotateValue>();
546  }
547 
548  py::object builtin_name =
549  py::module::import("torch.jit").attr("_find_builtin")(obj);
550  if (!builtin_name.is_none()) {
551  return std::make_shared<BuiltinFunction>(
552  Symbol::fromQualString(py::str(builtin_name)), c10::nullopt);
553  }
554 
555  if (py::isinstance<py::function>(obj)) {
556  auto compiled_fn =
557  py::module::import("torch.jit").attr("_try_compile_weak_script")(obj);
558  if (!compiled_fn.is(py::none())) {
559  auto mod = py::cast<std::shared_ptr<Module>>(compiled_fn);
560  return std::make_shared<ModuleValue>(mod);
561  }
562  }
563 
564  py::object dispatched_fn =
565  py::module::import("torch.jit").attr("_try_get_dispatched_fn")(obj);
566  if (!dispatched_fn.is_none()) {
567  return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
568  }
569 
570  py::object overloads =
571  py::module::import("torch.jit").attr("_try_get_overloaded_fn")(obj);
572  if (!overloads.is_none()) {
573  return std::make_shared<OverloadedFunctionValue>(std::move(overloads));
574  }
575 
576  return std::make_shared<PythonValue>(obj);
577 }
578 
579 py::object unpackVariableTensorList(std::vector<at::Tensor> outputs) {
580  // if we don't tell pybind these are variables it chokes on the
581  // conversion.
582  // TODO: fix conversions to be sane and make sure this works.
583  if (outputs.size() == 0) {
584  return py::none();
585  } else if (outputs.size() == 1) {
586  return py::cast(autograd::as_variable_ref(outputs[0]));
587  } else {
588  py::tuple tuple(outputs.size());
589  for (size_t i = 0; i < outputs.size(); i++) {
590  tuple[i] = py::cast(autograd::as_variable_ref(outputs[i]));
591  }
592  return std::move(tuple);
593  }
594 }
595 
596 static void gatherParametersAndBuffers(
597  std::vector<IValue*>& values,
598  const Module& m) {
599  for (auto& param : m.get_parameters()) {
600  values.push_back(param->slot());
601  }
602  for (auto& param : m.get_attributes()) {
603  if (param->type->isSubtypeOf(TensorType::get())) {
604  values.push_back(param->slot());
605  }
606  }
607  for (const auto& sub : m.get_modules()) {
608  gatherParametersAndBuffers(values, *sub->module);
609  }
610 }
611 
612 namespace {
613 
614 Resolver pythonResolver(const ResolutionCallback& rcb) {
615  return [rcb](const std::string& name, Method& m, const SourceRange& loc)
616  -> std::shared_ptr<SugaredValue> {
617  AutoGIL ag;
618  py::object obj = rcb(name);
619  if (obj.is(py::none())) {
620  return nullptr;
621  }
622  return toSugaredValue(obj, m, loc);
623  };
624 }
625 } // namespace
626 
627 FunctionSchema getSchemaWithNameAndDefaults(
628  const SourceRange& range,
629  const FunctionSchema& schema,
630  const at::optional<std::string>& new_name,
631  const FunctionDefaults& default_args) {
632  std::vector<Argument> new_args;
633  for (auto& arg : schema.arguments()) {
634  auto it = default_args.find(arg.name());
635  if (it != default_args.end()) {
636  try {
637  IValue value;
638  auto n = arg.N();
639  auto list_type = arg.type()->cast<ListType>();
640  if (n && *n > 0 && list_type) {
641  // BroadcastingList, allow default values T for arg types List[T]
642  value = toIValue(it->second, list_type->getElementType());
643  } else {
644  value = toIValue(it->second, arg.type());
645  }
646  new_args.emplace_back(
647  arg.name(), arg.type(), arg.N(), value, arg.kwarg_only());
648  } catch (py::cast_error& e) {
649  throw ErrorReport(range)
650  << "Expected a default value of type " << arg.type()->str()
651  << " on parameter \"" << arg.name() << "\"";
652  }
653  } else {
654  new_args.push_back(arg);
655  }
656  }
657 
658  return FunctionSchema(
659  new_name.value_or(schema.name()),
660  schema.overload_name(),
661  new_args,
662  schema.returns(),
663  schema.is_vararg(),
664  schema.is_varret());
665 }
666 
667 void initJitScriptBindings(PyObject* module) {
668  auto m = py::handle(module).cast<py::module>();
669 
670  // STL containers are not mutable by default and hence we need to bind as
671  // follows.
672  py::bind_map<ExtraFilesMap>(m, "ExtraFilesMap");
673 
674  // torch.jit.ScriptModule is a subclass of this C++ object.
675  // Methods here are prefixed with _ since they should not be
676  // public.
677  py::class_<Module, std::shared_ptr<Module>>(m, "ScriptModule")
678  .def(py::init<>())
679  .def(
680  "save",
681  [](std::shared_ptr<Module> m,
682  const std::string& filename,
683  const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
684  m->save(filename, _extra_files);
685  },
686  py::arg("filename"),
687  py::arg("_extra_files") = ExtraFilesMap())
688  .def(
689  "save_to_buffer",
690  [](std::shared_ptr<Module> m,
691  const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
692  std::ostringstream buf;
693  m->save(buf, _extra_files);
694  return py::bytes(buf.str());
695  },
696  py::arg("_extra_files") = ExtraFilesMap())
697  .def("_set_optimized", &Module::set_optimized)
698  .def(
699  "_define",
700  [](std::shared_ptr<Module> m,
701  const std::string& script,
702  ResolutionCallback rcb,
703  bool has_self) {
704  c10::optional<Self> self;
705  if (has_self) {
706  self = Self(std::make_shared<ModuleValue>(m));
707  }
708  defineMethodsInModule(m, script, pythonResolver(rcb), self);
709  })
710  .def(
711  "_create_methods",
712  [](std::shared_ptr<Module> m,
713  const std::vector<Def>& defs,
714  const std::vector<ResolutionCallback>& rcbs,
715  const std::vector<FunctionDefaults>& defaults) {
716  std::vector<Resolver> resolvers;
717  resolvers.reserve(rcbs.size());
718  for (auto& callback : rcbs) {
719  resolvers.push_back(pythonResolver(callback));
720  }
721  defineMethodsInModule(
722  m, defs, resolvers, Self(std::make_shared<ModuleValue>(m)));
723 
724  // Stitch in default arguments for each Def if provided
725  auto defaults_it = defaults.begin();
726  auto defs_it = defs.begin();
727  while (defs_it != defs.end()) {
728  auto& method = m->get_method((*defs_it).name().name());
729  method.setSchema(getSchemaWithNameAndDefaults(
730  defs_it->range(),
731  method.getSchema(),
732  at::nullopt,
733  *defaults_it));
734  ++defs_it;
735  ++defaults_it;
736  }
737  didFinishEmitModule(m);
738  })
739  .def(
740  "_get_method",
741  [](Module& self, const std::string& name) -> const Method& {
742  return self.get_method(name);
743  },
744  py::return_value_policy::reference_internal)
745  .def("_register_parameter", &Module::register_parameter)
746  .def(
747  "_register_attribute",
748  [](Module& self, std::string name, TypePtr type, py::object value) {
749  self.register_attribute(name, type, toIValue(value, type));
750  })
751  .def("_register_module", &Module::register_module)
752  .def("_register_buffer", &Module::register_buffer)
753  .def("_set_parameter", &Module::set_parameter)
754  .def("_get_parameter", &Module::get_parameter)
755  .def("_get_buffer", &Module::get_buffer)
756  .def("_get_module", &Module::get_module)
757  .def(
758  "_get_modules",
759  [](Module& self) -> py::tuple {
760  auto& modules = self.get_modules();
761  py::tuple result(modules.size());
762  for (size_t i = 0; i < modules.size(); ++i) {
763  auto& item = modules[i];
764  result[i] = std::make_pair(item.key(), item.value().module);
765  }
766  return result;
767  })
768  .def(
769  "_get_parameters",
770  [](Module& self) -> py::tuple {
771  auto& parameters = self.get_parameters();
772  py::tuple result(parameters.size());
773  for (size_t i = 0; i < parameters.size(); ++i) {
774  auto& p = parameters[i];
775  py::tuple r(2);
776  result[i] = std::make_tuple(
777  p.key(), autograd::as_variable_ref(p->slot()->toTensor()));
778  }
779  return result;
780  })
781  .def(
782  "_get_attributes",
783  [](Module& self) -> py::tuple {
784  auto& attributes = self.get_attributes();
785  py::tuple result(attributes.size());
786  for (size_t i = 0; i < attributes.size(); ++i) {
787  auto& buffer = attributes[i];
788  py::tuple r(3);
789  IValue v = *buffer->slot();
790  result[i] = std::make_tuple(
791  buffer.key(), buffer->type, toPyObject(std::move(v)));
792  }
793  return result;
794  })
795  .def(
796  "_has_parameter",
797  [](Module& self, const std::string& name) -> bool {
798  return self.find_parameter(name);
799  })
800  .def(
801  "_has_buffer",
802  [](Module& self, const std::string& name) -> bool {
803  return self.find_buffer(name);
804  })
805  .def(
806  "_has_module",
807  [](Module& self, const std::string& name) {
808  return bool(self.find_module(name));
809  })
810  .def(
811  "_has_method",
812  [](Module& self, const std::string& name) {
813  return bool(self.find_method(name));
814  })
815  .def(
816  "_method_names",
817  [](Module& self) {
818  using Item =
820  return fmap(self.get_methods(), [](const Item& item) {
821  return (*item)->name();
822  });
823  })
824  .def(
825  "_create_method_from_graph",
826  [](Module& self,
827  const std::string& name,
828  std::shared_ptr<Graph> graph) {
829  self.create_method(name, std::move(graph), {});
830  })
831  .def(
832  "_create_method_from_trace",
833  [](std::shared_ptr<Module> self,
834  const std::string& name,
835  py::function func,
836  py::tuple input_tuple,
837  py::function var_lookup_fn,
838  bool force_outplace) {
839  // prereq: Module's buffers and parameters are unique
840  // this was ensured in python before calling this function
841  std::vector<IValue*> parameters;
842  gatherParametersAndBuffers(parameters, *self);
843  Stack inputs = toStack(input_tuple);
844  for (IValue* param : parameters) {
845  inputs.emplace_back(*param);
846  }
847  auto graph = tracer::createGraphByTracing(
848  func,
849  inputs,
850  var_lookup_fn,
851  force_outplace,
852  input_tuple.size());
853  self->create_method(name, std::move(graph), std::move(parameters));
854  didFinishEmitModule(self);
855  })
856  .def(
857  "graph_for",
858  [](py::args args, py::kwargs kwargs) {
859  // [pybind11 varargs] note: old version of pybind11 have a bug that
860  // leaks memory when py::args is mixed with positional arguments
861  // https://github.com/pybind/pybind11/pull/1216
862  // we work around this by not mixing positional arguments with
863  // varargs
864  Module& self = py::cast<Module&>(args[0]);
865  if (self.find_method("forward")) {
866  Method& m = self.get_method("forward");
867  return m.graph_for(createStackForSchema(
868  m.getSchema(), tuple_slice(std::move(args), 1), kwargs));
869  }
870  throw std::runtime_error(
871  "Attempted to call graph_for on a Module without a compiled forward()");
872  })
873  .def(
874  "get_debug_state",
875  [](Module& self) {
876  if (self.find_method("forward")) {
877  Method& m = self.get_method("forward");
878  return m.getDebugState();
879  }
880  throw std::runtime_error(
881  "Attempted to call get_debug_state on a Module without a compiled forward()");
882  })
883  .def(
884  "debug_disable_autodiff_subgraph_inlining",
885  [](Module& self) {
886  if (self.find_method("forward")) {
887  Method& m = self.get_method("forward");
888  m.debugDisableAutodiffSubgraphInlining();
889  }
890  })
891  .def(
892  "forward",
893  [](py::args args, py::kwargs kwargs) {
894  // We implement this in C++ to avoid incurring the pybind11 dispatch
895  // overhead twice: once to call into the method lookup for "forward"
896  // and once to actually invoke the method.
897  //
898  // There is a thin wrapper on top of this method in the C++ version
899  // of ScriptModule.
900 
901  // see: [pybind11 varargs]
902  Module& self = py::cast<Module&>(args[0]);
903  return invokeScriptMethodFromPython(
904  self.get_method("forward"),
905  tuple_slice(std::move(args), 1),
906  std::move(kwargs));
907  })
908  .def(
909  "_python_print",
910  [](Module& self) {
911  std::ostringstream ss;
912  std::vector<at::Tensor> tensors;
913  std::vector<ClassTypePtr> classes;
914  PythonPrint(ss, self, tensors, classes, true);
915  return std::make_pair(ss.str(), tensors);
916  })
917  .def_property_readonly(
918  "code",
919  [](Module& self) {
920  std::ostringstream ss;
921  std::vector<at::Tensor> tensors;
922  std::vector<ClassTypePtr> classes;
923  PythonPrint(ss, self, tensors, classes, false);
924  return ss.str();
925  })
926  .def("apply", &Module::apply)
927  .def("_copy_into", &Module::copy_into)
928  .def(
929  "_copy_method",
930  [](std::shared_ptr<Module> m,
931  std::string name,
932  std::vector<std::tuple<std::shared_ptr<Module>, std::string>>
933  params,
934  std::shared_ptr<Module> orig) {
935  std::vector<IValue*> member_inputs;
936  for (auto& p : params) {
937  NamedIValue* np = std::get<0>(p)->find_parameter(std::get<1>(p));
938  if (np == nullptr) {
939  np = std::get<0>(p)->find_buffer(std::get<1>(p));
940  }
941  AT_ASSERT(np != nullptr);
942  member_inputs.push_back(np->slot());
943  }
944 
945  Method* orig_method = orig->find_method(name);
946  m->create_method(name, orig_method->graph()->copy(), member_inputs);
947  });
948 
949  py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
950  .def("graph", [&](Method& self) { return self.graph(); })
951  .def(
952  "__call__",
953  [](py::args args, py::kwargs kwargs) {
954  // see: [pybind11 varargs]
955  Method& method = py::cast<Method&>(args[0]);
956  return invokeScriptMethodFromPython(
957  method, tuple_slice(std::move(args), 1), std::move(kwargs));
958  })
959  .def_property_readonly("graph", [](Method& m) { return m.graph(); })
960  .def("propagate_shapes", &Method::propagate_shapes)
961  .def(
962  "propagate_and_assign_input_and_output_shapes",
963  &Method::propagate_and_assign_input_and_output_shapes)
964  .def("initial_ivalues", &Method::initial_ivalues)
965  .def(
966  "graph_for",
967  [](py::args args, py::kwargs kwargs) {
968  // see: [pybind11 varargs]
969  Method& self = py::cast<Method&>(args[0]);
970  return self.graph_for(createStackForSchema(
971  self.getSchema(), tuple_slice(std::move(args), 1), kwargs));
972  })
973  .def(
974  "debug_disable_autodiff_subgraph_inlining",
975  &Method::debugDisableAutodiffSubgraphInlining)
976  .def("schema", &Method::getSchema)
977  .def("pretty_print_schema", &Method::pretty_print_schema)
978  .def("python_print", [](Method& m) {
979  std::ostringstream oss;
980  std::vector<at::Tensor> constants;
981  std::vector<ClassTypePtr> classes;
982  PythonPrint(oss, m, constants, classes, true);
983  return std::make_pair(oss.str(), std::move(constants));
984  });
985 
986  m.def(
987  "_jit_script_compile",
988  [](std::shared_ptr<Module> mod,
989  const Def& def,
990  ResolutionCallback rcb,
991  FunctionDefaults defaults) {
992  auto def_f = def.withName("forward");
993  defineMethodsInModule(
994  mod, {def_f}, {pythonResolver(rcb)}, c10::nullopt);
995  auto& method = mod->get_method("forward");
996  method.setSchema(getSchemaWithNameAndDefaults(
997  def.range(), method.getSchema(), def.name().name(), defaults));
998  didFinishEmitModule(mod);
999  return mod;
1000  });
1001 
1002  m.def(
1003  "_jit_script_class_compile",
1004  [](std::shared_ptr<Module> module,
1005  const ClassDef& classDef,
1006  ResolutionCallback rcb) {
1007  auto classType = ClassType::create(classDef.name().name(), module);
1008  std::vector<Resolver> rcbs;
1009  std::vector<Def> methodDefs;
1010  for (const auto& def : classDef.defs()) {
1011  methodDefs.push_back(def);
1012  rcbs.push_back(pythonResolver(rcb));
1013  }
1014  defineMethodsInModule(module, methodDefs, rcbs, Self(classType));
1015  return module;
1016  });
1017 
1018  m.def("parse_type_comment", [](const std::string& comment) {
1019  Parser p(comment);
1020  return Decl(p.parseTypeComment());
1021  });
1022 
1023  m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
1024  m.def(
1025  "import_ir_module",
1026  [](ModuleLookup module_lookup,
1027  const std::string& filename,
1028  py::object map_location,
1029  ExtraFilesMap& extra_files) {
1030  c10::optional<at::Device> optional_device;
1031  if (!map_location.is(py::none())) {
1032  AT_ASSERT(THPDevice_Check(map_location.ptr()));
1033  optional_device =
1034  reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1035  }
1036  import_ir_module(module_lookup, filename, optional_device, extra_files);
1037  });
1038  m.def(
1039  "import_ir_module_from_buffer",
1040  [](ModuleLookup module_lookup,
1041  const std::string& buffer,
1042  py::object map_location,
1043  ExtraFilesMap& extra_files) {
1044  std::istringstream in(buffer);
1045  c10::optional<at::Device> optional_device;
1046  if (!map_location.is(py::none())) {
1047  AT_ASSERT(THPDevice_Check(map_location.ptr()));
1048  optional_device =
1049  reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1050  }
1051  import_ir_module(module_lookup, in, optional_device, extra_files);
1052  });
1053  m.def("_jit_import_methods", import_methods);
1054  m.def("_jit_set_emit_module_hook", setEmitModuleHook);
1055  m.def("_jit_clear_class_registry", ClassType::clearRegistry);
1056 
1057  py::class_<testing::FileCheck>(m, "FileCheck")
1058  .def(py::init<>())
1059  .def("check", &testing::FileCheck::check)
1060  .def("check_not", &testing::FileCheck::check_not)
1061  .def("check_same", &testing::FileCheck::check_same)
1062  .def("check_next", &testing::FileCheck::check_next)
1063  .def("check_count", &testing::FileCheck::check_count)
1064  .def("check_dag", &testing::FileCheck::check_dag)
1065  .def("check_count", &testing::FileCheck::check_count)
1066  .def(
1067  "check_count",
1068  [](testing::FileCheck& f,
1069  const std::string& str,
1070  size_t count,
1071  bool exactly) { return f.check_count(str, count, exactly); },
1072  "Check Count",
1073  py::arg("str"),
1074  py::arg("count"),
1075  py::arg("exactly") = false)
1076  .def(
1077  "run",
1078  [](testing::FileCheck& f, const std::string& str) {
1079  return f.run(str);
1080  })
1081  .def("run", [](testing::FileCheck& f, const Graph& g) {
1082  return f.run(g);
1083  });
1084 }
1085 } // namespace script
1086 } // namespace jit
1087 } // namespace torch
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: Dtype.h:9
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
AT_CPP14_CONSTEXPR const T & at(size_t Index) const
Vector compatibility.
Definition: ArrayRef.h:186
An ordered dictionary implementation, akin to Python&#39;s OrderedDict.
Definition: ordered_dict.h:16