Caffe2 - C++ API
A deep learning, cross platform ML framework
python_arg_parsing.h
1 #pragma once
2 
3 #include <torch/csrc/python_headers.h>
4 #include <ATen/ATen.h>
5 
6 #include <torch/csrc/utils/python_arg_parser.h>
7 
8 namespace torch { namespace autograd { namespace utils {
9 
10 // The parameter allow_copy is to accept copy for Tensor.to (and by proxy
11 // PackedSequences.to) but not nn.Module.to.
12 inline std::tuple<c10::optional<at::Device>, c10::optional<at::ScalarType>, bool, bool>
13  parse_to_conversion(PyObject *args, PyObject *kwargs, bool allow_copy) {
14  static PythonArgParser parser({
15  "to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False)",
16  "to(ScalarType dtype, bool non_blocking=False, bool copy=False)",
17  "to(Tensor tensor, bool non_blocking=False, bool copy=False)",
18  });
19  ParsedArgs<4> parsed_args;
20  auto r = parser.parse(args, kwargs, parsed_args);
21  if (r.idx == 0) {
22  if (!allow_copy && !r.isNone(3))
23  throw std::runtime_error(".to() does not accept copy argument");
24  return std::make_tuple(r.deviceOptional(0), r.scalartypeOptional(1), r.toBool(2), r.toBool(3));
25  } else if (r.idx == 1) {
26  if (!allow_copy && !r.isNone(2))
27  throw std::runtime_error(".to() does not accept copy argument");
28  return std::make_tuple(c10::nullopt, r.scalartype(0), r.toBool(1), r.toBool(2));
29  } else {
30  auto tensor = r.tensor(0);
31  if (!allow_copy && !r.isNone(2))
32  throw std::runtime_error(".to() does not accept copy argument");
33  return std::make_tuple(
34  tensor.device(),
35  tensor.scalar_type(),
36  r.toBool(1),
37  r.toBool(2)
38  );
39  }
40 }
41 }}} // namespace torch::autograd::utils
Definition: jit_type.h:17