Caffe2 - C++ API
A deep learning, cross platform ML framework
function_hook.h
1 #pragma once
2 
3 #include <vector>
4 
5 // A hook that's called on gradients
6 
7 namespace torch { namespace autograd {
8 
9 struct Variable;
10 using variable_list = std::vector<Variable>;
11 
13  virtual ~FunctionPreHook() = default;
14  virtual variable_list operator()(const variable_list& grads) = 0;
15 };
16 
18  virtual ~FunctionPostHook() = default;
19  virtual variable_list operator()(const variable_list& grad_input, const variable_list& grad_output) = 0;
20 };
21 
22 }} // namespace torch::autograd
Definition: jit_type.h:17
Definition: function_hook.h:17