Caffe2 - C++ API
A deep learning, cross platform ML framework
type_checks.h
1 #pragma once
2 
3 // Defines type checks and unpacking code for the legacy THNN/THCUNN bindings.
4 // These checks accept Tensors and Variables.
5 
6 #include <ATen/ATen.h>
7 
8 #include <torch/csrc/autograd/python_variable.h>
9 
10 namespace torch { namespace nn {
11 
12 inline bool check_type(PyObject* obj, at::TypeID typeID) {
13  if (THPVariable_Check(obj)) {
14  auto& tensor = ((THPVariable*)obj)->cdata;
15  return at::globalContext().getNonVariableType(tensor.type().backend(), tensor.scalar_type()).ID() == typeID;
16  }
17  return false;
18 }
19 
20 template<typename T>
21 inline T* unpack(PyObject* obj) {
22  return (T*) ((THPVariable*)obj)->cdata.data().unsafeGetTensorImpl();
23 }
24 
25 }} // namespace torch::nn
26 
27 static inline int get_device(PyObject* args) {
28  for (int i = 0, n = PyTuple_GET_SIZE(args); i != n; i++) {
29  PyObject* arg = PyTuple_GET_ITEM(args, i);
30  if (THPVariable_Check(arg)) {
31  auto& tensor = THPVariable_UnpackData(arg);
32  if (tensor.is_cuda()) {
33  return tensor.get_device();
34  }
35  }
36  }
37  return -1;
38 }
39 
40 static inline bool THNN_FloatTensor_Check(PyObject* obj) {
41  return torch::nn::check_type(obj, at::TypeID::CPUFloat);
42 }
43 
44 static inline bool THNN_DoubleTensor_Check(PyObject* obj) {
45  return torch::nn::check_type(obj, at::TypeID::CPUDouble);
46 }
47 
48 static inline bool THNN_LongTensor_Check(PyObject* obj) {
49  return torch::nn::check_type(obj, at::TypeID::CPULong);
50 }
51 
52 static inline bool THNN_IntTensor_Check(PyObject* obj) {
53  return torch::nn::check_type(obj, at::TypeID::CPUInt);
54 }
55 
56 static inline THFloatTensor* THNN_FloatTensor_Unpack(PyObject* obj) {
57  return torch::nn::unpack<THFloatTensor>(obj);
58 }
59 
60 static inline THDoubleTensor* THNN_DoubleTensor_Unpack(PyObject* obj) {
61  return torch::nn::unpack<THDoubleTensor>(obj);
62 }
63 
64 static inline THLongTensor* THNN_LongTensor_Unpack(PyObject* obj) {
65  return torch::nn::unpack<THLongTensor>(obj);
66 }
67 
68 static inline THIntTensor* THNN_IntTensor_Unpack(PyObject* obj) {
69  return torch::nn::unpack<THIntTensor>(obj);
70 }
71 
72 #ifdef USE_CUDA
73 
74 static inline bool THNN_CudaHalfTensor_Check(PyObject* obj) {
75  return torch::nn::check_type(obj, at::TypeID::CUDAHalf);
76 }
77 
78 static inline bool THNN_CudaFloatTensor_Check(PyObject* obj) {
79  return torch::nn::check_type(obj, at::TypeID::CUDAFloat);
80 }
81 
82 static inline bool THNN_CudaDoubleTensor_Check(PyObject* obj) {
83  return torch::nn::check_type(obj, at::TypeID::CUDADouble);
84 }
85 
86 static inline bool THNN_CudaLongTensor_Check(PyObject* obj) {
87  return torch::nn::check_type(obj, at::TypeID::CUDALong);
88 }
89 
90 static inline THCudaHalfTensor* THNN_CudaHalfTensor_Unpack(PyObject* obj) {
91  return torch::nn::unpack<THCudaHalfTensor>(obj);
92 }
93 
94 static inline THCudaTensor* THNN_CudaFloatTensor_Unpack(PyObject* obj) {
95  return torch::nn::unpack<THCudaTensor>(obj);
96 }
97 
98 static inline THCudaDoubleTensor* THNN_CudaDoubleTensor_Unpack(PyObject* obj) {
99  return torch::nn::unpack<THCudaDoubleTensor>(obj);
100 }
101 
102 static inline THCudaLongTensor* THNN_CudaLongTensor_Unpack(PyObject* obj) {
103  return torch::nn::unpack<THCudaLongTensor>(obj);
104 }
105 
106 #endif // USE_CUDA
Definition: jit_type.h:17