3 #include <torch/csrc/python_headers.h> 6 #include <pybind11/pybind11.h> 7 #include <pybind11/stl.h> 9 #include <torch/csrc/DynamicTypes.h> 10 #include <torch/csrc/autograd/python_variable.h> 11 #include <torch/csrc/utils/python_tuples.h> 12 #include <torch/csrc/utils/python_numbers.h> 25 PYBIND11_TYPE_CASTER(
at::Tensor, _(
"at::Tensor"));
27 bool load(handle src,
bool) {
28 PyObject* obj = src.ptr();
29 if (THPVariable_Check(obj)) {
37 cast(
const at::Tensor& src, return_value_policy , handle ) {
39 throw std::runtime_error(
40 "Expected tensor's dynamic type to be Variable, not Tensor");
46 template<>
struct type_caster<
torch::autograd::Variable> {
49 bool load(handle src,
bool) {
50 PyObject *source = src.ptr();
51 if (THPVariable_Check(source)) {
59 return handle(THPVariable_Wrap(std::move(src)));
63 template<>
struct type_caster<
at::IntArrayRef> {
67 bool load(handle src,
bool) {
68 PyObject *source = src.ptr();
69 auto tuple = PyTuple_Check(source);
70 if (tuple || PyList_Check(source)) {
71 auto size = tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source);
73 for (
int idx = 0; idx < size; idx++) {
74 PyObject* obj = tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx);
75 if (THPVariable_Check(obj)) {
76 v_value[idx] = THPVariable_Unpack(obj).item<int64_t>();
77 }
else if (PyLong_Check(obj)) {
79 v_value[idx] = THPUtils_unpackLong(obj);
89 static handle cast(
at::IntArrayRef src, return_value_policy , handle ) {
90 return handle(THPUtils_packInt64Array(src.
size(), src.data()));
93 std::vector<int64_t> v_value;
99 struct type_caster<
c10::optional<T>> : optional_caster<c10::optional<T>> {};
constexpr size_t size() const
size - Get the array size.
bool is_variable() const noexcept
Returns true if the Tensor is actually a torch::autograd::Variable.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Flush-To-Zero and Denormals-Are-Zero mode.