Caffe2 - C++ API
A deep learning, cross platform ML framework
copy_utils.h
1 #pragma once
2 
3 #include <functional>
4 #include <vector>
5 #include <torch/csrc/Types.h>
6 
7 typedef std::function<void(PyObject*, PyObject*, bool)> THPCopyFunction;
8 struct THPCopyInfo {
9  PyTypeObject* srcType; // Python type of src tensor/storage
10  THPCopyFunction copy; // copy function
11  bool non_blocking; // true if copy implements an 'non_blocking' copy
12  bool broadcast; // true if the copy implements a broadcast copy
13 };
14 typedef std::vector<THPCopyInfo> THPCopyList;
15 
16 inline bool tryTHPCopy(const THPCopyList& v, PyObject* dst, PyObject* src, bool non_blocking, bool broadcast)
17 {
18  for (auto& i : v) {
19  if (i.non_blocking == non_blocking && PyType_IsSubtype(Py_TYPE(src), i.srcType)) {
20  (i.copy)(dst, src, broadcast);
21  return true;
22  }
23  }
24  return false;
25 }
26 
27 inline bool THPCopy(const THPCopyList& v, PyObject* dst, PyObject* src, bool non_blocking, bool broadcast)
28 {
29  if (tryTHPCopy(v, dst, src, non_blocking, broadcast)) {
30  return true;
31  } else if (non_blocking && tryTHPCopy(v, dst, src, false, broadcast)) {
32  return true;
33  }
34  THPUtils_setError("copy from %s to %s isn't implemented",
35  THPUtils_typename(src), THPUtils_typename(dst));
36  return false;
37 }
38 
39 inline PyObject * THPStorageCopyMethod(const THPCopyList& v, PyObject *self, PyObject *args, PyObject *kwargs)
40 {
41  PyObject *src;
42  int non_blocking = 0;
43  static char *kwlist[] = {"source", "non_blocking", nullptr};
44  // use int as parse type because bool not available in python2.
45  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|i:copy_", kwlist, &src, &non_blocking)) {
46  return nullptr;
47  }
48 
49  if (!THPCopy(v, self, src, non_blocking, false)) {
50  return nullptr;
51  }
52 
53  Py_INCREF(self);
54  return self;
55 }
56 
57 template <typename THPStorageDst, typename THPStorageSrc, typename StorageDst, typename StorageSrc>
58 void THPInsertStorageCopyFunction(
59  PyTypeObject *srcType,
60  THPCopyList& copyList,
61  void (*copyFunc)(LIBRARY_STATE_TYPE StorageDst* x, StorageSrc* z),
62  bool non_blocking=false)
63 {
64  auto wrapper = [copyFunc](PyObject* dst_, PyObject* src_, bool broadcast) {
65  auto dst = ((THPStorageDst*)dst_)->cdata;
66  auto src = ((THPStorageSrc*)src_)->cdata;
67 
68  PyThreadState *_save = nullptr;
69  try {
70  Py_UNBLOCK_THREADS;
71  copyFunc(LIBRARY_STATE dst, src);
72  Py_BLOCK_THREADS;
73  } catch (...) {
74  if (_save) {
75  Py_BLOCK_THREADS;
76  }
77  throw;
78  }
79  };
80 
81  copyList.push_back({ srcType, wrapper, non_blocking, false });
82 }