1 #include <torch/csrc/autograd/python_variable.h> 3 #include <torch/csrc/THP.h> 4 #include <torch/csrc/DynamicTypes.h> 5 #include <torch/csrc/Exceptions.h> 6 #include <torch/csrc/Device.h> 7 #include <torch/csrc/Size.h> 8 #include <torch/csrc/Types.h> 9 #include <torch/csrc/autograd/edge.h> 10 #include <torch/csrc/autograd/python_cpp_function.h> 11 #include <torch/csrc/autograd/python_hook.h> 12 #include <torch/csrc/autograd/python_variable_indexing.h> 13 #include <torch/csrc/autograd/variable.h> 14 #include <torch/csrc/autograd/functions/accumulate_grad.h> 15 #include <torch/csrc/autograd/function.h> 16 #include <torch/csrc/autograd/generated/VariableType.h> 17 #include <torch/csrc/autograd/utils/python_error_messages.h> 18 #include <torch/csrc/autograd/utils/wrap_outputs.h> 19 #include <torch/csrc/tensor/python_tensor.h> 20 #include <torch/csrc/utils/auto_gil.h> 21 #include <torch/csrc/utils/cuda_lazy_init.h> 22 #include <torch/csrc/utils/pybind.h> 23 #include <torch/csrc/utils/python_strings.h> 24 #include <torch/csrc/utils/python_arg_parser.h> 25 #include <torch/csrc/utils/tensor_new.h> 26 #include <torch/csrc/jit/tracer.h> 28 #include <ATen/ATen.h> 29 #include <pybind11/pybind11.h> 31 #include <structmember.h> 37 using namespace torch;
42 PyObject *THPVariableClass =
nullptr;
44 static const char* VOLATILE_WARNING =
45 "volatile was removed and now has no effect. Use " 46 "`with torch.no_grad():` instead.";
50 static PyObject* THPVariable_NewWithVar(PyTypeObject* type,
Variable var)
52 PyObject* obj = type->tp_alloc(type, 0);
55 new (&v->cdata)
Variable(std::move(var));
56 v->cdata.set_pyobj(obj);
57 if (
auto fn = dynamic_cast<PyFunction*>(v->cdata.grad_fn_unsafe())) {
60 const auto output_nr = v->cdata.output_nr();
61 auto grad_fn = THPFunction_asFunction((
THPFunction*)fn->obj);
62 v->cdata.set_gradient_edge({std::move(grad_fn), output_nr});
68 PyObject * THPVariable_Wrap(
Variable var)
74 if (
auto obj = var.pyobj()) {
79 return THPVariable_NewWithVar((PyTypeObject *)THPVariableClass, std::move(var));
82 static int THPVariable_traverse(
THPVariable *
self, visitproc visit,
void *arg)
84 Py_VISIT(self->backward_hooks);
99 if (self->cdata.defined()) {
100 for (
const auto& hook : self->cdata.hooks()) {
101 if (
auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
102 Py_VISIT(pyhook->dict);
111 Py_CLEAR(self->backward_hooks);
112 if (self->cdata.defined()) {
113 if (
auto grad_acc = self->cdata.try_get_grad_accumulator()) {
114 grad_acc->pre_hooks().clear();
116 self->cdata.set_pyobj(
nullptr);
124 PyObject_GC_UnTrack(
self);
125 THPVariable_clear(
self);
126 self->cdata.~Variable();
127 Py_TYPE(
self)->tp_free((PyObject*)
self);
130 static PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
133 jit::tracer::warn(
"torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
134 auto& default_type = torch::tensors::get_default_tensor_type();
135 auto tensor = torch::utils::legacy_tensor_ctor(default_type, args, kwargs);
136 return THPVariable_NewWithVar(type, std::move(tensor));
141 static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, PyObject* kwargs) {
144 "_make_subclass(PyObject* cls, Tensor data, bool require_grad=False)",
147 auto r = parser.parse(args, kwargs, parsed_args);
148 PyObject* cls = r.pyobject(0);
149 if (!PyType_Check(cls)) {
150 throw TypeError(
"cls must be a type (got %s)", Py_TYPE(cls)->tp_name);
152 auto& data = as_variable_ref(r.tensor(1)).data();
153 auto var = make_variable(data, r.toBool(2));
154 return THPVariable_NewWithVar((PyTypeObject*)cls, std::move(var));
158 typedef PyObject *(*getter)(PyObject *,
void *);
159 typedef int (*setter)(PyObject *, PyObject *,
void *);
164 auto& var =
self->cdata;
165 return PyLong_FromVoidPtr(var.data().unsafeGetTensorImpl());
169 PyObject *THPVariable_get_version(
THPVariable *
self)
172 auto& var =
self->cdata;
173 return PyInt_FromLong(var.current_version());
177 PyObject *THPVariable_get_grad_fn(
THPVariable *
self)
180 auto& var =
self->cdata;
181 if (!var.grad_fn()) {
184 return functionToPyObject(var.grad_fn());
188 static int THPVariable_set_grad_fn(
THPVariable *
self, PyObject *obj)
191 THPUtils_assertRet(-1, obj,
"Deletion of _grad_fn not allowed. Detach tensor instead!");
192 THPUtils_assertRet(-1, obj == Py_None,
"_grad_fn can be only set to None");
193 self->cdata.detach_();
195 END_HANDLE_TH_ERRORS_RET(-1)
198 static PyObject *THPVariable_is_leaf(
THPVariable *
self)
201 return PyBool_FromLong(!self->cdata.grad_fn());
205 static PyObject * THPVariable_get_data(
THPVariable *
self)
216 auto var = make_variable(self->cdata.data(),
false,
false);
217 return THPVariable_Wrap(var);
221 int THPVariable_set_data(
THPVariable *
self, PyObject *data)
224 THPUtils_assertRet(-1, data,
"Deleting tensor data is not allowed. Delete tensor instead!");
225 if (!THPVariable_Check(data)) {
226 throw torch::TypeError(
"Variable data has to be a tensor, but got %s", Py_TYPE(data)->tp_name);
229 self->cdata.set_data(THPVariable_UnpackData(data));
231 END_HANDLE_TH_ERRORS_RET(-1)
237 return THPVariable_Wrap(self->cdata.grad());
241 int THPVariable_set_grad(
THPVariable *
self, PyObject *py_grad)
244 auto& var =
self->cdata;
245 if (!py_grad || py_grad == Py_None) {
250 THPUtils_assertRet(-1, THPVariable_Check(py_grad),
251 "expected Variable or None (got %s)", THPUtils_typename(py_grad));
252 THPUtils_assertRet(-1,
self != (
THPVariable*)py_grad,
253 "can't assign Variable as its own grad");
256 bool gradIsSparse =
false;
257 auto backend = var.is_cuda() ? Backend::SparseCUDA : Backend::SparseCPU;
258 auto typeOpt = at::globalContext().getNonVariableTypeOpt(backend, var.scalar_type());
260 auto& sparseType = at::globalContext().getNonVariableType(backend, var.scalar_type());
261 auto& gradType = at::globalContext().getNonVariableType(grad.type().backend(), grad.scalar_type());
262 gradIsSparse = gradType == sparseType;
265 THPUtils_assertRet(-1, grad.type() == var.type() || gradIsSparse,
266 "assigned grad has data of a different type");
268 THPUtils_assertRet(-1, grad.get_device() == var.get_device(),
269 "assigned grad has data located on a different device");
271 THPUtils_assertRet(-1, grad.sizes().equals(var.sizes()),
272 "assigned grad has data of a different size");
276 END_HANDLE_TH_ERRORS_RET(-1)
279 PyObject *THPVariable_get_volatile(
THPVariable *
self)
281 const char* msg =
"volatile was removed (Variable.volatile is always False)";
282 PyErr_WarnEx(PyExc_UserWarning, msg, 1);
286 int THPVariable_set_volatile(
THPVariable *
self, PyObject *obj)
288 return PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1);
291 PyObject *THPVariable_get_output_nr(
THPVariable *
self)
294 const auto output_nr =
static_cast<long>(
self->cdata.output_nr());
295 return PyInt_FromLong(output_nr);
299 PyObject *THPVariable_get_requires_grad(
THPVariable *
self)
302 return PyBool_FromLong(self->cdata.requires_grad());
306 int THPVariable_set_requires_grad(
THPVariable *
self, PyObject *obj)
309 THPUtils_assertRet(-1, obj && PyBool_Check(obj),
"requires_grad must be a bool");
310 auto& var =
self->cdata;
312 if (!var.is_leaf()) {
313 THPUtils_setError(autograd::utils::requires_grad_leaf_error(obj == Py_True).c_str());
316 if (requires_grad && !var.is_floating_point()) {
317 THPUtils_setError(
"only Tensors of floating point dtype can require gradients");
320 var.set_requires_grad(requires_grad);
322 END_HANDLE_TH_ERRORS_RET(-1)
327 if (self->cdata.name() ==
"")
329 return THPUtils_packString(self->cdata.name().c_str());
332 PyObject *THPVariable_get_backwards_hooks(
THPVariable *
self)
335 if (self->backward_hooks) {
336 Py_INCREF(self->backward_hooks);
337 return self->backward_hooks;
343 int THPVariable_set_backwards_hooks(
THPVariable *
self, PyObject *obj)
346 THPUtils_assertRet(-1, obj,
"Deletion of _backwards_hooks not allowed!");
347 if (obj == Py_None) {
351 Py_XDECREF(self->backward_hooks);
352 self->backward_hooks = obj;
353 self->cdata.clear_hooks();
355 self->cdata.add_hook(std::make_shared<PyFunctionPreHook>(obj, 0));
358 END_HANDLE_TH_ERRORS_RET(-1)
364 if (self->cdata.is_view()) {
365 return THPVariable_Wrap(self->cdata.base());
374 return THPSize_New(self->cdata);
381 auto& self_ =
self->cdata;
382 return torch::autograd::utils::wrap(self_.is_cuda());
389 auto& self_ =
self->cdata;
390 return torch::autograd::utils::wrap(self_.is_sparse());
394 static PyObject *THPVariable_dtype(
THPVariable *
self)
397 auto& self_ =
self->cdata;
398 return torch::autograd::utils::wrap(torch::getDtype(self_.scalar_type()));
402 static PyObject * THPVariable_layout(
THPVariable*
self) {
404 auto& self_ =
self->cdata;
405 return torch::autograd::utils::wrap(torch::getLayout(self_.type().backend()));
409 static PyObject * THPVariable_device(
THPVariable*
self) {
411 return THPDevice_New(self->cdata.device());
415 static struct PyGetSetDef THPVariable_properties[] = {
416 {
"_cdata", (getter)THPVariable_get_cdata,
nullptr,
nullptr,
nullptr},
417 {
"_version", (getter)THPVariable_get_version,
nullptr,
nullptr,
nullptr},
418 {
"grad_fn", (getter)THPVariable_get_grad_fn,
nullptr,
nullptr,
nullptr},
419 {
"_grad_fn", (getter)THPVariable_get_grad_fn, (setter)THPVariable_set_grad_fn,
nullptr,
nullptr},
420 {
"is_leaf", (getter)THPVariable_is_leaf,
nullptr,
nullptr,
nullptr},
421 {
"data", (getter)THPVariable_get_data, (setter)THPVariable_set_data,
nullptr,
nullptr},
422 {
"_grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad,
nullptr,
nullptr},
423 {
"grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad,
nullptr,
nullptr},
424 {
"_base", (getter)THPVariable_get_base,
nullptr,
nullptr,
nullptr},
425 {
"volatile", (getter)THPVariable_get_volatile, (setter)THPVariable_set_volatile,
nullptr,
nullptr},
426 {
"output_nr", (getter)THPVariable_get_output_nr,
nullptr,
nullptr,
nullptr},
427 {
"requires_grad", (getter)THPVariable_get_requires_grad, (setter)THPVariable_set_requires_grad,
nullptr,
nullptr},
428 {
"_backward_hooks", (getter)THPVariable_get_backwards_hooks, (setter)THPVariable_set_backwards_hooks,
nullptr,
nullptr},
429 {
"name", (getter)THPVariable_get_name,
nullptr,
nullptr,
nullptr},
430 {
"shape", (getter)THPVariable_get_shape,
nullptr,
nullptr,
nullptr},
431 {
"is_cuda", (getter)THPVariable_is_cuda,
nullptr,
nullptr,
nullptr},
432 {
"is_sparse", (getter)THPVariable_is_sparse,
nullptr,
nullptr,
nullptr},
433 {
"dtype", (getter)THPVariable_dtype,
nullptr,
nullptr,
nullptr},
434 {
"layout", (getter)THPVariable_layout,
nullptr,
nullptr,
nullptr},
435 {
"device", (getter)THPVariable_device,
nullptr,
nullptr,
nullptr},
439 static PyMappingMethods THPVariable_as_mapping = {
445 static PyMethodDef extra_methods[] = {
446 {
"_make_subclass", (PyCFunction)THPVariable_make_subclass, METH_STATIC | METH_VARARGS | METH_KEYWORDS,
nullptr},
450 PyTypeObject THPVariableType = {
451 PyVarObject_HEAD_INIT(
nullptr, 0)
452 "torch._C._TensorBase",
455 (destructor)THPVariable_dealloc,
463 &THPVariable_as_mapping,
470 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,
472 (traverseproc)THPVariable_traverse,
473 (inquiry)THPVariable_clear,
480 THPVariable_properties,
491 namespace
torch {
namespace autograd {
493 extern PyMethodDef variable_methods[];
494 extern void initTorchFunctions(PyObject *module);
496 void initTensorImplConversion(PyObject* module) {
497 auto m = py::handle(module).cast<py::module>();
498 m.def(
"_wrap_tensor_impl", [](
void* ptr) {
501 AT_CHECK(p.defined(),
"Can't wrap undefined tensor");
502 AT_CHECK(!p->is_variable(),
"Can wrap only non-variable tensor");
503 auto tensor = at::Tensor::wrap_tensor_impl(std::move(p));
505 torch::autograd::make_variable(std::move(tensor),
false)));
509 auto p = t->data().getIntrusivePtr();
517 bool THPVariable_initModule(PyObject *module)
519 static std::vector<PyMethodDef> methods;
520 THPUtils_addPyMethodDefs(methods, torch::autograd::variable_methods);
521 THPUtils_addPyMethodDefs(methods, extra_methods);
522 THPVariableType.tp_methods = methods.data();
523 if (PyType_Ready(&THPVariableType) < 0)
525 Py_INCREF(&THPVariableType);
526 PyModule_AddObject(module,
"_TensorBase", (PyObject *)&THPVariableType);
527 torch::autograd::initTorchFunctions(module);
528 torch::autograd::initTensorImplConversion(module);
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...
Flush-To-Zero and Denormals-Are-Zero mode.
static intrusive_ptr unsafe_reclaim_from_nonowning(TTarget *raw_ptr)
Turn a non-owning raw pointer to an intrusive_ptr.