1 #include <torch/csrc/jit/autodiff.h> 3 #include <ATen/core/functional.h> 4 #include <torch/csrc/jit/operator.h> 5 #include <torch/csrc/jit/passes/common_subexpression_elimination.h> 6 #include <torch/csrc/jit/passes/constant_pooling.h> 7 #include <torch/csrc/jit/passes/dead_code_elimination.h> 8 #include <torch/csrc/jit/passes/lower_tuples.h> 9 #include <torch/csrc/jit/script/compiler.h> 10 #include <torch/csrc/jit/symbolic_script.h> 11 #include <torch/csrc/jit/symbolic_variable.h> 13 #include <c10/util/Exception.h> 21 using value_map = std::unordered_map<Value*, Value*>;
22 using value_set = std::unordered_set<Value*>;
24 void wrapDim(int64_t& dim,
const std::vector<int64_t>& sizes) {
36 bool needTrimGrad(Node* n) {
37 static OperatorSet need_trim_grad_ops = {
38 "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
39 "aten::topk(Tensor self, int k, int dim, bool largest, bool sorted) -> (Tensor, Tensor)",
41 if (need_trim_grad_ops.find(n)) {
47 bool isDifferentiable(Node* n) {
49 static OperatorSet differentiable_ops = {
50 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
51 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
52 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
53 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
54 "aten::mul(Tensor self, Tensor other) -> Tensor",
55 "aten::mul(Tensor self, Scalar other) -> Tensor",
56 "aten::div(Tensor self, Tensor other) -> Tensor",
57 "aten::div(Tensor self, Scalar other) -> Tensor",
58 "aten::max(Tensor self, Tensor other) -> Tensor",
59 "aten::min(Tensor self, Tensor other) -> Tensor",
60 "aten::sigmoid(Tensor self) -> Tensor",
61 "aten::tanh(Tensor self) -> Tensor",
62 "aten::relu(Tensor self) -> Tensor",
63 "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
64 "aten::erf(Tensor self) -> Tensor",
65 "aten::erfc(Tensor self) -> Tensor",
66 "aten::exp(Tensor self) -> Tensor",
67 "aten::t(Tensor self) -> Tensor",
68 "aten::neg(Tensor self) -> Tensor",
69 "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
70 "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
71 "aten::type_as(Tensor self, Tensor other) -> Tensor",
72 "aten::unsqueeze(Tensor self, int dim) -> Tensor",
73 "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
74 "aten::mm(Tensor self, Tensor mat2) -> Tensor",
75 "aten::lt(Tensor self, Tensor other) -> Tensor",
76 "aten::le(Tensor self, Tensor other) -> Tensor",
77 "aten::gt(Tensor self, Tensor other) -> Tensor",
78 "aten::ge(Tensor self, Tensor other) -> Tensor",
79 "aten::eq(Tensor self, Tensor other) -> Tensor",
80 "aten::ne(Tensor self, Tensor other) -> Tensor",
81 "aten::lt(Tensor self, Scalar other) -> Tensor",
82 "aten::le(Tensor self, Scalar other) -> Tensor",
83 "aten::gt(Tensor self, Scalar other) -> Tensor",
84 "aten::ge(Tensor self, Scalar other) -> Tensor",
85 "aten::eq(Tensor self, Scalar other) -> Tensor",
86 "aten::ne(Tensor self, Scalar other) -> Tensor",
87 "aten::abs(Tensor self) -> Tensor",
88 "aten::acos(Tensor self) -> Tensor",
89 "aten::asin(Tensor self) -> Tensor",
90 "aten::atan(Tensor self) -> Tensor",
91 "aten::ceil(Tensor self) -> Tensor",
92 "aten::cos(Tensor self) -> Tensor",
93 "aten::cosh(Tensor self) -> Tensor",
94 "aten::exp(Tensor self) -> Tensor",
95 "aten::expm1(Tensor self) -> Tensor",
96 "aten::floor(Tensor self) -> Tensor",
97 "aten::fmod(Tensor self, Scalar other) -> Tensor",
98 "aten::frac(Tensor self) -> Tensor",
99 "aten::log(Tensor self) -> Tensor",
100 "aten::log10(Tensor self) -> Tensor",
101 "aten::log1p(Tensor self) -> Tensor",
102 "aten::log2(Tensor self) -> Tensor",
103 "aten::rand_like(Tensor self) -> Tensor",
104 "aten::reciprocal(Tensor self) -> Tensor",
105 "aten::remainder(Tensor self, Scalar other) -> Tensor",
106 "aten::round(Tensor self) -> Tensor",
107 "aten::rsqrt(Tensor self) -> Tensor",
108 "aten::sin(Tensor self) -> Tensor",
109 "aten::sinh(Tensor self) -> Tensor",
110 "aten::tan(Tensor self) -> Tensor",
111 "aten::trunc(Tensor self) -> Tensor",
112 "aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)",
113 "aten::log_softmax(Tensor self, int dim) -> Tensor",
114 "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
115 "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
116 "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)",
117 "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
125 if (n->kind() == prim::Constant || n->kind() == prim::AutogradZero ||
126 n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk)
128 if (differentiable_ops.find(n))
132 "aten::dropout(Tensor input, float p, bool train) -> Tensor", attr::train)) {
133 return n->get<
bool>(attr::train).value();
136 auto schema = n->maybeSchema();
137 if (schema && hasGradientInfoForSchema(*schema)) {
142 "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
143 return n->get<std::vector<int64_t>>(attr::size) &&
144 n->is_constant(attr::implicit) &&
145 n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
147 if (n->matches(
"aten::view(Tensor self, int[] size) -> Tensor")) {
148 return n->get<std::vector<int64_t>>(attr::size) &&
149 n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
152 "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
154 return n->namedInput(attr::weight)->node()->mustBeNone();
159 if (n->kind() == prim::GradOf) {
160 auto body = n->blocks().at(0);
162 body->nodes().begin(),
164 static_cast<bool (*)(Node*)
>(isDifferentiable));
170 bool isDifferentiable(Graph& g) {
174 static_cast<bool (*)(Node*)
>(isDifferentiable));
205 const ArrayRef<Value*>& grads) {
206 auto graph = node->owningGraph();
208 auto compiled_graphs = gradientInfoForSchema(node->schema());
209 if (!compiled_graphs) {
213 value_list new_outputs;
215 WithInsertPoint guard(node->next());
216 auto fw_graph = compiled_graphs->forward;
217 new_outputs = inlineCallTo(
218 *graph, *fw_graph, node->inputs(),
true);
219 auto outputs = node->outputs();
220 AT_ASSERT(new_outputs.size() == outputs.size() + 1);
221 for (
size_t i = 0; i < outputs.size(); ++i) {
222 new_outputs.at(i)->setType(outputs[i]->type());
223 outputs[i]->replaceAllUsesWith(new_outputs.at(i));
228 auto bw_graph = compiled_graphs->backward;
229 auto grad_vec = grads.vec();
230 if (needTrimGrad(node)) {
231 grad_vec.erase(grad_vec.begin()+1, grad_vec.end());
233 auto it = grad_vec.begin();
234 grad_vec.insert(it, new_outputs.back());
235 ArrayRef<Value*> grad(grad_vec);
237 inlineCallTo(*graph, *bw_graph, grad,
true);
242 class GradientHelper {
244 GradientHelper(Node* n) : node(n) {}
246 std::vector<Value*> gradient(ArrayRef<Value*> grad_values) {
247 if (!isDifferentiable(node)) {
248 throw std::runtime_error(
249 std::string(
"differentiation of ") + node->kind().toDisplayString() +
250 " is not supported, or it is missing necessary type information");
253 auto script_grads = build_script_grad(node, grad_values);
255 return *script_grads;
258 auto sym_grads = buildSymbolicGradient(fmap<SymbolicVariable>(grad_values));
259 return fmap(sym_grads, [](
const SymbolicVariable& v) {
return v.value(); });
265 SymbolicVariable gradSumToSizeOf(SymbolicVariable v, Symbol input_name) {
268 WithInsertPoint insert_guard{node};
269 size = SymbolicVariable(node->namedInput(input_name)).size();
271 return v.gradSumToSize(size);
274 std::vector<SymbolicVariable> buildSymbolicGradient(
275 const std::vector<SymbolicVariable>& grads) {
276 static const OperatorSet comparison_ops = {
277 "aten::lt(Tensor self, Tensor other) -> Tensor",
278 "aten::le(Tensor self, Tensor other) -> Tensor",
279 "aten::gt(Tensor self, Tensor other) -> Tensor",
280 "aten::ge(Tensor self, Tensor other) -> Tensor",
281 "aten::eq(Tensor self, Tensor other) -> Tensor",
282 "aten::ne(Tensor self, Tensor other) -> Tensor",
283 "aten::lt(Tensor self, Scalar other) -> Tensor",
284 "aten::le(Tensor self, Scalar other) -> Tensor",
285 "aten::gt(Tensor self, Scalar other) -> Tensor",
286 "aten::ge(Tensor self, Scalar other) -> Tensor",
287 "aten::eq(Tensor self, Scalar other) -> Tensor",
288 "aten::ne(Tensor self, Scalar other) -> Tensor",
290 auto inputs = fmap<SymbolicVariable>(node->inputs());
291 auto outputs = fmap<SymbolicVariable>(node->outputs());
294 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
295 return {gradSumToSizeOf(grads.at(0), attr::self),
297 grads.at(0) * node->namedInput(attr::alpha), attr::other),
302 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
303 return {grads.at(0),
nullptr,
nullptr};
305 }
else if (node->kind() == prim::AutogradAdd) {
307 return {grads.at(0), grads.at(0)};
311 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
312 return {gradSumToSizeOf(grads.at(0), attr::self),
314 -grads.at(0) * node->namedInput(attr::alpha), attr::other),
319 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
320 return {grads.at(0),
nullptr,
nullptr};
322 }
else if (node->matches(
323 "aten::mul(Tensor self, Tensor other) -> Tensor")) {
324 return {gradSumToSizeOf(grads.at(0) * inputs.at(1), attr::self),
325 gradSumToSizeOf(grads.at(0) * inputs.at(0), attr::other)};
327 }
else if (node->matches(
328 "aten::mul(Tensor self, Scalar other) -> Tensor")) {
329 return {grads.at(0) * inputs.at(1),
nullptr};
331 }
else if (node->matches(
332 "aten::div(Tensor self, Tensor other) -> Tensor")) {
333 return {gradSumToSizeOf(grads.at(0) / inputs.at(1), attr::self),
335 -grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1)),
338 }
else if (node->matches(
339 "aten::div(Tensor self, Scalar other) -> Tensor")) {
340 return {grads.at(0) / inputs.at(1),
nullptr};
342 }
else if (node->matches(
343 "aten::max(Tensor self, Tensor other) -> Tensor")) {
346 grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)),
349 grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0)),
352 }
else if (node->matches(
353 "aten::min(Tensor self, Tensor other) -> Tensor")) {
356 grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)),
359 grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0)),
364 "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor")) {
367 grads.at(0) * inputs.at(0).type_as(grads.at(0)), attr::self),
369 grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0)),
372 }
else if (node->matches(
"aten::sigmoid(Tensor self) -> Tensor")) {
376 return {(1 - outputs.at(0)) * outputs.at(0) * grads.at(0)};
378 }
else if (node->matches(
"aten::tanh(Tensor self) -> Tensor")) {
379 return {grads.at(0) * (1 - outputs.at(0) * outputs.at(0))};
381 }
else if (node->matches(
"aten::relu(Tensor self) -> Tensor")) {
382 return {grads.at(0) *
383 (outputs.at(0) >
at::Scalar(0)).type_as(outputs.at(0))};
387 "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor")) {
389 Value* min = inputs.at(1);
390 bool min_must_be_none = min->mustBeNone();
391 Value* max = inputs.at(2);
392 bool max_must_be_none = max->mustBeNone();
397 if (!min_must_be_none && !max_must_be_none) {
398 return {grads.at(0) *
399 (1 - (inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))) *
400 (1 - (inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))),
403 }
else if (max_must_be_none) {
404 return {grads.at(0) *
405 (1 - (inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))),
408 }
else if (min_must_be_none) {
409 return {grads.at(0) *
410 (1 - (inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))),
414 return {grads.at(0),
nullptr,
nullptr};
418 "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor")) {
419 auto threshold = node->get<
at::Scalar>(attr::threshold).value();
420 return {grads.at(0) * (inputs.at(0) > threshold).type_as(outputs.at(0)),
424 }
else if (node->matches(
"aten::erf(Tensor self) -> Tensor")) {
425 return {grads.at(0) * 1.12837916709551 *
426 (-inputs.at(0) * inputs.at(0)).exp()};
428 }
else if (node->matches(
"aten::erfc(Tensor self) -> Tensor")) {
429 return {-grads.at(0) * 1.12837916709551 *
430 (-inputs.at(0) * inputs.at(0)).exp()};
432 }
else if (node->matches(
"aten::exp(Tensor self) -> Tensor")) {
433 return {grads.at(0) * (outputs.at(0))};
435 }
else if (node->matches(
"aten::t(Tensor self) -> Tensor")) {
436 return {grads.at(0).t()};
438 }
else if (node->matches(
"aten::neg(Tensor self) -> Tensor")) {
439 return {-grads.at(0)};
441 }
else if (node->matches(
"aten::abs(Tensor self) -> Tensor")) {
442 return {grads.at(0) * inputs.at(0).sign()};
444 }
else if (node->matches(
"aten::acos(Tensor self) -> Tensor")) {
445 return {grads.at(0) *
446 -((-inputs.at(0) * inputs.at(0) +
at::Scalar(1)).rsqrt())};
448 }
else if (node->matches(
"aten::asin(Tensor self) -> Tensor")) {
449 return {grads.at(0) *
450 (-inputs.at(0) * inputs.at(0) +
at::Scalar(1)).rsqrt()};
452 }
else if (node->matches(
"aten::atan(Tensor self) -> Tensor")) {
453 return {grads.at(0) / (inputs.at(0) * inputs.at(0) +
at::Scalar(1))};
457 "aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)")) {
460 WithInsertPoint insert_guard{node};
461 self_size = inputs.at(0).size();
463 return {grads.at(0).expand(self_size),
nullptr};
465 }
else if (node->matches(
"aten::ceil(Tensor self) -> Tensor")) {
466 return {SymbolicVariable::zeros_like(grads.at(0))};
468 }
else if (node->matches(
"aten::cos(Tensor self) -> Tensor")) {
469 return {grads.at(0) * -inputs.at(0).sin()};
471 }
else if (node->matches(
"aten::cosh(Tensor self) -> Tensor")) {
472 return {grads.at(0) * inputs.at(0).sinh()};
474 }
else if (node->matches(
"aten::exp(Tensor self) -> Tensor")) {
475 return {grads.at(0) * outputs.at(0)};
477 }
else if (node->matches(
"aten::expm1(Tensor self) -> Tensor")) {
478 return {grads.at(0) * (outputs.at(0) +
at::Scalar(1))};
480 }
else if (node->matches(
"aten::floor(Tensor self) -> Tensor")) {
481 return {SymbolicVariable::zeros_like(grads.at(0))};
483 }
else if (node->matches(
484 "aten::fmod(Tensor self, Scalar other) -> Tensor")) {
485 return {grads.at(0),
nullptr};
487 }
else if (node->matches(
"aten::frac(Tensor self) -> Tensor")) {
488 return {grads.at(0)};
490 }
else if (node->matches(
"aten::log(Tensor self) -> Tensor")) {
491 return {grads.at(0) / inputs.at(0)};
493 }
else if (node->matches(
"aten::log10(Tensor self) -> Tensor")) {
494 return {grads.at(0) / (inputs.at(0) * 2.3025850929940456)};
496 }
else if (node->matches(
"aten::log1p(Tensor self) -> Tensor")) {
497 return {grads.at(0) / (inputs.at(0) +
at::Scalar(1))};
499 }
else if (node->matches(
"aten::log2(Tensor self) -> Tensor")) {
500 return {grads.at(0) / (inputs.at(0) * 0.6931471805599453)};
502 }
else if (node->matches(
"aten::reciprocal(Tensor self) -> Tensor")) {
503 return {-grads.at(0) * outputs.at(0) * outputs.at(0)};
505 }
else if (node->matches(
506 "aten::remainder(Tensor self, Scalar other) -> Tensor")) {
507 return {grads.at(0),
nullptr};
509 }
else if (node->matches(
"aten::round(Tensor self) -> Tensor")) {
510 return {SymbolicVariable::zeros_like(grads.at(0))};
512 }
else if (node->matches(
"aten::rsqrt(Tensor self) -> Tensor")) {
513 return {grads.at(0) * outputs.at(0).pow(3.) * -0.5};
515 }
else if (node->matches(
"aten::sin(Tensor self) -> Tensor")) {
516 return {grads.at(0) * inputs.at(0).cos()};
518 }
else if (node->matches(
"aten::sinh(Tensor self) -> Tensor")) {
519 return {grads.at(0) * inputs.at(0).cosh()};
521 }
else if (node->matches(
"aten::tan(Tensor self) -> Tensor")) {
522 return {grads.at(0) * (1. + outputs.at(0) * outputs.at(0))};
524 }
else if (node->matches(
"aten::trunc(Tensor self) -> Tensor")) {
525 return {SymbolicVariable::zeros_like(grads.at(0))};
527 }
else if (node->kind() == prim::ConstantChunk) {
528 return {SymbolicVariable::cat(grads, node->i(attr::dim))};
531 node->matches(
"aten::view(Tensor self, int[] size) -> Tensor") ||
532 node->matches(
"aten::reshape(Tensor self, int[] shape) -> Tensor")) {
535 auto sizes = node->namedInput(attr::self)
537 ->expect<CompleteTensorType>()
539 return {grads.at(0).reshape(sizes),
nullptr};
541 }
else if (node->matches(
542 "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
543 return {grads.at(0).type_as(inputs.at(0)),
nullptr};
545 }
else if (node->matches(
"aten::rand_like(Tensor self) -> Tensor")) {
548 }
else if (node->matches(
549 "aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
550 return {grads.at(0).squeeze(node->namedInput(attr::dim)),
nullptr};
554 "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
555 return {gradSumToSizeOf(
556 grads.at(0) * node->namedInput(attr::beta), attr::self),
557 grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
558 inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
562 }
else if (node->matches(
"aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
563 return {grads.at(0).mm(inputs.at(1).t()),
564 inputs.at(0).t().mm(grads.at(0))};
568 "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
569 const auto& input_sizes = inputs.at(0).sizes();
570 if (input_sizes.size() == 0)
571 return {grads.at(0).sum(),
nullptr,
nullptr};
572 auto grad_sizes = node->get<std::vector<int64_t>>(attr::size).value();
573 auto grad = grads.at(0);
574 while (grad_sizes.size() > input_sizes.size()) {
575 grad = grad.sum(0,
false);
576 grad_sizes.erase(grad_sizes.begin());
578 for (
size_t i = 0; i < input_sizes.size(); ++i) {
579 if (input_sizes[i] == 1 && grad_sizes[i] > 1) {
580 grad = grad.sum(i,
true);
583 return {grad,
nullptr,
nullptr};
585 }
else if (node->matches(
"aten::squeeze(Tensor self) -> Tensor")) {
586 const auto& sizes = inputs.at(0).sizes();
587 std::vector<size_t> squeezed_dims;
588 for (
size_t i = 0; i < sizes.size(); ++i) {
591 squeezed_dims.push_back(i);
593 SymbolicVariable returned_grad = grads.at(0);
594 for (
const auto& dim : squeezed_dims) {
595 returned_grad = returned_grad.unsqueeze(dim);
597 return {returned_grad};
599 }
else if (node->matches(
600 "aten::squeeze(Tensor self, int dim) -> Tensor",
602 int64_t dim = *node->get<int64_t>(attr::dim);
603 const auto& sizes = inputs.at(0).sizes();
605 if (sizes.size() == 0) {
606 return {grads.at(0),
nullptr};
608 return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim),
611 }
else if (comparison_ops.find(node)) {
612 return {
nullptr,
nullptr};
616 "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor")) {
617 AT_ASSERT(grads.size() == 1);
618 auto graph = node->owningGraph();
619 auto backward_value = graph->insert(
620 aten::avg_pool2d_backward,
621 {grads.at(0).value(),
622 node->namedInput(attr::self),
623 node->namedInput(attr::kernel_size),
624 node->namedInput(attr::stride),
625 node->namedInput(attr::padding),
626 node->namedInput(attr::ceil_mode),
627 node->namedInput(attr::count_include_pad)});
628 return {backward_value->node()->output(0),
637 "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)")) {
638 AT_ASSERT(grads.size() == 2);
639 auto graph = node->owningGraph();
640 auto backward_value = graph->insert(
641 aten::max_pool2d_with_indices_backward,
642 {grads.at(0).value(),
643 node->namedInput(attr::self),
644 node->namedInput(attr::kernel_size),
645 node->namedInput(attr::stride),
646 node->namedInput(attr::padding),
647 node->namedInput(attr::dilation),
648 node->namedInput(attr::ceil_mode),
649 outputs.at(1).value()});
650 return {backward_value->node()->output(0),
659 "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) {
660 auto graph = node->owningGraph();
661 auto backward_value = graph->insert(
662 aten::thnn_conv2d_backward,
663 {grads.at(0).value(),
664 inputs.at(0).value(),
665 inputs.at(1).value(),
666 node->namedInput(attr::kernel_size),
667 node->namedInput(attr::stride),
668 node->namedInput(attr::padding),
669 outputs.at(1).value(),
670 outputs.at(2).value(),
671 graph->insertConstant(std::vector<bool>{
true,
true,
true})});
674 Node* tuple_unpack_node =
675 graph->insertNode(graph->createTupleUnpack(backward_value));
676 auto tuple_outputs = tuple_unpack_node->outputs();
677 AT_ASSERT(tuple_outputs.size() == size_t(3));
678 return {tuple_outputs[0],
687 "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) {
688 auto graph = node->owningGraph();
689 auto backward_value = graph->insert(
690 aten::native_batch_norm_backward,
691 {grads.at(0).value(),
692 inputs.at(0).value(),
693 inputs.at(1).value(),
694 inputs.at(3).value(),
695 inputs.at(4).value(),
696 outputs.at(1).value(),
697 outputs.at(2).value(),
698 inputs.at(5).value(),
699 inputs.at(7).value(),
700 graph->insertConstant(std::vector<bool>{
true,
true,
true})});
703 Node* tuple_unpack_node =
704 graph->insertNode(graph->createTupleUnpack(backward_value));
705 auto tuple_outputs = tuple_unpack_node->outputs();
706 AT_ASSERT(tuple_outputs.size() == size_t(3));
707 return {tuple_outputs[0],
718 "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
719 auto graph = node->owningGraph();
720 auto total_weight = graph->insertNode(graph->createAutogradZero());
721 auto weight = graph->insertNode(graph->createNone(TensorType::get()));
722 auto backward_value = graph->insert(
723 aten::nll_loss_backward,
724 {grads.at(0).value(),
725 inputs.at(0).value(),
726 inputs.at(1).value(),
728 inputs.at(3).value(),
729 inputs.at(4).value(),
730 total_weight->output()});
731 return {backward_value->node()->output(0),
737 }
else if (node->matches(
738 "aten::log_softmax(Tensor self, int dim) -> Tensor")) {
739 AT_ASSERT(grads.size() == 1);
740 auto graph = node->owningGraph();
741 auto backward_value = graph->insert(
742 aten::_log_softmax_backward_data,
743 {grads.at(0).value(),
744 outputs.at(0).value(),
745 node->namedInput(attr::dim),
746 node->namedInput(attr::self)});
747 return {backward_value->node()->output(0),
nullptr};
750 node->kind() == prim::Constant || node->kind() == prim::AutogradZero) {
753 throw std::runtime_error(
754 std::string(
"failed to differentiate `") +
755 node->kind().toDisplayString() +
"`");
774 static std::vector<Value*> linearGradientForNode(
776 ArrayRef<Value*> grad_values) {
777 auto& graph = *node->owningGraph();
780 if (needTrimGrad(node)) {
781 grad_values = grad_values.at(0);
783 auto linear = graph.insertNode(graph.create(prim::GradOf, {grad_values}, 0));
785 linear->s_(attr::name, node->kind().toDisplayString());
786 auto block = linear->addBlock();
787 WithInsertPoint guard(block);
788 auto results = GradientHelper(node).gradient(grad_values);
789 return fmap(results, [block, linear](Value* grad) -> Value* {
792 block->registerOutput(grad);
793 return linear->addOutput()->copyMetadata(grad);
799 : grad_map(std::move(grad_map)), reverse_block(reverse_block) {}
802 Block* reverse_block;
811 auto graph = a->owningGraph();
812 return graph->insertNode(graph->create(prim::AutogradAdd, {a, b}))->output();
824 auto& graph = *grad_desc.f;
828 auto reverse_node = graph.create(prim::Reverse, 0);
829 auto reverse_block = reverse_node->addBlock();
833 const auto get_grad = [&](
Value* v) ->
Value* {
834 auto it = grad_map.find(v);
835 if (it == grad_map.end()) {
836 auto undef = graph.insertNode(graph.createAutogradZero());
837 std::tie(it, std::ignore) = grad_map.emplace(v, undef->output());
841 const auto set_grad = [&](
Value* x,
Value* dx) {
842 if (
Value* prev_grad = grad_map[x]) {
843 grad_map[x] = createAutogradAdd(prev_grad, dx);
849 auto outputs = graph.outputs();
850 for (
size_t i = 0, num_outputs = outputs.size(); i < num_outputs; ++i) {
851 Value* output = outputs[i];
852 if (!output->requires_grad())
854 Value* output_grad = reverse_block->addInput()->setType(output->type());
855 set_grad(output, output_grad);
856 grad_desc.df_input_vjps.push_back(i);
859 for (
auto it = graph.nodes().rbegin(), end = graph.nodes().rend(); it != end;
862 auto inputs = node->inputs();
863 auto outputs = node->outputs();
864 if (std::all_of(outputs.begin(), outputs.end(), [](
Value* v) {
865 return !v->requires_grad();
870 value_list grad_inputs =
871 linearGradientForNode(node, fmap(node->outputs(), get_grad));
872 LowerSimpleTuples(reverse_block);
874 AT_ASSERT(grad_inputs.size() == node->inputs().size());
875 for (
size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) {
883 set_grad(inputs[i], grad_inputs[i]);
887 auto inputs = graph.inputs();
888 for (
size_t i = 0, num_inputs = inputs.size(); i < num_inputs; ++i) {
889 Value* input = inputs[i];
890 if (!input->requires_grad())
897 if (grad_map.count(input) == 0)
899 reverse_block->registerOutput(get_grad(input));
900 grad_desc.df_output_vjps.push_back(i);
908 static value_list getReverseCaptures(
Gradient& grad_desc) {
909 auto& graph = *grad_desc.f;
910 auto primal_block = graph.block();
912 value_set reverse_captures_set;
913 value_list reverse_captures;
914 auto check_uses = [&](
Value* v) {
915 for (
auto use : v->uses()) {
916 if (use.user->owningBlock() == primal_block)
918 if ( reverse_captures_set.emplace(v).second) {
919 reverse_captures.push_back(v);
923 for (
Value* input : graph.inputs()) {
926 for (
Node* node : graph.nodes()) {
927 for (
Value* output : node->outputs())
930 return reverse_captures;
939 static const auto err = [](
Value*) ->
Value* {
940 throw std::runtime_error(
"unexpected input");
942 auto& graph = *grad_desc.f;
943 Block* reverse_block = rev_info.reverse_block;
945 for (
Node* top_node : reverse_block->nodes()) {
947 top_node->kind() == prim::GradOf ||
948 top_node->kind() == prim::AutogradAdd ||
949 top_node->kind() == prim::AutogradZero);
950 if (top_node->kind() != prim::GradOf)
952 Block* grad_body = top_node->blocks().at(0);
953 for (
Node* node : grad_body->nodes()) {
954 for (
Value* input : node->inputs()) {
955 if (input->node()->kind() != prim::Constant)
957 if (input->node()->owningBlock() == grad_body)
959 Node* lifted_constant = graph.createClone(input->node(), err);
960 reverse_block->prependNode(lifted_constant);
961 node->replaceInputWith(input, lifted_constant->output());
967 static void deduplicateSizeCaptures(
970 Block* primal_block = grad_desc.f->block();
971 const auto usedOnlyInReverse = [primal_block](
Value* v) {
972 const auto& uses = v->uses();
973 return std::all_of(uses.begin(), uses.end(), [primal_block](
const Use& u) {
974 return u.user->owningBlock() != primal_block;
977 auto captures = getReverseCaptures(grad_desc);
978 value_set capture_set(captures.begin(), captures.end());
979 for (
Value* capture : captures) {
980 Node* node = capture->node();
981 if (!node->matches(
"aten::size(Tensor self) -> int[]")) {
984 if (usedOnlyInReverse(capture) && capture_set.count(node->input())) {
985 WithInsertPoint insert_guard{*rev_info.reverse_block->nodes().begin()};
1001 std::function<void(const std::unordered_set<const Value*>&)> cb =
1002 [&](
const std::unordered_set<const Value*>& live_values) {
1003 std::vector<Value*> to_erase;
1004 for (
auto& entry : rev_info.grad_map) {
1005 if (!live_values.count(entry.second)) {
1006 to_erase.push_back(entry.first);
1009 for (
Value* v : to_erase) {
1010 rev_info.grad_map.erase(v);
1013 EliminateDeadCode(rev_info.reverse_block, std::move(cb));
1025 liftConstants(grad_desc, rev_info);
1029 EliminateCommonSubexpression(grad_desc.f);
1030 deduplicateSizeCaptures(grad_desc, rev_info);
1031 eliminateDeadCode(rev_info);
1043 auto& graph = *grad_desc.f;
1044 auto reverse_block = rev_info.reverse_block;
1054 value_list reverse_captures = getReverseCaptures(grad_desc);
1071 grad_desc.f_real_outputs = graph.outputs().size();
1073 std::unordered_map<Value*, size_t> orig_primal_outputs_idx;
1074 std::unordered_map<Value*, size_t> orig_primal_inputs_idx;
1077 for (
size_t i = 0, num_outputs = graph.outputs().size(); i < num_outputs; ++i)
1078 orig_primal_outputs_idx.emplace(graph.outputs()[i], i);
1079 for (
size_t i = 0, num_inputs = graph.inputs().size(); i < num_inputs; ++i)
1080 orig_primal_inputs_idx[graph.inputs()[i]] = i;
1083 for (
Value* capture_val : reverse_captures) {
1086 if (orig_primal_outputs_idx.count(capture_val) > 0) {
1087 grad_desc.df_input_captured_outputs.push_back(
1088 orig_primal_outputs_idx[capture_val]);
1091 }
else if (orig_primal_inputs_idx.count(capture_val) > 0) {
1092 grad_desc.df_input_captured_inputs.push_back(
1093 orig_primal_inputs_idx.at(capture_val));
1099 graph.registerOutput(capture_val);
1100 grad_desc.df_input_captured_outputs.emplace_back(
1101 graph.outputs().size() - 1);
1109 for (
size_t i = grad_desc.f_real_outputs; i < graph.outputs().size(); ++i) {
1110 Value* tmp = graph.outputs().at(i);
1118 if (rev_info.grad_map.count(tmp) == 0)
1120 Value* tmp_vjp_in = reverse_block->addInput()->setType(tmp->type());
1121 Value* tmp_vjp_prev = rev_info.grad_map.at(tmp);
1126 Value* new_vjp = createAutogradAdd(tmp_vjp_in, tmp_vjp_in);
1127 new_vjp->node()->moveAfter(tmp_vjp_prev->node());
1128 tmp_vjp_prev->replaceAllUsesWith(new_vjp);
1129 new_vjp->node()->replaceInput(1, tmp_vjp_prev);
1130 grad_desc.df_input_vjps.emplace_back(i);
1137 std::unordered_map<Value*, size_t> capture_to_formal_index;
1138 const auto& add_capture = [&](
Value* captured) {
1139 capture_to_formal_index[captured] = reverse_block->inputs().size();
1140 reverse_block->addInput()->copyMetadata(captured);
1142 for (
auto& offset : grad_desc.df_input_captured_inputs)
1143 add_capture(graph.inputs()[offset]);
1144 for (
auto& offset : grad_desc.df_input_captured_outputs)
1145 add_capture(graph.outputs()[offset]);
1147 grad_desc.df = std::make_shared<Graph>();
1148 grad_desc.df->block()->cloneFrom(reverse_block, [&](
Value* v) {
1149 return grad_desc.df->inputs()[capture_to_formal_index.at(v)];
1153 reverse_block->owningNode()->destroy();
1156 Gradient differentiate(std::shared_ptr<Graph>& graph) {
1160 graph.use_count() == 1,
1161 "differentiate will mutate and destroy the graph, so it requires " 1162 "graph.use_count() == 1, but found %d",
1164 std::swap(graph, grad_desc.f);
1169 auto rev_info = addReverseInline(grad_desc);
1170 Optimize(grad_desc, rev_info);
1172 EliminateDeadCode(grad_desc.f->block());
1176 lambdaLiftReverse(grad_desc, rev_info);
1179 ConstantPooling(grad_desc.df);
Scalar represents a 0-dimensional tensor which contains a single element.
An utility class for setting temporary insertion points.
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...