3 #include <torch/csrc/autograd/edge.h> 4 #include <torch/csrc/autograd/grad_mode.h> 5 #include <torch/csrc/autograd/anomaly_mode.h> 6 #include <torch/csrc/autograd/profiler.h> 7 #include <torch/csrc/autograd/saved_variable.h> 8 #include <torch/csrc/autograd/input_metadata.h> 9 #include <torch/csrc/autograd/variable.h> 10 #include <torch/csrc/utils/python_stub.h> 11 #include <torch/csrc/utils/variadic.h> 13 #include <ATen/ATen.h> 14 #include <c10/util/Exception.h> 18 #include <initializer_list> 24 namespace torch {
namespace autograd {
27 struct FunctionPostHook;
28 struct FunctionPreHook;
30 using tensor_list = std::vector<at::Tensor>;
31 using variable_list = std::vector<Variable>;
32 using edge_list = std::vector<Edge>;
33 using saved_variable_list = std::vector<SavedVariable>;
34 using IndexRange = std::pair<size_t, size_t>;
37 void deleteFunction(Function*
function);
87 struct TORCH_API Function : std::enable_shared_from_this<Function> {
94 edge_list&& next_edges = edge_list())
95 : sequence_nr_(sequence_nr),
96 next_edges_(
std::move(next_edges)) {
97 if (AnomalyMode::is_enabled()) {
98 metadata()->store_stack();
102 explicit Function(edge_list&& next_edges = edge_list())
103 : Function(get_next_sequence_nr()++,
std::move(next_edges)) {}
106 Function(
const Function& other) =
delete;
107 Function(Function&& other) =
delete;
108 Function& operator=(
const Function& other) =
delete;
109 Function& operator=(Function&& other) =
delete;
110 virtual ~Function() =
default;
114 variable_list operator()(variable_list&& inputs) {
115 profiler::RecordFunction rec(
this);
116 return apply(std::move(inputs));
126 struct undefined_input {};
130 uint32_t add_input_metadata(
134 uint32_t input_nr = input_metadata_.size();
135 input_metadata_.emplace_back(type, shape, device);
139 uint32_t add_input_metadata(
const at::Tensor& t) noexcept {
140 uint32_t input_nr = input_metadata_.size();
141 input_metadata_.emplace_back(t);
146 uint32_t add_input_metadata(undefined_input u) noexcept {
147 uint32_t input_nr = input_metadata_.size();
148 input_metadata_.emplace_back();
152 uint32_t num_inputs() const noexcept {
153 return input_metadata_.size();
156 const InputMetadata& input_metadata(
size_t index)
const {
157 return input_metadata_[index];
160 void clear_input_metadata() {
161 input_metadata_.clear();
166 const Edge& next_edge(
size_t index)
const noexcept {
167 return next_edges_[index];
170 void set_next_edge(
size_t index, Edge edge) {
171 next_edges_[index] = std::move(edge);
174 void add_next_edge(Edge edge) {
175 next_edges_.push_back(std::move(edge));
178 void set_next_edges(edge_list&& next_edges) {
179 next_edges_ = std::move(next_edges);
182 const edge_list& next_edges() const noexcept {
186 edge_list& next_edges() noexcept {
190 uint32_t num_outputs() const noexcept {
191 return next_edges_.size();
198 uint64_t sequence_nr() const noexcept {
205 virtual std::shared_ptr<Function> get_shared_ptr() {
206 return shared_from_this();
210 virtual std::string name()
const;
214 bool should_compute_output(
size_t output_edge_index)
const {
215 AT_CHECK(output_edge_index < num_outputs(),
"Index out of range");
216 return next_edges_[output_edge_index].is_valid();
220 bool should_compute_output(std::initializer_list<IndexRange> idxs)
const {
221 return std::any_of(idxs.begin(), idxs.end(), [
this](IndexRange range) {
222 for (
auto i = range.first; i < range.second; i++) {
223 if (should_compute_output(i))
232 PyObject* pyobj() const noexcept {
237 void set_pyobj(PyObject* pyobj) noexcept {
243 AnomalyMetadata* metadata() noexcept;
248 void add_post_hook(
std::unique_ptr<FunctionPostHook>&& post_hook) {
249 post_hooks_.push_back(std::move(post_hook));
252 const std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() const
257 std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() noexcept {
261 void add_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
262 pre_hooks_.push_back(std::move(pre_hook));
265 const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() const
270 std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() noexcept {
278 virtual void release_variables() {}
283 virtual void will_release_variables() {}
288 virtual bool is_traceable() {
301 virtual bool passes_state_transparently() {
305 static uint64_t peek_at_next_sequence_nr();
308 static uint64_t& get_next_sequence_nr();
311 virtual variable_list apply(variable_list&& inputs) = 0;
314 variable_list traced_apply(variable_list inputs);
318 const uint64_t sequence_nr_;
320 edge_list next_edges_;
321 PyObject* pyobj_ =
nullptr;
322 std::unique_ptr<AnomalyMetadata> anomaly_metadata_ =
nullptr;
323 std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
324 std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
329 struct TraceableFunction :
public Function {
330 using Function::Function;
331 bool is_traceable() final {
342 struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
343 edge_list next_edges;
344 using IterArgs<MakeNextFunctionList>::operator();
345 void operator()(
const Variable& variable) {
346 if (variable.defined()) {
347 next_edges.push_back(variable.gradient_edge());
349 next_edges.emplace_back();
366 inline void create_gradient_edge(
368 std::shared_ptr<Function>
function) {
370 const auto input_nr =
function->add_input_metadata(variable);
371 variable.set_gradient_edge({std::move(
function), input_nr});
375 inline bool any_variable_requires_grad(
const variable_list& variables) {
377 variables.begin(), variables.end(), [](
const Variable& variable) {
378 return variable.defined() && variable.requires_grad();
383 template <
typename... Variables>
384 edge_list collect_next_edges(Variables&&... variables) {
385 if (!GradMode::is_enabled())
387 detail::MakeNextFunctionList make;
388 make.apply(std::forward<Variables>(variables)...);
389 return std::move(make.next_edges);
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
Represents a a compute device on which a tensor is located.