Caffe2 - C++ API
A deep learning, cross platform ML framework
python_hook.cpp
1 #include <torch/csrc/autograd/python_hook.h>
2 
3 #include <sstream>
4 
5 #include <torch/csrc/THP.h>
6 #include <torch/csrc/autograd/python_variable.h>
7 #include <torch/csrc/utils/auto_gil.h>
8 #include <torch/csrc/utils/object_ptr.h>
9 #include <torch/csrc/utils/python_strings.h>
10 #include <torch/csrc/Exceptions.h>
11 
12 using torch::autograd::variable_list;
14 
15 static PyObject* wrap_variables(const variable_list& c_variables);
16 static variable_list unwrap_variables(PyObject* py_variables);
17 static std::string hook_name(PyObject* hook);
18 static void check_result(PyObject* original, PyObject* result, PyObject* hook);
19 static void check_single_result(PyObject* original, PyObject* result, PyObject* hook);
20 
21 
22 namespace torch { namespace autograd {
23 
24 PyFunctionPreHook::PyFunctionPreHook(PyObject* dict, int value_idx)
25  : dict(dict)
26  , value_idx(value_idx)
27 {
28  Py_INCREF(dict);
29 }
30 
31 PyFunctionPreHook::~PyFunctionPreHook() {
32  AutoGIL gil;
33  Py_DECREF(dict);
34 }
35 
36 auto PyFunctionPreHook::operator()(const variable_list& values) -> variable_list
37 {
38  AutoGIL gil;
39 
40  THPObjectPtr value(THPVariable_Wrap(values.at(value_idx)));
41  if (!value) throw python_error();
42 
43  PyObject *key, *hook;
44  Py_ssize_t pos = 0;
45  while (PyDict_Next(dict, &pos, &key, &hook)) {
46  THPObjectPtr res(PyObject_CallFunctionObjArgs(hook, value.get(), nullptr));
47  if (!res) throw python_error();
48  if (res == Py_None) continue;
49  check_single_result(value.get(), res.get(), hook);
50  value = std::move(res);
51  }
52 
53  variable_list results(values);
54  if (value != Py_None) results[value_idx] = ((THPVariable*)value.get())->cdata;
55  return results;
56 }
57 
58 PyFunctionPostHook::PyFunctionPostHook(PyObject* dict) : dict(dict) {
59  Py_INCREF(dict);
60 }
61 
62 PyFunctionPostHook::~PyFunctionPostHook() {
63  AutoGIL gil;
64  Py_DECREF(dict);
65 }
66 
67 auto PyFunctionPostHook::operator()(
68  const variable_list& _outputs, /* grad_inputs */
69  const variable_list& _inputs /* grad_outputs */) -> variable_list
70 {
71  AutoGIL gil;
72 
73  THPObjectPtr outputs(wrap_variables(_outputs));
74  THPObjectPtr inputs(wrap_variables(_inputs));
75 
76  PyObject *key, *hook;
77  Py_ssize_t pos = 0;
78  while (PyDict_Next(dict, &pos, &key, &hook)) {
79  THPObjectPtr res(PyObject_CallFunctionObjArgs(
80  hook, outputs.get(), inputs.get(), nullptr));
81  if (!res) throw python_error();
82  if (res == Py_None) continue;
83  check_result(outputs, res, hook);
84  outputs = std::move(res);
85  }
86 
87  return unwrap_variables(outputs.get());
88 }
89 
90 }} // namespace torch::autograd
91 
92 
93 static PyObject *wrap_variables(const variable_list& c_variables)
94 {
95  size_t num_vars = c_variables.size();
96  THPObjectPtr tuple(PyTuple_New(num_vars));
97  if (!tuple) throw python_error();
98  for (size_t i = 0; i < num_vars; ++i) {
99  THPObjectPtr var(THPVariable_Wrap(c_variables[i]));
100  if (!var) throw python_error();
101  PyTuple_SET_ITEM(tuple.get(), i, var.release());
102  }
103  return tuple.release();
104 }
105 
106 static variable_list unwrap_variables(PyObject* py_variables) {
107  variable_list results(PyTuple_GET_SIZE(py_variables));
108  for (size_t i = 0; i < results.size(); i++) {
109  PyObject* item = PyTuple_GET_ITEM(py_variables, i);
110  if (item == Py_None) {
111  continue;
112  } else if (THPVariable_Check(item)) {
113  results[i] = ((THPVariable*)item)->cdata;
114  } else {
115  // this should never happen, but just in case...
116  std::stringstream ss;
117  ss << "expected variable but got " << Py_TYPE(item)->tp_name;
118  throw std::runtime_error(ss.str());
119  }
120  }
121  return results;
122 }
123 
124 static void check_result(PyObject* prev, PyObject* result, PyObject* hook) {
125  if (!PyTuple_Check(result)) {
126  PyErr_Format(PyExc_TypeError, "expected tuple, but hook returned '%s'",
127  THPUtils_typename(result));
128  throw python_error();
129  }
130 
131  auto prev_size = PyTuple_GET_SIZE(prev);
132  auto result_size = PyTuple_GET_SIZE(result);
133  if (prev_size != result_size) {
134  std::stringstream ss;
135  auto name = hook_name(hook);
136  ss << "hook '" << name << "' has returned an incorrect number ";
137  ss << "of values (got " << result_size << ", but expected ";
138  ss << prev_size << ")";
139  throw std::runtime_error(ss.str());
140  }
141 
142  for (auto i = 0; i < prev_size; i++) {
143  check_single_result(PyTuple_GET_ITEM(prev, i), PyTuple_GET_ITEM(result, i), hook);
144  }
145 }
146 
147 static void check_single_result(PyObject* _original, PyObject* _result, PyObject* hook) {
148  if (_result == Py_None) return;
149 
150  if (_original == Py_None) {
151  throw std::runtime_error("can't replace a None gradient with a non-None value");
152  }
153 
154  if (!PyObject_IsInstance(_result, THPVariableClass)) {
155  PyErr_Format(PyExc_TypeError, "expected Variable, but hook returned '%s'",
156  THPUtils_typename(_result));
157  throw python_error();
158  }
159 
160  auto& original = ((THPVariable*)_original)->cdata.data();
161  auto& result = ((THPVariable*)_result)->cdata.data();
162 
163  if (original.type().ID() != result.type().ID()) {
164  std::stringstream ss;
165  auto name = hook_name(hook);
166  ss << "hook '" << name << "' has changed the type of value (";
167  ss << "was " << original.toString() << " got ";
168  ss << result.toString() << ")";
169  throw std::runtime_error(ss.str());
170  }
171 
172  if (original.is_cuda() != result.is_cuda()) {
173  std::stringstream ss;
174  auto name = hook_name(hook);
175  ss << "hook '" << name << "' has changed the type of value";
176  if (original.is_cuda()) {
177  ss << " (was CUDA tensor got CPU tensor)";
178  } else {
179  ss << " (was CPU tensor got CUDA tensor)";
180  }
181  throw std::runtime_error(ss.str());
182  }
183 
184  if (original.sizes().vec() != result.sizes().vec()) {
185  std::stringstream ss;
186  auto name = hook_name(hook);
187  ss << "hook '" << name << "' has changed the size of value";
188  throw std::runtime_error(ss.str());
189  }
190 }
191 
192 static std::string hook_name(PyObject* hook) {
193  THPObjectPtr name(PyObject_GetAttrString(hook, "__name__"));
194  if (name && THPUtils_checkString(name.get())) {
195  return THPUtils_unpackString(name.get());
196  }
197  return "<unknown>";
198 }
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17