1 #include <torch/csrc/jit/autodiff.h>     3 #include <ATen/core/functional.h>     4 #include <torch/csrc/jit/operator.h>     5 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>     6 #include <torch/csrc/jit/passes/constant_pooling.h>     7 #include <torch/csrc/jit/passes/dead_code_elimination.h>     8 #include <torch/csrc/jit/passes/lower_tuples.h>     9 #include <torch/csrc/jit/script/compiler.h>    10 #include <torch/csrc/jit/symbolic_script.h>    11 #include <torch/csrc/jit/symbolic_variable.h>    13 #include <c10/util/Exception.h>    21 using value_map = std::unordered_map<Value*, Value*>;
    22 using value_set = std::unordered_set<Value*>;
    24 void wrapDim(int64_t& dim, 
const std::vector<int64_t>& sizes) {
    36 bool needTrimGrad(Node* n) {
    37   static OperatorSet need_trim_grad_ops = {
    38       "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
    39       "aten::topk(Tensor self, int k, int dim, bool largest, bool sorted) -> (Tensor, Tensor)",
    41   if (need_trim_grad_ops.find(n)) {
    47 bool isDifferentiable(Node* n) {
    49   static OperatorSet differentiable_ops = {
    50       "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
    51       "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
    52       "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
    53       "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
    54       "aten::mul(Tensor self, Tensor other) -> Tensor",
    55       "aten::mul(Tensor self, Scalar other) -> Tensor",
    56       "aten::div(Tensor self, Tensor other) -> Tensor",
    57       "aten::div(Tensor self, Scalar other) -> Tensor",
    58       "aten::max(Tensor self, Tensor other) -> Tensor",
    59       "aten::min(Tensor self, Tensor other) -> Tensor",
    60       "aten::sigmoid(Tensor self) -> Tensor",
    61       "aten::tanh(Tensor self) -> Tensor",
    62       "aten::relu(Tensor self) -> Tensor",
    63       "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
    64       "aten::erf(Tensor self) -> Tensor",
    65       "aten::erfc(Tensor self) -> Tensor",
    66       "aten::exp(Tensor self) -> Tensor",
    67       "aten::t(Tensor self) -> Tensor",
    68       "aten::neg(Tensor self) -> Tensor",
    69       "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
    70       "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
    71       "aten::type_as(Tensor self, Tensor other) -> Tensor",
    72       "aten::unsqueeze(Tensor self, int dim) -> Tensor",
    73       "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
    74       "aten::mm(Tensor self, Tensor mat2) -> Tensor",
    75       "aten::lt(Tensor self, Tensor other) -> Tensor",
    76       "aten::le(Tensor self, Tensor other) -> Tensor",
    77       "aten::gt(Tensor self, Tensor other) -> Tensor",
    78       "aten::ge(Tensor self, Tensor other) -> Tensor",
    79       "aten::eq(Tensor self, Tensor other) -> Tensor",
    80       "aten::ne(Tensor self, Tensor other) -> Tensor",
    81       "aten::lt(Tensor self, Scalar other) -> Tensor",
    82       "aten::le(Tensor self, Scalar other) -> Tensor",
    83       "aten::gt(Tensor self, Scalar other) -> Tensor",
    84       "aten::ge(Tensor self, Scalar other) -> Tensor",
    85       "aten::eq(Tensor self, Scalar other) -> Tensor",
    86       "aten::ne(Tensor self, Scalar other) -> Tensor",
    87       "aten::abs(Tensor self) -> Tensor",
    88       "aten::acos(Tensor self) -> Tensor",
    89       "aten::asin(Tensor self) -> Tensor",
    90       "aten::atan(Tensor self) -> Tensor",
    91       "aten::ceil(Tensor self) -> Tensor",
    92       "aten::cos(Tensor self) -> Tensor",
    93       "aten::cosh(Tensor self) -> Tensor",
    94       "aten::exp(Tensor self) -> Tensor",
    95       "aten::expm1(Tensor self) -> Tensor",
    96       "aten::floor(Tensor self) -> Tensor",
    97       "aten::fmod(Tensor self, Scalar other) -> Tensor",
    98       "aten::frac(Tensor self) -> Tensor",
    99       "aten::log(Tensor self) -> Tensor",
   100       "aten::log10(Tensor self) -> Tensor",
   101       "aten::log1p(Tensor self) -> Tensor",
   102       "aten::log2(Tensor self) -> Tensor",
   103       "aten::rand_like(Tensor self) -> Tensor",
   104       "aten::reciprocal(Tensor self) -> Tensor",
   105       "aten::remainder(Tensor self, Scalar other) -> Tensor",
   106       "aten::round(Tensor self) -> Tensor",
   107       "aten::rsqrt(Tensor self) -> Tensor",
   108       "aten::sin(Tensor self) -> Tensor",
   109       "aten::sinh(Tensor self) -> Tensor",
   110       "aten::tan(Tensor self) -> Tensor",
   111       "aten::trunc(Tensor self) -> Tensor",
   112       "aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)",
   113       "aten::log_softmax(Tensor self, int dim) -> Tensor",
   114       "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
   115       "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
   116       "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)",
   117       "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
   125   if (n->kind() == prim::Constant || n->kind() == prim::AutogradZero ||
   126       n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk)
   128   if (differentiable_ops.find(n))
   132           "aten::dropout(Tensor input, float p, bool train) -> Tensor", attr::train)) {
   133     return n->get<
bool>(attr::train).value();
   136   auto schema = n->maybeSchema();
   137   if (schema && hasGradientInfoForSchema(*schema)) {
   142           "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
   143     return n->get<std::vector<int64_t>>(attr::size) &&
   144         n->is_constant(attr::implicit) &&
   145         n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
   147   if (n->matches(
"aten::view(Tensor self, int[] size) -> Tensor")) {
   148     return n->get<std::vector<int64_t>>(attr::size) &&
   149         n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
   152           "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
   154     return n->namedInput(attr::weight)->node()->mustBeNone();
   159   if (n->kind() == prim::GradOf) {
   160     auto body = n->blocks().at(0);
   162         body->nodes().begin(),
   164         static_cast<bool (*)(Node*)
>(isDifferentiable));
   170 bool isDifferentiable(Graph& g) {
   174       static_cast<bool (*)(Node*)
>(isDifferentiable));
   205     const ArrayRef<Value*>& grads) {
   206   auto graph = node->owningGraph();
   208   auto compiled_graphs = gradientInfoForSchema(node->schema());
   209   if (!compiled_graphs) {
   213   value_list new_outputs;
   215     WithInsertPoint guard(node->next());
   216     auto fw_graph = compiled_graphs->forward;
   217     new_outputs = inlineCallTo(
   218         *graph, *fw_graph, node->inputs(), 
true);
   219     auto outputs = node->outputs();
   220     AT_ASSERT(new_outputs.size() == outputs.size() + 1);
   221     for (
size_t i = 0; i < outputs.size(); ++i) {
   222       new_outputs.at(i)->setType(outputs[i]->type());
   223       outputs[i]->replaceAllUsesWith(new_outputs.at(i));
   228   auto bw_graph = compiled_graphs->backward;
   229   auto grad_vec = grads.vec();
   230   if (needTrimGrad(node)) {
   231     grad_vec.erase(grad_vec.begin()+1, grad_vec.end());
   233   auto it = grad_vec.begin();
   234   grad_vec.insert(it, new_outputs.back());
   235   ArrayRef<Value*> grad(grad_vec);
   237       inlineCallTo(*graph, *bw_graph, grad, 
true);
   242 class GradientHelper {
   244   GradientHelper(Node* n) : node(n) {}
   246   std::vector<Value*> gradient(ArrayRef<Value*> grad_values) {
   247     if (!isDifferentiable(node)) {
   248       throw std::runtime_error(
   249           std::string(
"differentiation of ") + node->kind().toDisplayString() +
   250           " is not supported, or it is missing necessary type information");
   253     auto script_grads = build_script_grad(node, grad_values);
   255       return *script_grads;
   258     auto sym_grads = buildSymbolicGradient(fmap<SymbolicVariable>(grad_values));
   259     return fmap(sym_grads, [](
const SymbolicVariable& v) { 
return v.value(); });
   265   SymbolicVariable gradSumToSizeOf(SymbolicVariable v, Symbol input_name) {
   268       WithInsertPoint insert_guard{node};
   269       size = SymbolicVariable(node->namedInput(input_name)).size();
   271     return v.gradSumToSize(size);
   274   std::vector<SymbolicVariable> buildSymbolicGradient(
   275       const std::vector<SymbolicVariable>& grads) {
   276     static const OperatorSet comparison_ops = {
   277         "aten::lt(Tensor self, Tensor other) -> Tensor",
   278         "aten::le(Tensor self, Tensor other) -> Tensor",
   279         "aten::gt(Tensor self, Tensor other) -> Tensor",
   280         "aten::ge(Tensor self, Tensor other) -> Tensor",
   281         "aten::eq(Tensor self, Tensor other) -> Tensor",
   282         "aten::ne(Tensor self, Tensor other) -> Tensor",
   283         "aten::lt(Tensor self, Scalar other) -> Tensor",
   284         "aten::le(Tensor self, Scalar other) -> Tensor",
   285         "aten::gt(Tensor self, Scalar other) -> Tensor",
   286         "aten::ge(Tensor self, Scalar other) -> Tensor",
   287         "aten::eq(Tensor self, Scalar other) -> Tensor",
   288         "aten::ne(Tensor self, Scalar other) -> Tensor",
   290     auto inputs = fmap<SymbolicVariable>(node->inputs());
   291     auto outputs = fmap<SymbolicVariable>(node->outputs());
   294             "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
   295       return {gradSumToSizeOf(grads.at(0), attr::self),
   297                   grads.at(0) * node->namedInput(attr::alpha), attr::other),
   302             "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
   303       return {grads.at(0), 
nullptr, 
nullptr};
   305     } 
else if (node->kind() == prim::AutogradAdd) {
   307       return {grads.at(0), grads.at(0)};
   311             "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
   312       return {gradSumToSizeOf(grads.at(0), attr::self),
   314                   -grads.at(0) * node->namedInput(attr::alpha), attr::other),
   319             "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
   320       return {grads.at(0), 
nullptr, 
nullptr};
   322     } 
else if (node->matches(
   323                    "aten::mul(Tensor self, Tensor other) -> Tensor")) {
   324       return {gradSumToSizeOf(grads.at(0) * inputs.at(1), attr::self),
   325               gradSumToSizeOf(grads.at(0) * inputs.at(0), attr::other)};
   327     } 
else if (node->matches(
   328                    "aten::mul(Tensor self, Scalar other) -> Tensor")) {
   329       return {grads.at(0) * inputs.at(1), 
nullptr};
   331     } 
else if (node->matches(
   332                    "aten::div(Tensor self, Tensor other) -> Tensor")) {
   333       return {gradSumToSizeOf(grads.at(0) / inputs.at(1), attr::self),
   335                   -grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1)),
   338     } 
else if (node->matches(
   339                    "aten::div(Tensor self, Scalar other) -> Tensor")) {
   340       return {grads.at(0) / inputs.at(1), 
nullptr};
   342     } 
else if (node->matches(
   343                    "aten::max(Tensor self, Tensor other) -> Tensor")) {
   346               grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)),
   349               grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0)),
   352     } 
else if (node->matches(
   353                    "aten::min(Tensor self, Tensor other) -> Tensor")) {
   356               grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)),
   359               grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0)),
   364             "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor")) {
   367                   grads.at(0) * inputs.at(0).type_as(grads.at(0)), attr::self),
   369                   grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0)),
   372     } 
else if (node->matches(
"aten::sigmoid(Tensor self) -> Tensor")) {
   376       return {(1 - outputs.at(0)) * outputs.at(0) * grads.at(0)};
   378     } 
else if (node->matches(
"aten::tanh(Tensor self) -> Tensor")) {
   379       return {grads.at(0) * (1 - outputs.at(0) * outputs.at(0))};
   381     } 
else if (node->matches(
"aten::relu(Tensor self) -> Tensor")) {
   382       return {grads.at(0) *
   383               (outputs.at(0) > 
at::Scalar(0)).type_as(outputs.at(0))};
   387             "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor")) {
   389       Value* min = inputs.at(1);
   390       bool min_must_be_none = min->mustBeNone();
   391       Value* max = inputs.at(2);
   392       bool max_must_be_none = max->mustBeNone();
   397       if (!min_must_be_none && !max_must_be_none) {
   398         return {grads.at(0) *
   399                     (1 - (inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))) *
   400                     (1 - (inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))),
   403       } 
else if (max_must_be_none) {
   404         return {grads.at(0) *
   405                     (1 - (inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))),
   408       } 
else if (min_must_be_none) {
   409         return {grads.at(0) *
   410                     (1 - (inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))),
   414         return {grads.at(0), 
nullptr, 
nullptr};
   418             "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor")) {
   419       auto threshold = node->get<
at::Scalar>(attr::threshold).value();
   420       return {grads.at(0) * (inputs.at(0) > threshold).type_as(outputs.at(0)),
   424     } 
else if (node->matches(
"aten::erf(Tensor self) -> Tensor")) {
   425       return {grads.at(0) * 1.12837916709551 *
   426               (-inputs.at(0) * inputs.at(0)).exp()};
   428     } 
else if (node->matches(
"aten::erfc(Tensor self) -> Tensor")) {
   429       return {-grads.at(0) * 1.12837916709551 *
   430               (-inputs.at(0) * inputs.at(0)).exp()};
   432     } 
else if (node->matches(
"aten::exp(Tensor self) -> Tensor")) {
   433       return {grads.at(0) * (outputs.at(0))};
   435     } 
else if (node->matches(
"aten::t(Tensor self) -> Tensor")) {
   436       return {grads.at(0).t()};
   438     } 
else if (node->matches(
"aten::neg(Tensor self) -> Tensor")) {
   439       return {-grads.at(0)};
   441     } 
else if (node->matches(
"aten::abs(Tensor self) -> Tensor")) {
   442       return {grads.at(0) * inputs.at(0).sign()};
   444     } 
else if (node->matches(
"aten::acos(Tensor self) -> Tensor")) {
   445       return {grads.at(0) *
   446               -((-inputs.at(0) * inputs.at(0) + 
at::Scalar(1)).rsqrt())};
   448     } 
else if (node->matches(
"aten::asin(Tensor self) -> Tensor")) {
   449       return {grads.at(0) *
   450               (-inputs.at(0) * inputs.at(0) + 
at::Scalar(1)).rsqrt()};
   452     } 
else if (node->matches(
"aten::atan(Tensor self) -> Tensor")) {
   453       return {grads.at(0) / (inputs.at(0) * inputs.at(0) + 
at::Scalar(1))};
   457             "aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)")) {
   460         WithInsertPoint insert_guard{node};
   461         self_size = inputs.at(0).size();
   463       return {grads.at(0).expand(self_size), 
nullptr};
   465     } 
else if (node->matches(
"aten::ceil(Tensor self) -> Tensor")) {
   466       return {SymbolicVariable::zeros_like(grads.at(0))};
   468     } 
else if (node->matches(
"aten::cos(Tensor self) -> Tensor")) {
   469       return {grads.at(0) * -inputs.at(0).sin()};
   471     } 
else if (node->matches(
"aten::cosh(Tensor self) -> Tensor")) {
   472       return {grads.at(0) * inputs.at(0).sinh()};
   474     } 
else if (node->matches(
"aten::exp(Tensor self) -> Tensor")) {
   475       return {grads.at(0) * outputs.at(0)};
   477     } 
else if (node->matches(
"aten::expm1(Tensor self) -> Tensor")) {
   478       return {grads.at(0) * (outputs.at(0) + 
at::Scalar(1))};
   480     } 
else if (node->matches(
"aten::floor(Tensor self) -> Tensor")) {
   481       return {SymbolicVariable::zeros_like(grads.at(0))};
   483     } 
else if (node->matches(
   484                    "aten::fmod(Tensor self, Scalar other) -> Tensor")) {
   485       return {grads.at(0), 
nullptr};
   487     } 
else if (node->matches(
"aten::frac(Tensor self) -> Tensor")) {
   488       return {grads.at(0)};
   490     } 
else if (node->matches(
"aten::log(Tensor self) -> Tensor")) {
   491       return {grads.at(0) / inputs.at(0)};
   493     } 
else if (node->matches(
"aten::log10(Tensor self) -> Tensor")) {
   494       return {grads.at(0) / (inputs.at(0) * 2.3025850929940456)};
   496     } 
else if (node->matches(
"aten::log1p(Tensor self) -> Tensor")) {
   497       return {grads.at(0) / (inputs.at(0) + 
at::Scalar(1))};
   499     } 
else if (node->matches(
"aten::log2(Tensor self) -> Tensor")) {
   500       return {grads.at(0) / (inputs.at(0) * 0.6931471805599453)};
   502     } 
else if (node->matches(
"aten::reciprocal(Tensor self) -> Tensor")) {
   503       return {-grads.at(0) * outputs.at(0) * outputs.at(0)};
   505     } 
else if (node->matches(
   506                    "aten::remainder(Tensor self, Scalar other) -> Tensor")) {
   507       return {grads.at(0), 
nullptr};
   509     } 
else if (node->matches(
"aten::round(Tensor self) -> Tensor")) {
   510       return {SymbolicVariable::zeros_like(grads.at(0))};
   512     } 
else if (node->matches(
"aten::rsqrt(Tensor self) -> Tensor")) {
   513       return {grads.at(0) * outputs.at(0).pow(3.) * -0.5};
   515     } 
else if (node->matches(
"aten::sin(Tensor self) -> Tensor")) {
   516       return {grads.at(0) * inputs.at(0).cos()};
   518     } 
else if (node->matches(
"aten::sinh(Tensor self) -> Tensor")) {
   519       return {grads.at(0) * inputs.at(0).cosh()};
   521     } 
else if (node->matches(
"aten::tan(Tensor self) -> Tensor")) {
   522       return {grads.at(0) * (1. + outputs.at(0) * outputs.at(0))};
   524     } 
else if (node->matches(
"aten::trunc(Tensor self) -> Tensor")) {
   525       return {SymbolicVariable::zeros_like(grads.at(0))};
   527     } 
else if (node->kind() == prim::ConstantChunk) {
   528       return {SymbolicVariable::cat(grads, node->i(attr::dim))};
   531         node->matches(
"aten::view(Tensor self, int[] size) -> Tensor") ||
   532         node->matches(
"aten::reshape(Tensor self, int[] shape) -> Tensor")) {
   535       auto sizes = node->namedInput(attr::self)
   537                        ->expect<CompleteTensorType>()
   539       return {grads.at(0).reshape(sizes), 
nullptr};
   541     } 
else if (node->matches(
   542                    "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
   543       return {grads.at(0).type_as(inputs.at(0)), 
nullptr};
   545     } 
else if (node->matches(
"aten::rand_like(Tensor self) -> Tensor")) {
   548     } 
else if (node->matches(
   549                    "aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
   550       return {grads.at(0).squeeze(node->namedInput(attr::dim)), 
nullptr};
   554             "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
   555       return {gradSumToSizeOf(
   556                   grads.at(0) * node->namedInput(attr::beta), attr::self),
   557               grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
   558               inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
   562     } 
else if (node->matches(
"aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
   563       return {grads.at(0).mm(inputs.at(1).t()),
   564               inputs.at(0).t().mm(grads.at(0))};
   568             "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
   569       const auto& input_sizes = inputs.at(0).sizes();
   570       if (input_sizes.size() == 0)
   571         return {grads.at(0).sum(), 
nullptr, 
nullptr};
   572       auto grad_sizes = node->get<std::vector<int64_t>>(attr::size).value();
   573       auto grad = grads.at(0);
   574       while (grad_sizes.size() > input_sizes.size()) {
   575         grad = grad.sum(0, 
false);
   576         grad_sizes.erase(grad_sizes.begin());
   578       for (
size_t i = 0; i < input_sizes.size(); ++i) {
   579         if (input_sizes[i] == 1 && grad_sizes[i] > 1) {
   580           grad = grad.sum(i, 
true);
   583       return {grad, 
nullptr, 
nullptr};
   585     } 
else if (node->matches(
"aten::squeeze(Tensor self) -> Tensor")) {
   586       const auto& sizes = inputs.at(0).sizes();
   587       std::vector<size_t> squeezed_dims;
   588       for (
size_t i = 0; i < sizes.size(); ++i) {
   591         squeezed_dims.push_back(i);
   593       SymbolicVariable returned_grad = grads.at(0);
   594       for (
const auto& dim : squeezed_dims) {
   595         returned_grad = returned_grad.unsqueeze(dim);
   597       return {returned_grad};
   599     } 
else if (node->matches(
   600                    "aten::squeeze(Tensor self, int dim) -> Tensor",
   602       int64_t dim = *node->get<int64_t>(attr::dim);
   603       const auto& sizes = inputs.at(0).sizes();
   605       if (sizes.size() == 0) {
   606         return {grads.at(0), 
nullptr};
   608       return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim),
   611     } 
else if (comparison_ops.find(node)) {
   612       return {
nullptr, 
nullptr};
   616             "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor")) {
   617       AT_ASSERT(grads.size() == 1);
   618       auto graph = node->owningGraph();
   619       auto backward_value = graph->insert(
   620           aten::avg_pool2d_backward,
   621           {grads.at(0).value(),
   622            node->namedInput(attr::self),
   623            node->namedInput(attr::kernel_size),
   624            node->namedInput(attr::stride),
   625            node->namedInput(attr::padding),
   626            node->namedInput(attr::ceil_mode),
   627            node->namedInput(attr::count_include_pad)});
   628       return {backward_value->node()->output(0),
   637             "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)")) {
   638       AT_ASSERT(grads.size() == 2);
   639       auto graph = node->owningGraph();
   640       auto backward_value = graph->insert(
   641           aten::max_pool2d_with_indices_backward,
   642           {grads.at(0).value(),
   643            node->namedInput(attr::self),
   644            node->namedInput(attr::kernel_size),
   645            node->namedInput(attr::stride),
   646            node->namedInput(attr::padding),
   647            node->namedInput(attr::dilation),
   648            node->namedInput(attr::ceil_mode),
   649            outputs.at(1).value()});
   650       return {backward_value->node()->output(0),
   659             "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) {
   660       auto graph = node->owningGraph();
   661       auto backward_value = graph->insert(
   662           aten::thnn_conv2d_backward,
   663           {grads.at(0).value(),
   664            inputs.at(0).value(),
   665            inputs.at(1).value(),
   666            node->namedInput(attr::kernel_size),
   667            node->namedInput(attr::stride),
   668            node->namedInput(attr::padding),
   669            outputs.at(1).value(),
   670            outputs.at(2).value(),
   671            graph->insertConstant(std::vector<bool>{
true, 
true, 
true})});
   674       Node* tuple_unpack_node =
   675           graph->insertNode(graph->createTupleUnpack(backward_value));
   676       auto tuple_outputs = tuple_unpack_node->outputs();
   677       AT_ASSERT(tuple_outputs.size() == size_t(3));
   678       return {tuple_outputs[0],
   687             "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) {
   688       auto graph = node->owningGraph();
   689       auto backward_value = graph->insert(
   690           aten::native_batch_norm_backward,
   691           {grads.at(0).value(),
   692            inputs.at(0).value(),
   693            inputs.at(1).value(),
   694            inputs.at(3).value(),
   695            inputs.at(4).value(),
   696            outputs.at(1).value(),
   697            outputs.at(2).value(),
   698            inputs.at(5).value(),
   699            inputs.at(7).value(),
   700            graph->insertConstant(std::vector<bool>{
true, 
true, 
true})});
   703       Node* tuple_unpack_node =
   704           graph->insertNode(graph->createTupleUnpack(backward_value));
   705       auto tuple_outputs = tuple_unpack_node->outputs();
   706       AT_ASSERT(tuple_outputs.size() == size_t(3));
   707       return {tuple_outputs[0],
   718             "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
   719       auto graph = node->owningGraph();
   720       auto total_weight = graph->insertNode(graph->createAutogradZero());
   721       auto weight = graph->insertNode(graph->createNone(TensorType::get()));
   722       auto backward_value = graph->insert(
   723           aten::nll_loss_backward,
   724           {grads.at(0).value(),
   725            inputs.at(0).value(),
   726            inputs.at(1).value(),
   728            inputs.at(3).value(),
   729            inputs.at(4).value(),
   730            total_weight->output()});
   731       return {backward_value->node()->output(0),
   737     } 
else if (node->matches(
   738                    "aten::log_softmax(Tensor self, int dim) -> Tensor")) {
   739       AT_ASSERT(grads.size() == 1);
   740       auto graph = node->owningGraph();
   741       auto backward_value = graph->insert(
   742           aten::_log_softmax_backward_data,
   743           {grads.at(0).value(),
   744            outputs.at(0).value(),
   745            node->namedInput(attr::dim),
   746            node->namedInput(attr::self)});
   747       return {backward_value->node()->output(0), 
nullptr};
   750         node->kind() == prim::Constant || node->kind() == prim::AutogradZero) {
   753     throw std::runtime_error(
   754         std::string(
"failed to differentiate `") +
   755         node->kind().toDisplayString() + 
"`");
   774 static std::vector<Value*> linearGradientForNode(
   776     ArrayRef<Value*> grad_values) {
   777   auto& graph = *node->owningGraph();
   780   if (needTrimGrad(node)) {
   781     grad_values = grad_values.at(0);
   783   auto linear = graph.insertNode(graph.create(prim::GradOf, {grad_values}, 0));
   785   linear->s_(attr::name, node->kind().toDisplayString());
   786   auto block = linear->addBlock();
   787   WithInsertPoint guard(block);
   788   auto results = GradientHelper(node).gradient(grad_values);
   789   return fmap(results, [block, linear](Value* grad) -> Value* {
   792     block->registerOutput(grad);
   793     return linear->addOutput()->copyMetadata(grad);
   799       : grad_map(std::move(grad_map)), reverse_block(reverse_block) {}
   802   Block* reverse_block;
   811   auto graph = a->owningGraph();
   812   return graph->insertNode(graph->create(prim::AutogradAdd, {a, b}))->output();
   824   auto& graph = *grad_desc.f;
   828   auto reverse_node = graph.create(prim::Reverse, 0);
   829   auto reverse_block = reverse_node->addBlock();
   833   const auto get_grad = [&](
Value* v) -> 
Value* {
   834     auto it = grad_map.find(v);
   835     if (it == grad_map.end()) {
   836       auto undef = graph.insertNode(graph.createAutogradZero());
   837       std::tie(it, std::ignore) = grad_map.emplace(v, undef->output());
   841   const auto set_grad = [&](
Value* x, 
Value* dx) {
   842     if (
Value* prev_grad = grad_map[x]) {
   843       grad_map[x] = createAutogradAdd(prev_grad, dx);
   849   auto outputs = graph.outputs();
   850   for (
size_t i = 0, num_outputs = outputs.size(); i < num_outputs; ++i) {
   851     Value* output = outputs[i];
   852     if (!output->requires_grad())
   854     Value* output_grad = reverse_block->addInput()->setType(output->type());
   855     set_grad(output, output_grad);
   856     grad_desc.df_input_vjps.push_back(i);
   859   for (
auto it = graph.nodes().rbegin(), end = graph.nodes().rend(); it != end;
   862     auto inputs = node->inputs();
   863     auto outputs = node->outputs();
   864     if (std::all_of(outputs.begin(), outputs.end(), [](
Value* v) {
   865           return !v->requires_grad();
   870     value_list grad_inputs =
   871         linearGradientForNode(node, fmap(node->outputs(), get_grad));
   872     LowerSimpleTuples(reverse_block);
   874     AT_ASSERT(grad_inputs.size() == node->inputs().size());
   875     for (
size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) {
   883       set_grad(inputs[i], grad_inputs[i]);
   887   auto inputs = graph.inputs();
   888   for (
size_t i = 0, num_inputs = inputs.size(); i < num_inputs; ++i) {
   889     Value* input = inputs[i];
   890     if (!input->requires_grad())
   897     if (grad_map.count(input) == 0)
   899     reverse_block->registerOutput(get_grad(input));
   900     grad_desc.df_output_vjps.push_back(i);
   908 static value_list getReverseCaptures(
Gradient& grad_desc) {
   909   auto& graph = *grad_desc.f;
   910   auto primal_block = graph.block();
   912   value_set reverse_captures_set;
   913   value_list reverse_captures; 
   914   auto check_uses = [&](
Value* v) {
   915     for (
auto use : v->uses()) {
   916       if (use.user->owningBlock() == primal_block)
   918       if ( reverse_captures_set.emplace(v).second) {
   919         reverse_captures.push_back(v);
   923   for (
Value* input : graph.inputs()) {
   926   for (
Node* node : graph.nodes()) {
   927     for (
Value* output : node->outputs())
   930   return reverse_captures;
   939   static const auto err = [](
Value*) -> 
Value* {
   940     throw std::runtime_error(
"unexpected input");
   942   auto& graph = *grad_desc.f;
   943   Block* reverse_block = rev_info.reverse_block;
   945   for (
Node* top_node : reverse_block->nodes()) {
   947         top_node->kind() == prim::GradOf ||
   948         top_node->kind() == prim::AutogradAdd ||
   949         top_node->kind() == prim::AutogradZero);
   950     if (top_node->kind() != prim::GradOf)
   952     Block* grad_body = top_node->blocks().at(0);
   953     for (
Node* node : grad_body->nodes()) {
   954       for (
Value* input : node->inputs()) {
   955         if (input->node()->kind() != prim::Constant)
   957         if (input->node()->owningBlock() == grad_body)
   959         Node* lifted_constant = graph.createClone(input->node(), err);
   960         reverse_block->prependNode(lifted_constant);
   961         node->replaceInputWith(input, lifted_constant->output());
   967 static void deduplicateSizeCaptures(
   970   Block* primal_block = grad_desc.f->block();
   971   const auto usedOnlyInReverse = [primal_block](
Value* v) {
   972     const auto& uses = v->uses();
   973     return std::all_of(uses.begin(), uses.end(), [primal_block](
const Use& u) {
   974       return u.user->owningBlock() != primal_block;
   977   auto captures = getReverseCaptures(grad_desc);
   978   value_set capture_set(captures.begin(), captures.end());
   979   for (
Value* capture : captures) {
   980     Node* node = capture->node();
   981     if (!node->matches(
"aten::size(Tensor self) -> int[]")) {
   984     if (usedOnlyInReverse(capture) && capture_set.count(node->input())) {
   985       WithInsertPoint insert_guard{*rev_info.reverse_block->nodes().begin()};
  1001   std::function<void(const std::unordered_set<const Value*>&)> cb =
  1002       [&](
const std::unordered_set<const Value*>& live_values) {
  1003         std::vector<Value*> to_erase;
  1004         for (
auto& entry : rev_info.grad_map) {
  1005           if (!live_values.count(entry.second)) {
  1006             to_erase.push_back(entry.first);
  1009         for (
Value* v : to_erase) {
  1010           rev_info.grad_map.erase(v);
  1013   EliminateDeadCode(rev_info.reverse_block, std::move(cb));
  1025   liftConstants(grad_desc, rev_info);
  1029   EliminateCommonSubexpression(grad_desc.f);
  1030   deduplicateSizeCaptures(grad_desc, rev_info);
  1031   eliminateDeadCode(rev_info);
  1043   auto& graph = *grad_desc.f;
  1044   auto reverse_block = rev_info.reverse_block;
  1054   value_list reverse_captures = getReverseCaptures(grad_desc);
  1071   grad_desc.f_real_outputs = graph.outputs().size();
  1073   std::unordered_map<Value*, size_t> orig_primal_outputs_idx;
  1074   std::unordered_map<Value*, size_t> orig_primal_inputs_idx;
  1077   for (
size_t i = 0, num_outputs = graph.outputs().size(); i < num_outputs; ++i)
  1078     orig_primal_outputs_idx.emplace(graph.outputs()[i], i);
  1079   for (
size_t i = 0, num_inputs = graph.inputs().size(); i < num_inputs; ++i)
  1080     orig_primal_inputs_idx[graph.inputs()[i]] = i;
  1083   for (
Value* capture_val : reverse_captures) {
  1086     if (orig_primal_outputs_idx.count(capture_val) > 0) {
  1087       grad_desc.df_input_captured_outputs.push_back(
  1088           orig_primal_outputs_idx[capture_val]);
  1091     } 
else if (orig_primal_inputs_idx.count(capture_val) > 0) {
  1092       grad_desc.df_input_captured_inputs.push_back(
  1093           orig_primal_inputs_idx.at(capture_val));
  1099       graph.registerOutput(capture_val);
  1100       grad_desc.df_input_captured_outputs.emplace_back(
  1101           graph.outputs().size() - 1);
  1109   for (
size_t i = grad_desc.f_real_outputs; i < graph.outputs().size(); ++i) {
  1110     Value* tmp = graph.outputs().at(i);
  1118     if (rev_info.grad_map.count(tmp) == 0)
  1120     Value* tmp_vjp_in = reverse_block->addInput()->setType(tmp->type());
  1121     Value* tmp_vjp_prev = rev_info.grad_map.at(tmp);
  1126     Value* new_vjp = createAutogradAdd(tmp_vjp_in, tmp_vjp_in);
  1127     new_vjp->node()->moveAfter(tmp_vjp_prev->node());
  1128     tmp_vjp_prev->replaceAllUsesWith(new_vjp);
  1129     new_vjp->node()->replaceInput(1, tmp_vjp_prev);
  1130     grad_desc.df_input_vjps.emplace_back(i);
  1137   std::unordered_map<Value*, size_t> capture_to_formal_index;
  1138   const auto& add_capture = [&](
Value* captured) {
  1139     capture_to_formal_index[captured] = reverse_block->inputs().size();
  1140     reverse_block->addInput()->copyMetadata(captured);
  1142   for (
auto& offset : grad_desc.df_input_captured_inputs)
  1143     add_capture(graph.inputs()[offset]);
  1144   for (
auto& offset : grad_desc.df_input_captured_outputs)
  1145     add_capture(graph.outputs()[offset]);
  1147   grad_desc.df = std::make_shared<Graph>();
  1148   grad_desc.df->block()->cloneFrom(reverse_block, [&](
Value* v) {
  1149     return grad_desc.df->inputs()[capture_to_formal_index.at(v)];
  1153   reverse_block->owningNode()->destroy();
  1156 Gradient differentiate(std::shared_ptr<Graph>& graph) {
  1160       graph.use_count() == 1,
  1161       "differentiate will mutate and destroy the graph, so it requires "  1162       "graph.use_count() == 1, but found %d",
  1164   std::swap(graph, grad_desc.f);
  1169   auto rev_info = addReverseInline(grad_desc);
  1170   Optimize(grad_desc, rev_info);
  1172   EliminateDeadCode(grad_desc.f->block());
  1176   lambdaLiftReverse(grad_desc, rev_info);
  1179   ConstantPooling(grad_desc.df);
 Scalar represents a 0-dimensional tensor which contains a single element. 
An utility class for setting temporary insertion points. 
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...