Caffe2 - C++ API
A deep learning, cross platform ML framework
python_scalars.h
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <torch/csrc/python_headers.h>
5 
6 #include <torch/csrc/utils/python_numbers.h>
7 #include <torch/csrc/Exceptions.h>
8 
9 namespace torch { namespace utils {
10 
11 inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
12  switch (scalarType) {
13  case at::kByte: *(uint8_t*)data = (uint8_t)THPUtils_unpackLong(obj); break;
14  case at::kChar: *(char*)data = (char)THPUtils_unpackLong(obj); break;
15  case at::kShort: *(int16_t*)data = (int16_t)THPUtils_unpackLong(obj); break;
16  case at::kInt: *(int32_t*)data = (int32_t)THPUtils_unpackLong(obj); break;
17  case at::kLong: *(int64_t*)data = THPUtils_unpackLong(obj); break;
18  case at::kHalf:
19  *(at::Half*)data = at::convert<at::Half, double>(THPUtils_unpackDouble(obj));
20  break;
21  case at::kFloat: *(float*)data = (float)THPUtils_unpackDouble(obj); break;
22  case at::kDouble: *(double*)data = THPUtils_unpackDouble(obj); break;
23  case at::kComplexFloat: *(std::complex<float>*)data = (std::complex<float>)THPUtils_unpackComplexDouble(obj); break;
24  case at::kComplexDouble: *(std::complex<double>*)data = THPUtils_unpackComplexDouble(obj); break;
25  case at::kBool: *(bool*)data = (bool)THPUtils_unpackLong(obj); break;
26  default: throw std::runtime_error("invalid type");
27  }
28 }
29 
30 inline PyObject* load_scalar(void* data, at::ScalarType scalarType) {
31  switch (scalarType) {
32  case at::kByte: return THPUtils_packInt64(*(uint8_t*)data);
33  case at::kChar: return THPUtils_packInt64(*(char*)data);
34  case at::kShort: return THPUtils_packInt64(*(int16_t*)data);
35  case at::kInt: return THPUtils_packInt64(*(int32_t*)data);
36  case at::kLong: return THPUtils_packInt64(*(int64_t*)data);
37  case at::kHalf: return PyFloat_FromDouble(at::convert<double, at::Half>(*(at::Half*)data));
38  case at::kFloat: return PyFloat_FromDouble(*(float*)data);
39  case at::kDouble: return PyFloat_FromDouble(*(double*)data);
40  case at::kComplexFloat: return PyComplex_FromCComplex(*reinterpret_cast<Py_complex *>((std::complex<float>*)data));
41  case at::kComplexDouble: return PyComplex_FromCComplex(*reinterpret_cast<Py_complex *>((std::complex<double>*)data));
42  case at::kBool: return PyBool_FromLong(*(uint8_t*)data);
43  default: throw std::runtime_error("invalid type");
44  }
45 }
46 
47 }} // namespace torch::utils
Definition: jit_type.h:17