5 #include <torch/csrc/Types.h> 7 typedef std::function<void(PyObject*, PyObject*, bool)> THPCopyFunction;
14 typedef std::vector<THPCopyInfo> THPCopyList;
16 inline bool tryTHPCopy(
const THPCopyList& v, PyObject* dst, PyObject* src,
bool non_blocking,
bool broadcast)
19 if (i.non_blocking == non_blocking && PyType_IsSubtype(Py_TYPE(src), i.srcType)) {
20 (i.copy)(dst, src, broadcast);
27 inline bool THPCopy(
const THPCopyList& v, PyObject* dst, PyObject* src,
bool non_blocking,
bool broadcast)
29 if (tryTHPCopy(v, dst, src, non_blocking, broadcast)) {
31 }
else if (non_blocking && tryTHPCopy(v, dst, src,
false, broadcast)) {
34 THPUtils_setError(
"copy from %s to %s isn't implemented",
35 THPUtils_typename(src), THPUtils_typename(dst));
39 inline PyObject * THPStorageCopyMethod(
const THPCopyList& v, PyObject *
self, PyObject *args, PyObject *kwargs)
43 static char *kwlist[] = {
"source",
"non_blocking",
nullptr};
45 if (!PyArg_ParseTupleAndKeywords(args, kwargs,
"O|i:copy_", kwlist, &src, &non_blocking)) {
49 if (!THPCopy(v,
self, src, non_blocking,
false)) {
57 template <
typename THPStorageDst,
typename THPStorageSrc,
typename StorageDst,
typename StorageSrc>
58 void THPInsertStorageCopyFunction(
59 PyTypeObject *srcType,
60 THPCopyList& copyList,
61 void (*copyFunc)(LIBRARY_STATE_TYPE StorageDst* x, StorageSrc* z),
62 bool non_blocking=
false)
64 auto wrapper = [copyFunc](PyObject* dst_, PyObject* src_,
bool broadcast) {
65 auto dst = ((THPStorageDst*)dst_)->cdata;
66 auto src = ((THPStorageSrc*)src_)->cdata;
68 PyThreadState *_save =
nullptr;
71 copyFunc(LIBRARY_STATE dst, src);
81 copyList.push_back({ srcType, wrapper, non_blocking,
false });