1 #include <torch/csrc/python_headers.h> 7 #include <unordered_map> 8 #include <torch/csrc/THP.h> 9 #include <torch/csrc/utils/python_strings.h> 10 #include <torch/csrc/utils/invalid_arguments.h> 11 #include <torch/csrc/autograd/variable.h> 12 #include <torch/csrc/DynamicTypes.h> 14 #include <torch/csrc/generic/utils.cpp> 15 #include <TH/THGenerateAllTypes.h> 17 #include <torch/csrc/generic/utils.cpp> 18 #include <TH/THGenerateHalfType.h> 20 #include <torch/csrc/generic/utils.cpp> 21 #include <TH/THGenerateBoolType.h> 23 int THPUtils_getCallable(PyObject *arg, PyObject **result) {
24 if (!PyCallable_Check(arg))
30 THLongStoragePtr THPUtils_unpackSize(PyObject *arg) {
31 THLongStoragePtr result;
32 if (!THPUtils_tryUnpackLongs(arg, result)) {
33 std::string msg =
"THPUtils_unpackSize() expects a torch.Size (got '";
34 msg += Py_TYPE(arg)->tp_name;
36 throw std::runtime_error(msg);
41 bool THPUtils_tryUnpackLongs(PyObject *arg, THLongStoragePtr& result) {
42 bool tuple = PyTuple_Check(arg);
43 bool list = PyList_Check(arg);
45 int nDim = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
46 THLongStoragePtr storage(THLongStorage_newWithSize(nDim));
47 for (
int i = 0; i != nDim; ++i) {
48 PyObject* item = tuple ? PyTuple_GET_ITEM(arg, i) : PyList_GET_ITEM(arg, i);
49 if (!THPUtils_checkLong(item)) {
52 THLongStorage_set(storage, i, THPUtils_unpackLong(item));
54 result = std::move(storage);
60 std::vector<int64_t> THPUtils_unpackLongs(PyObject *arg) {
61 bool tuple = PyTuple_Check(arg);
62 bool list = PyList_Check(arg);
64 int nDim = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
65 std::vector<int64_t> sizes(nDim);
66 for (
int i = 0; i != nDim; ++i) {
67 PyObject* item = tuple ? PyTuple_GET_ITEM(arg, i) : PyList_GET_ITEM(arg, i);
68 if (!THPUtils_checkLong(item)) {
69 std::ostringstream oss;
70 oss <<
"expected int at position " << i <<
", but got: " << THPUtils_typename(item);
71 throw std::runtime_error(oss.str());
73 sizes[i] = THPUtils_unpackLong(item);
77 throw std::runtime_error(
"Expected tuple or list");
80 bool THPUtils_tryUnpackLongVarArgs(PyObject *args,
int ignore_first, THLongStoragePtr& result) {
81 Py_ssize_t length = PyTuple_Size(args) - ignore_first;
86 PyObject *first_arg = PyTuple_GET_ITEM(args, ignore_first);
87 if (length == 1 && THPUtils_tryUnpackLongs(first_arg, result)) {
92 result = THLongStorage_newWithSize(length);
93 for (Py_ssize_t i = 0; i < length; ++i) {
94 PyObject *arg = PyTuple_GET_ITEM(args, i + ignore_first);
95 if (!THPUtils_checkLong(arg)) {
98 THLongStorage_set(result, i, THPUtils_unpackLong(arg));
103 bool THPUtils_checkIntTuple(PyObject *arg)
105 if (!PyTuple_Check(arg)) {
108 for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) {
109 if (!THPUtils_checkLong(PyTuple_GET_ITEM(arg, i))) {
116 std::vector<int> THPUtils_unpackIntTuple(PyObject *arg)
118 if (!THPUtils_checkIntTuple(arg)) {
119 throw std::runtime_error(
"Couldn't unpack int tuple");
121 std::vector<int> values(PyTuple_GET_SIZE(arg));
122 for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) {
123 values[i] = (int)THPUtils_unpackLong(PyTuple_GET_ITEM(arg, i));
128 void THPUtils_setError(
const char *format, ...)
130 static const size_t ERROR_BUFFER_SIZE = 1000;
131 char buffer[ERROR_BUFFER_SIZE];
134 va_start(fmt_args, format);
135 vsnprintf(buffer, ERROR_BUFFER_SIZE, format, fmt_args);
137 PyErr_SetString(PyExc_RuntimeError, buffer);
140 void THPUtils_addPyMethodDefs(std::vector<PyMethodDef>& vector, PyMethodDef* methods)
142 if (!vector.empty()) {
147 vector.push_back(*methods);
148 if (!methods->ml_name) {
155 static const char* classOrTypename(PyObject* obj) {
156 if (PyType_Check(obj)) {
157 return ((PyTypeObject*)obj)->tp_name;
159 return Py_TYPE(obj)->tp_name;
162 PyObject * THPUtils_dispatchStateless(
163 PyObject *tensor,
const char *name, PyObject *args, PyObject *kwargs)
165 THPObjectPtr methods(PyObject_GetAttrString(tensor, THP_STATELESS_ATTRIBUTE_NAME));
169 "Type %s doesn't implement stateless methods",
170 classOrTypename(tensor));
172 THPObjectPtr method(PyObject_GetAttrString(methods, name));
176 "Type %s doesn't implement stateless method %s",
177 classOrTypename(tensor),
180 return PyObject_Call(method.get(), args, kwargs);
183 void THPUtils_invalidArguments(PyObject *given_args, PyObject *given_kwargs,
184 const char *function_name,
size_t num_options, ...) {
185 std::vector<std::string> option_strings;
187 va_start(option_list, num_options);
188 for (
size_t i = 0; i < num_options; i++)
189 option_strings.emplace_back(va_arg(option_list,
const char*));
192 PyErr_SetString(PyExc_TypeError, torch::format_invalid_args(
193 given_args, given_kwargs, function_name, option_strings).c_str());
204 static bool backCompatBroadcastWarn =
false;
206 void setBackCompatBroadcastWarn(
bool warn) {
207 backCompatBroadcastWarn = warn;
210 bool getBackCompatBroadcastWarn() {
211 return backCompatBroadcastWarn;
214 static bool backCompatKeepdimWarn =
false;
216 void setBackCompatKeepdimWarn(
bool warn) {
217 backCompatKeepdimWarn = warn;
220 bool getBackCompatKeepdimWarn() {
221 return backCompatKeepdimWarn;
224 bool maybeThrowBackCompatKeepdimWarn(
char *func) {
225 if(getBackCompatKeepdimWarn()) {
226 std::ostringstream ss;
227 ss <<
"backwards compatibility: call to \"" << func
228 <<
"\" uses default value for keepdim which has changed default to False. Consider passing as kwarg.",
229 PyErr_WarnEx(PyExc_UserWarning, ss.str().c_str(), 1);
237 THTensor_free(LIBRARY_STATE ptr);