1 #include <torch/csrc/python_headers.h> 2 #include <system_error> 4 #include <torch/csrc/THP.h> 5 #include <torch/csrc/serialization.h> 8 ssize_t doPartialRead(io fildes,
void* buf,
size_t nbytes);
11 ssize_t doPartialWrite(io fildes,
void* buf,
size_t nbytes);
13 static ssize_t doPartialPythonReadBuffered(PyObject* fildes,
void* buf,
size_t nbytes);
14 static ssize_t doPartialPythonReadInto(PyObject* fildes,
void* buf,
size_t nbytes);
15 static ssize_t doPartialPythonWrite(PyObject* fildes,
void* buf,
size_t nbytes);
18 ssize_t doPartialRead<int>(
int fildes,
void* buf,
size_t nbytes) {
19 return read(fildes, buf, nbytes);
23 ssize_t doPartialRead<PyObject*>(PyObject* fildes,
void* buf,
size_t nbytes) {
27 auto has_readinto = PyObject_HasAttrString(fildes,
"readinto") == 1;
29 return doPartialPythonReadInto(fildes, buf, nbytes);
31 return doPartialPythonReadBuffered(fildes, buf, nbytes);
35 ssize_t doPartialWrite<int>(
int fildes,
void* buf,
size_t nbytes) {
36 return write(fildes, buf, nbytes);
40 ssize_t doPartialWrite<PyObject*>(PyObject* fildes,
void* buf,
size_t nbytes) {
41 return doPartialPythonWrite(fildes, buf, nbytes);
44 static inline bool isUnsupportedOperation() {
47 THPObjectPtr exception(PyObject_GetAttrString(io,
"UnsupportedOperation"));
49 return PyErr_ExceptionMatches(exception.get());
53 static inline ssize_t doPartialPythonReadBuffered(PyObject* fildes,
void* buf,
size_t raw_nbytes) {
59 const size_t nbytes = std::min<size_t>(raw_nbytes, 262144u);
61 THPObjectPtr r(PyObject_CallMethod(fildes,
"read",
"i", nbytes));
65 #if PY_MAJOR_VERSION >= 3 66 auto size = PyBytes_GET_SIZE(r.get());
67 const void* py_buf = PyBytes_AsString(r.get());
69 auto size = PyString_GET_SIZE(r.get());
70 const void* py_buf = PyString_AsString(r.get());
79 memcpy(buf, py_buf, size);
85 static inline ssize_t doPartialPythonIO(PyObject* fildes,
void* buf,
size_t nbytes,
bool is_read) {
86 #if PY_MAJOR_VERSION >= 3 87 auto rw_flag = is_read ? PyBUF_WRITE : PyBUF_READ;
89 reinterpret_cast<char*>(buf), nbytes, rw_flag));
91 THPObjectPtr memview(PyBuffer_FromReadWriteMemory(buf, nbytes));
95 char* method =
"write";
99 THPObjectPtr r(PyObject_CallMethod(fildes, method,
"O", memview.get()));
101 return PyLong_AsSsize_t(r.get());
105 if (is_read && isUnsupportedOperation()) {
107 return doPartialPythonReadBuffered(fildes, buf, nbytes);
113 static ssize_t doPartialPythonReadInto(PyObject* fildes,
void* buf,
size_t nbytes) {
114 return doPartialPythonIO(fildes, buf, nbytes,
true);
118 static ssize_t doPartialPythonWrite(PyObject* fildes,
void* buf,
size_t nbytes) {
119 return doPartialPythonIO(fildes, buf, nbytes,
false);
123 template <
typename io>
124 void doRead(io fildes,
void* raw_buf,
size_t nbytes) {
125 char* buf =
static_cast<char*
>(raw_buf);
130 ssize_t r = doPartialRead(fildes, buf, std::min<size_t>(nbytes, 1073741824));
133 AT_ASSERTM(err != 0,
"read(): impossible! r < 0, but no errno was set");
134 AT_ASSERTM(err != EAGAIN,
"read(): non-blocking fd ", fildes,
135 " read EAGAIN; cowardly refusing to spin-wait");
139 AT_ERROR(
"read(): fd ", fildes,
" failed with ", strerror(err));
147 AT_ASSERT(static_cast<size_t>(r) <= nbytes);
151 AT_ERROR(
"unexpected EOF, expected ", nbytes,
" more bytes. The file might be corrupted.");
155 template <
typename io>
156 void doWrite(io fildes,
void* raw_buf,
size_t nbytes) {
157 char* buf =
static_cast<char*
>(raw_buf);
162 ssize_t r = doPartialWrite(fildes, buf, std::min<size_t>(nbytes, 1073741824));
165 AT_ASSERTM(err != 0,
"write(): impossible! r < 0, but no errno was set");
166 AT_ASSERTM(err != EAGAIN,
"write(): non-blocking fd ", fildes,
167 " read EAGAIN; cowardly refusing to spin-wait");
171 AT_ERROR(
"write(): fd ", fildes,
" failed with ", strerror(err));
175 AT_ASSERT(static_cast<size_t>(r) <= nbytes);
180 #include <torch/csrc/generic/serialization.cpp> 181 #include <TH/THGenerateAllTypes.h> 183 #include <torch/csrc/generic/serialization.cpp> 184 #include <TH/THGenerateHalfType.h> 186 #include <torch/csrc/generic/serialization.cpp> 187 #include <TH/THGenerateBoolType.h>