Caffe2 - C++ API
A deep learning, cross platform ML framework
Storage.cpp
1 #ifndef TH_GENERIC_FILE
2 #define TH_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
3 #else
4 
5 PyObject *THPStorageClass = nullptr;
6 
7 PyObject * THPStorage_(New)(THWStorage *ptr)
8 {
9  AT_ASSERT(ptr);
10  PyTypeObject *type = (PyTypeObject *)THPStorageClass;
11  PyObject *obj = type->tp_alloc(type, 0);
12  if (obj) {
13  ((THPStorage *)obj)->cdata = ptr;
14  } else {
15  THWStorage_(free)(LIBRARY_STATE ptr);
16  }
17  return obj;
18 }
19 
20 static void THPStorage_(dealloc)(THPStorage* self)
21 {
22  THWStorage_(free)(LIBRARY_STATE self->cdata);
23  Py_TYPE(self)->tp_free((PyObject*)self);
24 }
25 
26 static THWStorage* THPStorage_(newWithAllocator)(int64_t size, at::Allocator* allocator)
27 {
28 #if defined(THC_GENERIC_FILE) || defined(THD_GENERIC_FILE)
29  THPUtils_setError(THPStorageStr " does not support custom allocators");
30  return nullptr;
31 #else
32  return THWStorage_(newWithAllocator)(LIBRARY_STATE size, allocator);
33 #endif
34 }
35 
36 static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObject *kwargs)
37 {
38  HANDLE_TH_ERRORS
39  Py_ssize_t num_args = args ? PyTuple_Size(args) : 0;
40 
41  THPStoragePtr self((THPStorage *)type->tp_alloc(type, 0));
42  THPUtils_assert(self, "failed to allocate a " THPStorageStr " object");
43  c10::Allocator* allocator = nullptr;
44 
45  // Internally we allow constructing with a keywoard only argument cdata
46  if (kwargs != nullptr) {
47  PyObject *allocator_ptr = PyDict_GetItemString(kwargs, "allocator");
48  if (allocator_ptr) {
49  THPUtils_assert(THPUtils_checkLong(allocator_ptr), "invalid allocator");
50  allocator = static_cast<c10::Allocator*>(PyLong_AsVoidPtr(allocator_ptr));
51  PyDict_DelItemString(kwargs, "allocator");
52  }
53 
54  Py_ssize_t num_kwargs = PyDict_Size(kwargs);
55  if (num_args == 0) {
56  PyObject *cdata_ptr = PyDict_GetItemString(kwargs, "cdata");
57  if (num_kwargs == 1 && cdata_ptr && THPUtils_checkLong(cdata_ptr)) {
58  THWStorage *ptr = (THWStorage*)PyLong_AsVoidPtr(cdata_ptr);
59  self->cdata = ptr;
60  return (PyObject*)self.release();
61  }
62  }
63  THPUtils_assert(num_kwargs == 0, THPStorageStr "(): invalid keyword arguments");
64  }
65 
66  // torch.Storage()
67  if (num_args == 0) {
68  if (allocator) {
69  self->cdata = THPStorage_(newWithAllocator)(0, allocator);
70  } else {
71  self->cdata = THWStorage_(new)(LIBRARY_STATE_NOARGS);
72  }
73  return (PyObject*)self.release();
74  }
75 
76  PyObject *first_arg = PyTuple_GET_ITEM(args, 0);
77 
78  // torch.Storage(size)
79  if (num_args == 1 && THPUtils_checkLong(first_arg)) {
80  int64_t size = THPUtils_unpackLong(first_arg);
81  if (allocator) {
82  self->cdata = THPStorage_(newWithAllocator)(size, allocator);
83  } else {
84  self->cdata = THWStorage_(newWithSize)(LIBRARY_STATE size);
85  }
86  return (PyObject*)self.release();
87  }
88 
89  // torch.Storage(view_source, [offset, [size]])
90  if (num_args < 4 && THPStorage_(Check)(first_arg)) {
91  THPUtils_setError("storage views not supported");
92  return nullptr;
93  }
94 
95  // torch.Storage(sequence)
96  if (num_args == 1 && PySequence_Check(first_arg)) {
97 #ifdef THD_GENERIC_FILE
98  THPUtils_setError("distributed storages don't support construction from a sequence");
99 #else
100  Py_ssize_t length = PySequence_Length(first_arg);
101  THPUtils_assert(length >= 0, "couldn't obtain the length of %s",
102  THPUtils_typename(first_arg));
103  self->cdata = THWStorage_(newWithSize)(LIBRARY_STATE length);
104  THPObjectPtr item;
105  try {
106  for (Py_ssize_t i = 0; i < length; i++) {
107  item = PySequence_GetItem(first_arg, i);
108  scalar_t value = THPUtils_(unpackReal)(item.get());
109 #if !defined(THC_GENERIC_FILE)
110  self->cdata->unsafe_data<scalar_t>()[i] = value;
111 #else
112  // TODO: this might be slow - consider batched updates?
113  THCStorage_(set)(LIBRARY_STATE self->cdata, i, value);
114 #endif
115  }
116  } catch (const std::exception &e) {
117  THPUtils_setError("tried to construct a storage from a sequence (%s), "
118  "but one of the items was of type %s instead of %s",
119  THPUtils_typename(first_arg),
120  THPUtils_typename(item.get()),
121  THPUtils_typeTraits<scalar_t>::python_type_str);
122  return nullptr;
123  }
124  return (PyObject*)self.release();
125 #endif
126  }
127 
128  THPUtils_invalidArguments(args, kwargs, THPStorageStr " constructor", 6,
129  "no arguments",
130  "(int size)",
131  "(Sequence data)",
132  "(" THPStorageStr " view_source)",
133  "(" THPStorageStr " view_source, int offset)",
134  "(" THPStorageStr " view_source, int offset, int size)");
135  return nullptr;
136  END_HANDLE_TH_ERRORS
137 }
138 
139 static Py_ssize_t THPStorage_(length)(THPStorage *self)
140 {
141  HANDLE_TH_ERRORS
142  return THWStorage_(size)(LIBRARY_STATE self->cdata);
143  END_HANDLE_TH_ERRORS_RET(-1)
144 }
145 
146 static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
147 {
148  HANDLE_TH_ERRORS
149  /* Integer index */
150  if (THPUtils_checkLong(index)) {
151  int64_t nindex = THPUtils_unpackLong(index);
152  if (nindex < 0)
153  nindex += THWStorage_(size)(LIBRARY_STATE self->cdata);
154  if (nindex < 0 || nindex >= self->cdata->numel()) {
155  PyErr_Format(PyExc_IndexError, "index %" PRId64 " out of range for storage of "
156  "size %" PRId64, (int64_t) nindex, (int64_t) self->cdata->numel());
157  return nullptr;
158  }
159  scalar_t value = THWStorage_(get)(LIBRARY_STATE self->cdata, nindex);
160  return THPUtils_(newReal)(value);
161  /* Slice index */
162  } else if (PySlice_Check(index)) {
163  Py_ssize_t start, stop, slicelength, step;
164  int64_t len = THWStorage_(size)(LIBRARY_STATE self->cdata);
165  if (!THPUtils_parseSlice(index, len, &start, &stop, &step, &slicelength))
166  return nullptr;
167  if (step != 1) {
168  THPUtils_setError("Trying to slice with a step of %" PRId64 ", but only a step of "
169  "1 is supported", (int64_t)step);
170  return nullptr;
171  }
172 
173  scalar_t *data = THWStorage_(data)(LIBRARY_STATE self->cdata);
174 
175  at::StorageImpl* old_storage = self->cdata;
176  c10::raw::intrusive_ptr::incref(old_storage);
177  at::Storage new_storage(c10::make_intrusive<at::StorageImpl>(
178  old_storage->dtype(),
179  slicelength,
180  at::DataPtr(static_cast<void*>(data + start),
181  old_storage,
182  [](void* s) { c10::raw::intrusive_ptr::decref(static_cast<at::StorageImpl*>(s)); },
183  old_storage->device()),
184  old_storage->allocator(),
185  /* resizable */ false));
186 
187  PyObject *_ret = THPStorage_(New)(new_storage.unsafeReleaseStorageImpl());
188  return _ret;
189  }
190  PyErr_Format(PyExc_TypeError, "can't index a " THPStorageStr " with %s",
191  THPUtils_typename(index));
192  return nullptr;
193  END_HANDLE_TH_ERRORS
194 }
195 
196 static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
197 {
198  HANDLE_TH_ERRORS
199  if (!THPUtils_(checkReal)(value)) {
200  THPUtils_setError("can only set storage content with a %s, but got "
201  "%s instead", THPUtils_typeTraits<scalar_t>::python_type_str,
202  THPUtils_typename(value));
203  return -1;
204  }
205 
206  scalar_t rvalue = THPUtils_(unpackReal)(value);
207  if (THPUtils_checkLong(index)) {
208  int64_t nindex = THPUtils_unpackLong(index);
209  THWStorage_(set)(LIBRARY_STATE self->cdata, nindex, rvalue);
210  return 0;
211  } else if (PySlice_Check(index)) {
212  Py_ssize_t start, stop, slicelength, step;
213  int64_t len = THWStorage_(size)(LIBRARY_STATE self->cdata);
214  if (!THPUtils_parseSlice(index, len, &start, &stop, &step, &slicelength))
215  return -1;
216  if (step != 1) {
217  THPUtils_setError("Trying to slice with a step of %" PRId64 ", but only a step of "
218  "1 is supported", (int64_t)step);
219  return 0;
220  }
221  // TODO: check the bounds only once
222  // TODO: fill?
223  for (;start < stop; start++)
224  THWStorage_(set)(LIBRARY_STATE self->cdata, start, rvalue);
225  return 0;
226  }
227  THPUtils_setError("can't index a " THPStorageStr " with %s",
228  THPUtils_typename(index));
229  return -1;
230  END_HANDLE_TH_ERRORS_RET(-1)
231 }
232 
233 static PyMappingMethods THPStorage_(mappingmethods) = {
234  (lenfunc)THPStorage_(length),
235  (binaryfunc)THPStorage_(get),
236  (objobjargproc)THPStorage_(set)
237 };
238 
239 // TODO: implement equality
240 PyTypeObject THPStorageType = {
241  PyVarObject_HEAD_INIT(nullptr, 0)
242  "torch._C." THPStorageBaseStr, /* tp_name */
243  sizeof(THPStorage), /* tp_basicsize */
244  0, /* tp_itemsize */
245  (destructor)THPStorage_(dealloc), /* tp_dealloc */
246  nullptr, /* tp_print */
247  nullptr, /* tp_getattr */
248  nullptr, /* tp_setattr */
249  nullptr, /* tp_reserved */
250  nullptr, /* tp_repr */
251  nullptr, /* tp_as_number */
252  nullptr, /* tp_as_sequence */
253  &THPStorage_(mappingmethods), /* tp_as_mapping */
254  nullptr, /* tp_hash */
255  nullptr, /* tp_call */
256  nullptr, /* tp_str */
257  nullptr, /* tp_getattro */
258  nullptr, /* tp_setattro */
259  nullptr, /* tp_as_buffer */
260  Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
261  nullptr, /* tp_doc */
262  nullptr, /* tp_traverse */
263  nullptr, /* tp_clear */
264  nullptr, /* tp_richcompare */
265  0, /* tp_weaklistoffset */
266  nullptr, /* tp_iter */
267  nullptr, /* tp_iternext */
268  nullptr, /* will be assigned in init */ /* tp_methods */
269  nullptr, /* will be assigned in init */ /* tp_members */
270  nullptr, /* tp_getset */
271  nullptr, /* tp_base */
272  nullptr, /* tp_dict */
273  nullptr, /* tp_descr_get */
274  nullptr, /* tp_descr_set */
275  0, /* tp_dictoffset */
276  nullptr, /* tp_init */
277  nullptr, /* tp_alloc */
278  THPStorage_(pynew), /* tp_new */
279 };
280 
281 static struct PyMemberDef THPStorage_(members)[] = {
282  {(char*)"_cdata", T_ULONGLONG, offsetof(THPStorage, cdata), READONLY, nullptr},
283  {nullptr}
284 };
285 
286 extern THPCopyList THWStorage_(copy_functions);
287 THPCopyList THWStorage_(copy_functions);
288 
289 void THPStorage_(initCopyMethods)()
290 {
291 #ifndef THD_GENERIC_FILE
292  auto& h = THWStorage_(copy_functions);
293  // copy from CPU types
294  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPByteStorageType, h, &THWStorage_(copyByte));
295  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPCharStorageType, h, &THWStorage_(copyChar));
296  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPShortStorageType, h, &THWStorage_(copyShort));
297  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPIntStorageType, h, &THWStorage_(copyInt));
298  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPLongStorageType, h, &THWStorage_(copyLong));
299  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPHalfStorageType, h, &THWStorage_(copyHalf));
300  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPFloatStorageType, h, &THWStorage_(copyFloat));
301  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPDoubleStorageType, h, &THWStorage_(copyDouble));
302  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPBoolStorageType, h, &THWStorage_(copyBool));
303 #ifdef THC_GENERIC_FILE
304  // copy from GPU types
305  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPByteStorageType, h, &THWStorage_(copyCudaByte));
306  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPCharStorageType, h, &THWStorage_(copyCudaChar));
307  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPShortStorageType, h, &THWStorage_(copyCudaShort));
308  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPIntStorageType, h, &THWStorage_(copyCudaInt));
309  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPLongStorageType, h, &THWStorage_(copyCudaLong));
310  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPFloatStorageType, h, &THWStorage_(copyCudaFloat));
311  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPDoubleStorageType, h, &THWStorage_(copyCudaDouble));
312  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPHalfStorageType, h, &THWStorage_(copyCudaHalf));
313  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPBoolStorageType, h, &THWStorage_(copyCudaBool));
314  // add CPU <- GPU copies to base type
316  #define THCpuStorage_(name) TH_CONCAT_4(TH, Real, Storage_, name)
317  extern THPCopyList THCpuStorage_(copy_functions);
318  auto& b = THCpuStorage_(copy_functions);
319  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPByteStorageType, b, &THCpuStorage_(copyCudaByte));
320  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPCharStorageType, b, &THCpuStorage_(copyCudaChar));
321  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPShortStorageType, b, &THCpuStorage_(copyCudaShort));
322  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPIntStorageType, b, &THCpuStorage_(copyCudaInt));
323  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPLongStorageType, b, &THCpuStorage_(copyCudaLong));
324  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPFloatStorageType, b, &THCpuStorage_(copyCudaFloat));
325  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPDoubleStorageType, b, &THCpuStorage_(copyCudaDouble));
326  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPHalfStorageType, b, &THCpuStorage_(copyCudaHalf));
327  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPBoolStorageType, b, &THCpuStorage_(copyCudaBool));
328  #undef THCpuStorage
329  #undef THCpuStorage_
330 #endif
331 #endif // !defined(THD_GENERIC_FILE)
332 }
333 
334 #include <torch/csrc/generic/StorageMethods.cpp>
335 #ifndef THD_GENERIC_FILE
336 #include <torch/csrc/generic/StorageSharing.cpp>
337 #endif
338 
339 bool THPStorage_(init)(PyObject *module)
340 {
341  static std::vector<PyMethodDef> methods;
342  THPUtils_addPyMethodDefs(methods, THPStorage_(methods));
343 #ifndef THD_GENERIC_FILE
344  THPUtils_addPyMethodDefs(methods, THPStorage_(sharingMethods));
345 #endif
346 
347  THPStorageType.tp_methods = methods.data();
348  THPStorageType.tp_members = THPStorage_(members);
349  if (PyType_Ready(&THPStorageType) < 0)
350  return false;
351  Py_INCREF(&THPStorageType);
352  PyModule_AddObject(module, THPStorageBaseStr, (PyObject *)&THPStorageType);
353  THPStorage_(initCopyMethods)();
354  return true;
355 }
356 
357 void THPStorage_(postInit)(PyObject *module)
358 {
359  THPStorageClass = PyObject_GetAttrString(module,(char*)TH_CONCAT_STRING_2(Real,Storage));
360  if (!THPStorageClass) throw python_error();
361 
362  bool is_cuda = false;
363 #ifdef THC_GENERIC_FILE
364  is_cuda = true;
365 #endif
366  const char *type_name = TH_CONCAT_STRING_2(Real,);
367  torch::registerStoragePyTypeObject((PyTypeObject*)THPStorageClass, type_name, is_cuda, false);
368 }
369 
370 #endif