1 #include <torch/csrc/autograd/python_hook.h> 5 #include <torch/csrc/THP.h> 6 #include <torch/csrc/autograd/python_variable.h> 7 #include <torch/csrc/utils/auto_gil.h> 8 #include <torch/csrc/utils/object_ptr.h> 9 #include <torch/csrc/utils/python_strings.h> 10 #include <torch/csrc/Exceptions.h> 12 using torch::autograd::variable_list;
15 static PyObject* wrap_variables(
const variable_list& c_variables);
16 static variable_list unwrap_variables(PyObject* py_variables);
17 static std::string hook_name(PyObject* hook);
18 static void check_result(PyObject* original, PyObject* result, PyObject* hook);
19 static void check_single_result(PyObject* original, PyObject* result, PyObject* hook);
22 namespace torch {
namespace autograd {
24 PyFunctionPreHook::PyFunctionPreHook(PyObject* dict,
int value_idx)
26 , value_idx(value_idx)
31 PyFunctionPreHook::~PyFunctionPreHook() {
36 auto PyFunctionPreHook::operator()(
const variable_list& values) -> variable_list
40 THPObjectPtr value(THPVariable_Wrap(values.at(value_idx)));
45 while (PyDict_Next(dict, &pos, &key, &hook)) {
46 THPObjectPtr res(PyObject_CallFunctionObjArgs(hook, value.get(),
nullptr));
48 if (res == Py_None)
continue;
49 check_single_result(value.get(), res.get(), hook);
50 value = std::move(res);
53 variable_list results(values);
54 if (value != Py_None) results[value_idx] = ((
THPVariable*)value.get())->cdata;
58 PyFunctionPostHook::PyFunctionPostHook(PyObject* dict) : dict(dict) {
62 PyFunctionPostHook::~PyFunctionPostHook() {
67 auto PyFunctionPostHook::operator()(
68 const variable_list& _outputs,
69 const variable_list& _inputs ) -> variable_list
78 while (PyDict_Next(dict, &pos, &key, &hook)) {
80 hook, outputs.get(), inputs.get(),
nullptr));
82 if (res == Py_None)
continue;
83 check_result(outputs, res, hook);
84 outputs = std::move(res);
87 return unwrap_variables(outputs.get());
93 static PyObject *wrap_variables(
const variable_list& c_variables)
95 size_t num_vars = c_variables.size();
98 for (
size_t i = 0; i < num_vars; ++i) {
101 PyTuple_SET_ITEM(tuple.get(), i, var.release());
103 return tuple.release();
106 static variable_list unwrap_variables(PyObject* py_variables) {
107 variable_list results(PyTuple_GET_SIZE(py_variables));
108 for (
size_t i = 0; i < results.size(); i++) {
109 PyObject* item = PyTuple_GET_ITEM(py_variables, i);
110 if (item == Py_None) {
112 }
else if (THPVariable_Check(item)) {
116 std::stringstream ss;
117 ss <<
"expected variable but got " << Py_TYPE(item)->tp_name;
118 throw std::runtime_error(ss.str());
124 static void check_result(PyObject* prev, PyObject* result, PyObject* hook) {
125 if (!PyTuple_Check(result)) {
126 PyErr_Format(PyExc_TypeError,
"expected tuple, but hook returned '%s'",
127 THPUtils_typename(result));
131 auto prev_size = PyTuple_GET_SIZE(prev);
132 auto result_size = PyTuple_GET_SIZE(result);
133 if (prev_size != result_size) {
134 std::stringstream ss;
135 auto name = hook_name(hook);
136 ss <<
"hook '" << name <<
"' has returned an incorrect number ";
137 ss <<
"of values (got " << result_size <<
", but expected ";
138 ss << prev_size <<
")";
139 throw std::runtime_error(ss.str());
142 for (
auto i = 0; i < prev_size; i++) {
143 check_single_result(PyTuple_GET_ITEM(prev, i), PyTuple_GET_ITEM(result, i), hook);
147 static void check_single_result(PyObject* _original, PyObject* _result, PyObject* hook) {
148 if (_result == Py_None)
return;
150 if (_original == Py_None) {
151 throw std::runtime_error(
"can't replace a None gradient with a non-None value");
154 if (!PyObject_IsInstance(_result, THPVariableClass)) {
155 PyErr_Format(PyExc_TypeError,
"expected Variable, but hook returned '%s'",
156 THPUtils_typename(_result));
160 auto& original = ((
THPVariable*)_original)->cdata.data();
161 auto& result = ((
THPVariable*)_result)->cdata.data();
163 if (original.type().ID() != result.type().ID()) {
164 std::stringstream ss;
165 auto name = hook_name(hook);
166 ss <<
"hook '" << name <<
"' has changed the type of value (";
167 ss <<
"was " << original.toString() <<
" got ";
168 ss << result.toString() <<
")";
169 throw std::runtime_error(ss.str());
172 if (original.is_cuda() != result.is_cuda()) {
173 std::stringstream ss;
174 auto name = hook_name(hook);
175 ss <<
"hook '" << name <<
"' has changed the type of value";
176 if (original.is_cuda()) {
177 ss <<
" (was CUDA tensor got CPU tensor)";
179 ss <<
" (was CPU tensor got CUDA tensor)";
181 throw std::runtime_error(ss.str());
184 if (original.sizes().vec() != result.sizes().vec()) {
185 std::stringstream ss;
186 auto name = hook_name(hook);
187 ss <<
"hook '" << name <<
"' has changed the size of value";
188 throw std::runtime_error(ss.str());
192 static std::string hook_name(PyObject* hook) {
193 THPObjectPtr name(PyObject_GetAttrString(hook,
"__name__"));
194 if (name && THPUtils_checkString(name.get())) {
195 return THPUtils_unpackString(name.get());
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...