Caffe2 - C++ API
A deep learning, cross platform ML framework
python_function.h
1 #pragma once
2 
3 #include <torch/csrc/python_headers.h>
4 
5 #include <torch/csrc/Exceptions.h>
6 #include <torch/csrc/autograd/function.h>
7 #include <torch/csrc/autograd/variable.h>
8 #include <torch/csrc/autograd/saved_variable.h>
9 #include <torch/csrc/utils/object_ptr.h>
10 
11 #include <c10/util/Optional.h>
12 #include <c10/core/DeviceGuard.h>
13 
14 #include <vector>
15 #include <utility>
16 #include <memory>
17 
18 namespace torch { namespace jit { struct Graph; }}
19 namespace torch { namespace autograd {
20 
21 struct VariableInfo {
22  explicit VariableInfo(const Variable& var);
23 
24  Variable zeros(at::OptionalDeviceGuard& device_guard) const;
25 
26  at::Type* type;
27  at::Device device = at::kCPU;
28  std::vector<int64_t> size;
29  bool requires_grad;
30 };
31 
32 // A Function which is implemented by a Python object (i.e., a THPFunction).
33 // Calls to 'apply' are forwarded to the Python method implementation.
34 struct PyFunction : public Function {
35  PyFunction(PyObject* obj) : obj(obj) {}
36 
37  variable_list apply(variable_list&& inputs) override;
38  variable_list legacy_apply(const variable_list& inputs);
39 
40  void release_variables() override;
41  std::string name() const override;
42  std::shared_ptr<Function> get_shared_ptr() override;
43  bool is_traceable() override;
44 
45  // THPFunction this Function is wrapping.
46  PyObject* obj;
47 };
48 
53 inline bool ensure_tuple(THPObjectPtr& obj) {
54  if (PyTuple_Check(obj.get()))
55  return false;
56 
57  PyObject *tuple = PyTuple_New(1);
58  if (!tuple) throw python_error();
59  PyTuple_SET_ITEM(tuple, 0, obj.release());
60  obj = tuple;
61  return true;
62 }
63 
64 }} // namespace torch::autograd
65 
66 struct THPFunction {
67  PyObject_HEAD
68 
69  PyObject *needs_input_grad;
70 
71  // Python tuple of tensors whose variables we should save. Set
72  // by Python with 'save_for_backward'. If nullptr, no tensors were
73  // saved.
74  PyObject *to_save;
75  // Python tuple of tensors which are not differentiable. Set by
76  // Python with 'mark_non_differentiable'. If nullptr, no tensors were
77  // non-differentiable.
78  PyObject *non_differentiable;
79  // Python tuple of tensors which had inplace updates in the forward()
80  // pass. Set by Python with 'mark_dirty'. If nullptr, no tensors were
81  // modified inplace.
82  PyObject *dirty_tensors;
83 
84  std::vector<torch::autograd::VariableInfo> output_info;
85  std::vector<torch::autograd::VariableInfo> input_info;
86  std::vector<torch::autograd::SavedVariable> saved_variables;
87  // For each input, true if the input is a THPVariable
88  std::vector<bool> is_variable_input;
89  char has_freed_buffers;
90 
91  // The C++ wrapper for this Python function.
92  // See a comment in THPFunction_asFunction for details about this field.
94 };
95 
96 bool THPFunction_initModule(PyObject *module);
97 extern PyTypeObject THPFunctionType;
98 extern PyObject *THPFunctionClass;
99 
100 // XXX: this function requires the GIL (it can have side effects).
101 std::shared_ptr<torch::autograd::PyFunction> THPFunction_asFunction(THPFunction* self);
102 
103 inline bool THPFunction_Check(PyObject* obj) {
104  return PyObject_IsInstance(obj, (PyObject*)&THPFunctionType);
105 }
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
A OptionalDeviceGuard is an RAII class that sets a device to some value on initialization, and resets the device to its original value on destruction.
Definition: DeviceGuard.h:119
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...