4 #include <torch/csrc/python_headers.h> 6 #include <torch/csrc/utils/python_numbers.h> 7 #include <torch/csrc/Exceptions.h> 9 namespace torch {
namespace utils {
11 inline void store_scalar(
void* data, at::ScalarType scalarType, PyObject* obj) {
13 case at::kByte: *(uint8_t*)data = (uint8_t)THPUtils_unpackLong(obj);
break;
14 case at::kChar: *(
char*)data = (
char)THPUtils_unpackLong(obj);
break;
15 case at::kShort: *(int16_t*)data = (int16_t)THPUtils_unpackLong(obj);
break;
16 case at::kInt: *(int32_t*)data = (int32_t)THPUtils_unpackLong(obj);
break;
17 case at::kLong: *(int64_t*)data = THPUtils_unpackLong(obj);
break;
19 *(
at::Half*)data = at::convert<at::Half, double>(THPUtils_unpackDouble(obj));
21 case at::kFloat: *(
float*)data = (
float)THPUtils_unpackDouble(obj);
break;
22 case at::kDouble: *(
double*)data = THPUtils_unpackDouble(obj);
break;
23 case at::kComplexFloat: *(std::complex<float>*)data = (std::complex<float>)THPUtils_unpackComplexDouble(obj);
break;
24 case at::kComplexDouble: *(std::complex<double>*)data = THPUtils_unpackComplexDouble(obj);
break;
25 case at::kBool: *(
bool*)data = (
bool)THPUtils_unpackLong(obj);
break;
26 default:
throw std::runtime_error(
"invalid type");
30 inline PyObject* load_scalar(
void* data, at::ScalarType scalarType) {
32 case at::kByte:
return THPUtils_packInt64(*(uint8_t*)data);
33 case at::kChar:
return THPUtils_packInt64(*(
char*)data);
34 case at::kShort:
return THPUtils_packInt64(*(int16_t*)data);
35 case at::kInt:
return THPUtils_packInt64(*(int32_t*)data);
36 case at::kLong:
return THPUtils_packInt64(*(int64_t*)data);
37 case at::kHalf:
return PyFloat_FromDouble(at::convert<double, at::Half>(*(
at::Half*)data));
38 case at::kFloat:
return PyFloat_FromDouble(*(
float*)data);
39 case at::kDouble:
return PyFloat_FromDouble(*(
double*)data);
40 case at::kComplexFloat:
return PyComplex_FromCComplex(*reinterpret_cast<Py_complex *>((std::complex<float>*)data));
41 case at::kComplexDouble:
return PyComplex_FromCComplex(*reinterpret_cast<Py_complex *>((std::complex<double>*)data));
42 case at::kBool:
return PyBool_FromLong(*(uint8_t*)data);
43 default:
throw std::runtime_error(
"invalid type");