Caffe2 - C++ API
A deep learning, cross platform ML framework
python_legacy_variable.cpp
1 #include <torch/csrc/autograd/python_legacy_variable.h>
2 
3 #include <ATen/ATen.h>
4 
5 #include <torch/csrc/Exceptions.h>
6 #include <torch/csrc/autograd/python_function.h>
7 #include <torch/csrc/autograd/python_variable.h>
8 #include <torch/csrc/tensor/python_tensor.h>
9 #include <torch/csrc/jit/tracer.h>
10 
11 using namespace at;
12 
13 namespace torch { namespace autograd {
14 
15 static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject *kwds) {
16  HANDLE_TH_ERRORS
17  THPObjectPtr _data;
18  PyObject *data = nullptr;
19  PyObject *grad_fn = nullptr;
20  char is_volatile = 0;
21  char requires_grad = 0;
22  const char* name = nullptr;
23 
24  const char *accepted_args[] = {"data", "requires_grad", "volatile", "_grad_fn", "name", nullptr};
25  if (!PyArg_ParseTupleAndKeywords(args, kwds, "|ObbOz", (char**)accepted_args,
26  &data, &requires_grad, &is_volatile, &grad_fn, &name))
27  return nullptr;
28 
29  if (grad_fn == Py_None)
30  grad_fn = nullptr;
31 
32  if (is_volatile) {
33  PyErr_WarnEx(PyExc_UserWarning,
34  "volatile was removed and now has no effect. Use `with torch.no_grad():` "
35  "instead.", 1);
36  }
37 
38  if (is_volatile && requires_grad) {
39  throw ValueError("Variable can't be volatile and require_grad at the same time!");
40  }
41  if (grad_fn && !THPFunction_Check(grad_fn)) {
42  throw TypeError("_grad_fn has to be a Function object or None, but got %s",
43  Py_TYPE(grad_fn)->tp_name);
44  }
45  Tensor tensor;
46  if (!data || data == Py_None) {
47  // For legacy serialization code, create an empty tensor. This is also used
48  // by nn.Parameter() with no arguments.
49  auto var = at::empty({0}, torch::tensors::get_default_tensor_type().options());
50  tensor = static_cast<Variable&>(var).data();
51  } else if (THPVariable_Check(data)) {
52  tensor = ((THPVariable*)data)->cdata.data();
53  } else {
54  throw torch::TypeError("Variable data has to be a tensor, but got %s",
55  Py_TYPE(data)->tp_name);
56  }
57 
58  Variable var;
59  if (grad_fn) {
60  auto grad_fn_ = THPFunction_asFunction((THPFunction*)grad_fn);
61  Edge edge(grad_fn_, grad_fn_->add_input_metadata(tensor));
62  var = make_variable(std::move(tensor), std::move(edge));
63  } else {
64  var = make_variable(std::move(tensor), requires_grad);
65  }
66 
67  if (name) {
68  var.set_name(name);
69  }
70 
71  if (jit::tracer::isTracing() && data && data != Py_None && THPVariable_Check(data)) {
72  if (auto *v = jit::tracer::getValueTrace(((THPVariable*)data)->cdata)) {
73  jit::tracer::setValueTrace(var, v);
74  }
75  }
76 
77  return THPVariable_Wrap(std::move(var));
78  END_HANDLE_TH_ERRORS
79 }
80 
81 PyTypeObject THPLegacyVariableType = {
82  PyVarObject_HEAD_INIT(nullptr, 0)
83  "torch._C._LegacyVariableBase", /* tp_name */
84  0, /* tp_basicsize */
85  0, /* tp_itemsize */
86  nullptr, /* tp_dealloc */
87  nullptr, /* tp_print */
88  nullptr, /* tp_getattr */
89  nullptr, /* tp_setattr */
90  nullptr, /* tp_reserved */
91  nullptr, /* tp_repr */
92  nullptr, /* tp_as_number */
93  nullptr, /* tp_as_sequence */
94  nullptr, /* tp_as_mapping */
95  nullptr, /* tp_hash */
96  nullptr, /* tp_call */
97  nullptr, /* tp_str */
98  nullptr, /* tp_getattro */
99  nullptr, /* tp_setattro */
100  nullptr, /* tp_as_buffer */
101  Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
102  nullptr, /* tp_doc */
103  nullptr, /* tp_traverse */
104  nullptr, /* tp_clear */
105  nullptr, /* tp_richcompare */
106  0, /* tp_weaklistoffset */
107  nullptr, /* tp_iter */
108  nullptr, /* tp_iternext */
109  nullptr, /* tp_methods */
110  nullptr, /* tp_members */
111  nullptr, /* tp_getset */
112  nullptr, /* tp_base */
113  nullptr, /* tp_dict */
114  nullptr, /* tp_descr_get */
115  nullptr, /* tp_descr_set */
116  0, /* tp_dictoffset */
117  nullptr, /* tp_init */
118  nullptr, /* tp_alloc */
119  THPVariable_pynew /* tp_new */
120 };
121 
122 void init_legacy_variable(PyObject *module) {
123  if (PyType_Ready(&THPLegacyVariableType) < 0) {
124  throw python_error();
125  }
126  auto obj = (PyObject*)&THPLegacyVariableType;
127  Py_INCREF(obj);
128  if (PyModule_AddObject(module, "_LegacyVariableBase", obj) < 0) {
129  throw python_error();
130  }
131 }
132 
133 }} // namespace torch::autograd
Definition: jit_type.h:17
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...
Flush-To-Zero and Denormals-Are-Zero mode.