Caffe2 - C++ API
A deep learning, cross platform ML framework
serialization.cpp
1 #include <torch/csrc/python_headers.h>
2 #include <system_error>
3 
4 #include <torch/csrc/THP.h>
5 #include <torch/csrc/serialization.h>
6 
7 template <class io>
8 ssize_t doPartialRead(io fildes, void* buf, size_t nbytes);
9 
10 template <class io>
11 ssize_t doPartialWrite(io fildes, void* buf, size_t nbytes);
12 
13 static ssize_t doPartialPythonReadBuffered(PyObject* fildes, void* buf, size_t nbytes);
14 static ssize_t doPartialPythonReadInto(PyObject* fildes, void* buf, size_t nbytes);
15 static ssize_t doPartialPythonWrite(PyObject* fildes, void* buf, size_t nbytes);
16 
17 template <>
18 ssize_t doPartialRead<int>(int fildes, void* buf, size_t nbytes) {
19  return read(fildes, buf, nbytes);
20 }
21 
22 template <>
23 ssize_t doPartialRead<PyObject*>(PyObject* fildes, void* buf, size_t nbytes) {
24  // Try to use fildes.readinto() instead of fildes.read()
25  // because it is more memory efficient.
26  // TODO: Stop calling PyObject_HasAttrString() in a loop on our read loop
27  auto has_readinto = PyObject_HasAttrString(fildes, "readinto") == 1;
28  if (has_readinto) {
29  return doPartialPythonReadInto(fildes, buf, nbytes);
30  }
31  return doPartialPythonReadBuffered(fildes, buf, nbytes);
32 }
33 
34 template <>
35 ssize_t doPartialWrite<int>(int fildes, void* buf, size_t nbytes) {
36  return write(fildes, buf, nbytes);
37 }
38 
39 template <>
40 ssize_t doPartialWrite<PyObject*>(PyObject* fildes, void* buf, size_t nbytes) {
41  return doPartialPythonWrite(fildes, buf, nbytes);
42 }
43 
44 static inline bool isUnsupportedOperation() {
45  THPObjectPtr io(PyImport_ImportModule("io"));
46  if (!io) throw python_error();
47  THPObjectPtr exception(PyObject_GetAttrString(io, "UnsupportedOperation"));
48  if (!exception) throw python_error();
49  return PyErr_ExceptionMatches(exception.get());
50 }
51 
52 // Call Python fildes.read(nbytes) and copy it to buf.
53 static inline ssize_t doPartialPythonReadBuffered(PyObject* fildes, void* buf, size_t raw_nbytes) {
54  // If we request a large amount of data, f.read() will internally try to
55  // allocate a buffer of that size. This is counterproductive, because
56  // it's not the buffer we ultimately want to write the data into. Read
57  // less than that and avoid allocating too much extra memory.
58  // TODO: Maybe 260 KB is a bit small...
59  const size_t nbytes = std::min<size_t>(raw_nbytes, 262144u); // 2^18 (~260 KB)
60 
61  THPObjectPtr r(PyObject_CallMethod(fildes, "read", "i", nbytes));
62  if (!r) throw python_error();
63 
64  // read output is String (Python 2) / Bytes (Python 3)
65 #if PY_MAJOR_VERSION >= 3
66  auto size = PyBytes_GET_SIZE(r.get());
67  const void* py_buf = PyBytes_AsString(r.get());
68 #else
69  auto size = PyString_GET_SIZE(r.get());
70  const void* py_buf = PyString_AsString(r.get());
71 #endif
72 
73  // we read EOF
74  if (size == 0) {
75  return 0;
76  }
77 
78  // Slurp it into the buffer we actually want
79  memcpy(buf, py_buf, size);
80 
81  return size;
82 }
83 
84 // Either does fildes.readinto(buf) or fildes.write(buf)
85 static inline ssize_t doPartialPythonIO(PyObject* fildes, void* buf, size_t nbytes, bool is_read) {
86 #if PY_MAJOR_VERSION >= 3
87  auto rw_flag = is_read ? PyBUF_WRITE : PyBUF_READ;
88  THPObjectPtr memview(PyMemoryView_FromMemory(
89  reinterpret_cast<char*>(buf), nbytes, rw_flag));
90 #else
91  THPObjectPtr memview(PyBuffer_FromReadWriteMemory(buf, nbytes));
92 #endif
93  if (!memview) throw python_error();
94 
95  char* method = "write";
96  if (is_read) {
97  method = "readinto";
98  }
99  THPObjectPtr r(PyObject_CallMethod(fildes, method, "O", memview.get()));
100  if (r) {
101  return PyLong_AsSsize_t(r.get());
102  }
103 
104  // fildes.readinto can return UnsupportedOperation so fall back to fildes.read.
105  if (is_read && isUnsupportedOperation()) {
106  PyErr_Clear();
107  return doPartialPythonReadBuffered(fildes, buf, nbytes);
108  }
109  throw python_error();
110 }
111 
112 // Call Python fildes.readinto(buf)
113 static ssize_t doPartialPythonReadInto(PyObject* fildes, void* buf, size_t nbytes) {
114  return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ true);
115 }
116 
117 // Call Python fildes.write(buf)
118 static ssize_t doPartialPythonWrite(PyObject* fildes, void* buf, size_t nbytes) {
119  return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ false);
120 }
121 
122 // Requires that we read EXACTLY nbytes; fails if we don't.
123 template <typename io>
124 void doRead(io fildes, void* raw_buf, size_t nbytes) {
125  char* buf = static_cast<char*>(raw_buf);
126  while (nbytes > 0) {
127  errno = 0; // doPartialRead may not set errno
128  // we read in 1GB blocks to avoid bugs on Mac OS X Lion
129  // see https://github.com/pytorch/pytorch/issues/1031 for more details
130  ssize_t r = doPartialRead(fildes, buf, std::min<size_t>(nbytes, 1073741824));
131  if (r < 0) {
132  int err = errno;
133  AT_ASSERTM(err != 0, "read(): impossible! r < 0, but no errno was set");
134  AT_ASSERTM(err != EAGAIN, "read(): non-blocking fd ", fildes,
135  " read EAGAIN; cowardly refusing to spin-wait");
136  if (err == EINTR) {
137  continue;
138  } else {
139  AT_ERROR("read(): fd ", fildes, " failed with ", strerror(err));
140  }
141  } else if (r == 0) {
142  break;
143  }
144  buf += r;
145  // This is guaranteed by POSIX, but I just want to be double-sure
146  // to not underflow a signed integer.
147  AT_ASSERT(static_cast<size_t>(r) <= nbytes);
148  nbytes -= r;
149  }
150  if (nbytes != 0) {
151  AT_ERROR("unexpected EOF, expected ", nbytes, " more bytes. The file might be corrupted.");
152  }
153 }
154 
155 template <typename io>
156 void doWrite(io fildes, void* raw_buf, size_t nbytes) {
157  char* buf = static_cast<char*>(raw_buf);
158  while (nbytes > 0) {
159  errno = 0; // doPartialWrite may not set errno
160  // we write in 1GB blocks to avoid bugs on Mac OS X Lion
161  // see https://github.com/pytorch/pytorch/issues/1031 for more details
162  ssize_t r = doPartialWrite(fildes, buf, std::min<size_t>(nbytes, 1073741824));
163  if (r < 0) {
164  int err = errno;
165  AT_ASSERTM(err != 0, "write(): impossible! r < 0, but no errno was set");
166  AT_ASSERTM(err != EAGAIN, "write(): non-blocking fd ", fildes,
167  " read EAGAIN; cowardly refusing to spin-wait");
168  if (err == EINTR) {
169  continue;
170  } else {
171  AT_ERROR("write(): fd ", fildes, " failed with ", strerror(err));
172  }
173  }
174  buf += r;
175  AT_ASSERT(static_cast<size_t>(r) <= nbytes);
176  nbytes -= r;
177  }
178 }
179 
180 #include <torch/csrc/generic/serialization.cpp>
181 #include <TH/THGenerateAllTypes.h>
182 
183 #include <torch/csrc/generic/serialization.cpp>
184 #include <TH/THGenerateHalfType.h>
185 
186 #include <torch/csrc/generic/serialization.cpp>
187 #include <TH/THGenerateBoolType.h>