3 #include "test/cpp/jit/test_base.h" 4 #include "test/cpp/jit/test_utils.h" 5 #include "torch/csrc/jit/argument_spec.h" 6 #include "torch/csrc/jit/autodiff.h" 7 #include "torch/csrc/jit/passes/common_subexpression_elimination.h" 8 #include "torch/csrc/jit/passes/constant_propagation.h" 9 #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h" 10 #include "torch/csrc/jit/passes/dead_code_elimination.h" 11 #include "torch/csrc/jit/passes/graph_fuser.h" 12 #include "torch/csrc/jit/passes/lower_grad_of.h" 13 #include "torch/csrc/jit/passes/requires_grad_analysis.h" 14 #include "torch/csrc/jit/passes/shape_analysis.h" 15 #include "torch/csrc/jit/passes/utils/subgraph_utils.h" 16 #include "torch/csrc/jit/tracer.h" 18 #include <ATen/ATen.h> 19 #include "torch/csrc/autograd/engine.h" 20 #include "torch/csrc/autograd/generated/variable_factories.h" 21 #include "torch/csrc/autograd/variable.h" 29 using var_meta_type = std::vector<int64_t>;
30 using var_meta_list = std::vector<var_meta_type>;
31 using test_fn_type = std::function<variable_list(const variable_list&)>;
34 ADTestSpec(
const char* name, var_meta_list input_meta, test_fn_type test_fn)
35 : name(name), input_meta(input_meta), test_fn(test_fn) {}
37 variable_list operator()(
const variable_list& inputs)
const {
38 return test_fn(inputs);
41 std::vector<Variable> make_vars()
const {
42 std::vector<Variable> out;
43 for (
const auto& m : input_meta) {
44 out.push_back(torch::randn(m, at::requires_grad(
true)));
50 var_meta_list input_meta;
54 variable_list get_grad_outputs(
const variable_list& vars) {
56 return at::randn(v.sizes(), v.
options());
60 std::shared_ptr<Graph> trace(
62 const variable_list& vars_in) {
63 std::shared_ptr<tracer::TracingState> state;
65 std::tie(state, trace_stack_in) = tracer::enter(fmap<IValue>(vars_in));
66 variable_list trace_vars_in = fmap(
67 trace_stack_in, [](
const IValue& v) {
return Variable(v.toTensor()); });
68 auto trace_vars_out = test(trace_vars_in);
69 tracer::exit(fmap<IValue>(trace_vars_out));
74 const variable_list& outputs,
75 const variable_list& inputs,
76 const variable_list& grad_outputs) {
77 const auto get_edge = [](
const Variable& v) {
return v.gradient_edge(); };
79 return engine.execute(
80 fmap(outputs, get_edge),
84 fmap(inputs, get_edge));
87 void testADFormulas() {
88 const auto unwrap = [](
const Variable& v) {
return v.data(); };
90 using VL = variable_list;
91 const var_meta_list binary_pointwise = {{2, 3, 4, 5}, {2, 3, 4, 5}};
92 const var_meta_list unary_pointwise = {{2, 3, 4, 5}};
93 const var_meta_list unary_pointwise_2d = {{2, 3}};
94 const std::vector<ADTestSpec> ad_tests = {
97 [](
const VL& v) -> VL {
return {v[0] + v[1]}; }},
100 [](
const VL& v) -> VL {
return {v[0] - v[1]}; }},
103 [](
const VL& v) -> VL {
return {v[0] * v[1]}; }},
106 [](
const VL& v) -> VL {
return {v[0].sigmoid()}; }},
109 [](
const VL& v) -> VL {
return {v[0].tanh()}; }},
110 {
"t", unary_pointwise_2d, [](
const VL& v) -> VL {
return {v[0].t()}; }},
113 [](
const VL& v) -> VL {
114 return {v[0].view({3, 2})};
118 [](
const VL& v) -> VL {
119 return {v[0].expand({2, 3})};
122 {{10, 12}, {12, 15}},
123 [](
const VL& v) -> VL {
return {v[0].mm(v[1])}; }},
136 for (
const auto& test : ad_tests) {
138 auto vars_in = test.make_vars();
139 auto vars_out = test(vars_in);
140 auto var_grads_in = get_grad_outputs(vars_out);
141 auto var_grads_out = grad(vars_out, vars_in, var_grads_in);
144 auto graph = trace(test, vars_in);
145 EliminateDeadCode(graph);
146 ConstantPropagation(graph);
147 auto grad_spec = differentiate(graph);
148 LowerGradOf(*grad_spec.df);
150 auto tensors_in = fmap(vars_in, unwrap);
151 auto tensor_grads_in = fmap(var_grads_in, unwrap);
152 tensor_list tensors_out, tensor_grads_out;
153 std::tie(tensors_out, tensor_grads_out) =
154 runGradient(grad_spec, tensors_in, tensor_grads_in);
157 auto expected_tensors_out = fmap(vars_out, unwrap);
158 auto expected_tensor_grads_out = fmap(var_grads_out, unwrap);
159 assertAllClose(tensors_out, expected_tensors_out);
160 assertAllClose(tensor_grads_out, expected_tensor_grads_out);
164 void testDifferentiate() {
165 auto graph = std::make_shared<Graph>();
166 at::ScalarType s = at::ScalarType::Float;
167 auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1});
170 auto a = SymbolicVariable::asNewInput(*graph, type);
171 auto b = SymbolicVariable::asNewInput(*graph, type);
172 auto c = a * b * a + b;
173 graph->registerOutput(c.value());
175 auto grad_spec = differentiate(graph);
176 std::vector<size_t> expected_captured_inputs = {0, 1};
177 std::vector<size_t> expected_captured_outputs = {1, 2};
178 std::vector<size_t> expected_input_vjps = {0, 1};
179 std::vector<size_t> expected_output_vjps = {0, 1};
180 ASSERT_EQ(grad_spec.f_real_outputs, 1);
181 ASSERT_EQ(grad_spec.df_input_captured_inputs, expected_captured_inputs);
182 ASSERT_EQ(grad_spec.df_input_captured_outputs, expected_captured_outputs);
183 ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
184 ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
186 .check_count(
"aten::mul", 2)
187 ->check(
"aten::size")
191 .check(
"prim::GradOf[name=\"aten::add\"]")
192 ->check_count(
"prim::GradOf[name=\"aten::mul\"]", 2)
193 ->check_count(
"AutogradAdd", 2)
194 ->run(*grad_spec.df);
197 void testDifferentiateWithRequiresGrad() {
199 auto graph = std::make_shared<Graph>();
200 auto a = SymbolicVariable::asNewInput(*graph);
201 auto b = SymbolicVariable::asNewInput(*graph);
203 auto e = (d + a) * a + b;
204 graph->registerOutput(d.value());
205 graph->registerOutput(e.value());
207 auto a_var = autograd::make_variable(
208 at::empty_strided(2, 2, at::CPU(at::kFloat).options()),
true);
209 auto b_var = autograd::make_variable(
210 at::empty_strided(2, 2, at::CPU(at::kFloat).options()),
false);
211 setInputTypes(*graph,
ArgumentSpec(
true, {a_var, b_var}, 2));
212 PropagateInputShapes(graph);
213 PropagateRequiresGrad(graph);
215 auto grad_spec = differentiate(graph);
216 std::vector<size_t> expected_input_vjps = {1, 2};
217 std::vector<size_t> expected_output_vjps = {0};
218 ASSERT_EQ(grad_spec.f_real_outputs, 2);
219 ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
220 ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3}));
221 ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
222 ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
225 ->check_count(
"aten::add", 2)
227 ->check(
"aten::size")
232 .check_count(
"prim::GradOf[name=\"aten::mul\"]", 1,
true)
233 ->run(*grad_spec.df);
TensorOptions options() const
Returns the TensorOptions corresponding to this Tensor.
static Engine & get_default_engine()
Returns a reference to a static Engine instance.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...