Caffe2 - C++ API
A deep learning, cross platform ML framework
utils.h
1 #pragma once
2 
3 #include <torch/csrc/WindowsTorchApiMacro.h>
4 #include <torch/csrc/autograd/function.h>
5 #include <torch/csrc/autograd/variable.h>
6 #include <torch/csrc/utils/variadic.h>
7 
8 #include <ATen/ATen.h>
9 
10 #include <functional>
11 #include <memory>
12 #include <vector>
13 
14 namespace torch { namespace autograd {
15 
16 using function_constructor = std::function<std::shared_ptr<Function>(edge_list&&)>;
17 
22 TORCH_API variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
23  const function_constructor& ctr);
24 
27 TORCH_API void check_input_variables(const char* name, const variable_list& inputs, int args, int required_args=-1);
28 
29 struct ComputeRequiresGrad : IterArgs<ComputeRequiresGrad> {
30  bool out = false;
32  void operator()(const at::Tensor& tensor) {
33  const auto& var = static_cast<const Variable&>(tensor);
34  if (var.defined() && var.requires_grad()) {
35  out = true;
36  }
37  }
38  bool short_circuit() {
39  return out;
40  }
41 };
42 
43 template <typename... Args>
44 inline bool compute_requires_grad(Args&&... args) {
45  if (!GradMode::is_enabled()) {
46  return false;
47  }
48  return ComputeRequiresGrad().apply(std::forward<Args>(args)...).out;
49 }
50 
51 inline void set_history(
52  at::Tensor& variable,
53  const std::shared_ptr<Function>& grad_fn) {
54  if (grad_fn) {
55  if (variable.defined()) {
56  auto output_nr =
57  grad_fn->add_input_metadata(variable);
58  as_variable_ref(variable).set_gradient_edge({grad_fn, output_nr});
59  } else {
60  grad_fn->add_input_metadata(Function::undefined_input());
61  }
62  }
63 }
64 
65 inline void set_history(
66  std::vector<Variable>&& variables,
67  const std::shared_ptr<Function>& grad_fn) {
68  for (auto& variable : variables) {
69  set_history(variable, grad_fn);
70  }
71 }
72 
73 inline void set_history(
74  std::vector<Variable>& variables,
75  const std::shared_ptr<Function>& grad_fn) {
76  for (auto& variable : variables) {
77  set_history(variable, grad_fn);
78  }
79 }
80 }}
void set_gradient_edge(Edge edge) noexcept
Set the gradient edge – i.e.
Definition: variable.h:682
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17