Caffe2 - C++ API
A deep learning, cross platform ML framework
utils.cpp
1 #include <torch/csrc/python_headers.h>
2 #include <stdarg.h>
3 #include <string>
4 #include <torch/csrc/cuda/THCP.h>
5 
6 #include <torch/csrc/cuda/override_macros.h>
7 
8 #define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
9 #include <THC/THCGenerateAllTypes.h>
10 
11 #define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
12 #include <THC/THCGenerateBoolType.h>
13 
14 #ifdef USE_CUDA
15 // NB: It's a list of *optional* CUDAStream; when nullopt, that means to use
16 // whatever the current stream of the device the input is associated with was.
17 std::vector<c10::optional<at::cuda::CUDAStream>> THPUtils_PySequence_to_CUDAStreamList(PyObject *obj) {
18  if (!PySequence_Check(obj)) {
19  throw std::runtime_error("Expected a sequence in THPUtils_PySequence_to_CUDAStreamList");
20  }
21  THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr));
22  if (seq.get() == nullptr) {
23  throw std::runtime_error("expected PySequence, but got " + std::string(THPUtils_typename(obj)));
24  }
25 
26  std::vector<c10::optional<at::cuda::CUDAStream>> streams;
27  Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
28  for (Py_ssize_t i = 0; i < length; i++) {
29  PyObject *stream = PySequence_Fast_GET_ITEM(seq.get(), i);
30 
31  if (PyObject_IsInstance(stream, THCPStreamClass)) {
32  // Spicy hot reinterpret cast!!
33  streams.emplace_back( at::cuda::CUDAStream::unpack((reinterpret_cast<THCPStream*>(stream))->cdata) );
34  } else if (stream == Py_None) {
35  streams.emplace_back();
36  } else {
37  std::runtime_error("Unknown data type found in stream list. Need torch.cuda.Stream or None");
38  }
39  }
40  return streams;
41 }
42 
43 #endif