8 #include <torch/csrc/autograd/python_variable.h> 10 namespace torch {
namespace nn {
12 inline bool check_type(PyObject* obj, at::TypeID typeID) {
13 if (THPVariable_Check(obj)) {
15 return at::globalContext().getNonVariableType(tensor.type().backend(), tensor.scalar_type()).ID() == typeID;
21 inline T* unpack(PyObject* obj) {
22 return (
T*) ((
THPVariable*)obj)->cdata.data().unsafeGetTensorImpl();
27 static inline int get_device(PyObject* args) {
28 for (
int i = 0, n = PyTuple_GET_SIZE(args); i != n; i++) {
29 PyObject* arg = PyTuple_GET_ITEM(args, i);
30 if (THPVariable_Check(arg)) {
31 auto& tensor = THPVariable_UnpackData(arg);
32 if (tensor.is_cuda()) {
33 return tensor.get_device();
40 static inline bool THNN_FloatTensor_Check(PyObject* obj) {
41 return torch::nn::check_type(obj, at::TypeID::CPUFloat);
44 static inline bool THNN_DoubleTensor_Check(PyObject* obj) {
45 return torch::nn::check_type(obj, at::TypeID::CPUDouble);
48 static inline bool THNN_LongTensor_Check(PyObject* obj) {
49 return torch::nn::check_type(obj, at::TypeID::CPULong);
52 static inline bool THNN_IntTensor_Check(PyObject* obj) {
53 return torch::nn::check_type(obj, at::TypeID::CPUInt);
56 static inline THFloatTensor* THNN_FloatTensor_Unpack(PyObject* obj) {
57 return torch::nn::unpack<THFloatTensor>(obj);
60 static inline THDoubleTensor* THNN_DoubleTensor_Unpack(PyObject* obj) {
61 return torch::nn::unpack<THDoubleTensor>(obj);
64 static inline THLongTensor* THNN_LongTensor_Unpack(PyObject* obj) {
65 return torch::nn::unpack<THLongTensor>(obj);
68 static inline THIntTensor* THNN_IntTensor_Unpack(PyObject* obj) {
69 return torch::nn::unpack<THIntTensor>(obj);
74 static inline bool THNN_CudaHalfTensor_Check(PyObject* obj) {
75 return torch::nn::check_type(obj, at::TypeID::CUDAHalf);
78 static inline bool THNN_CudaFloatTensor_Check(PyObject* obj) {
79 return torch::nn::check_type(obj, at::TypeID::CUDAFloat);
82 static inline bool THNN_CudaDoubleTensor_Check(PyObject* obj) {
83 return torch::nn::check_type(obj, at::TypeID::CUDADouble);
86 static inline bool THNN_CudaLongTensor_Check(PyObject* obj) {
87 return torch::nn::check_type(obj, at::TypeID::CUDALong);
90 static inline THCudaHalfTensor* THNN_CudaHalfTensor_Unpack(PyObject* obj) {
91 return torch::nn::unpack<THCudaHalfTensor>(obj);
94 static inline THCudaTensor* THNN_CudaFloatTensor_Unpack(PyObject* obj) {
95 return torch::nn::unpack<THCudaTensor>(obj);
98 static inline THCudaDoubleTensor* THNN_CudaDoubleTensor_Unpack(PyObject* obj) {
99 return torch::nn::unpack<THCudaDoubleTensor>(obj);
102 static inline THCudaLongTensor* THNN_CudaLongTensor_Unpack(PyObject* obj) {
103 return torch::nn::unpack<THCudaLongTensor>(obj);