Caffe2 - C++ API
A deep learning, cross platform ML framework
autodiff.cpp
1 #include <torch/csrc/jit/autodiff.h>
2 
3 #include <ATen/core/functional.h>
4 #include <torch/csrc/jit/operator.h>
5 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
6 #include <torch/csrc/jit/passes/constant_pooling.h>
7 #include <torch/csrc/jit/passes/dead_code_elimination.h>
8 #include <torch/csrc/jit/passes/lower_tuples.h>
9 #include <torch/csrc/jit/script/compiler.h>
10 #include <torch/csrc/jit/symbolic_script.h>
11 #include <torch/csrc/jit/symbolic_variable.h>
12 
13 #include <c10/util/Exception.h>
14 
15 #include <algorithm>
16 #include <memory>
17 
18 namespace torch {
19 namespace jit {
20 
21 using value_map = std::unordered_map<Value*, Value*>;
22 using value_set = std::unordered_set<Value*>;
23 
24 void wrapDim(int64_t& dim, const std::vector<int64_t>& sizes) {
25  if (dim < 0) {
26  dim += sizes.size();
27  }
28 }
29 
30 // need_trim_grad_ops contains functions that return multiple outputs in
31 // forward, but only the first one requires grad.
32 // Example:
33 // kthvalue returns (kthvalue, index of kthvalue), currently autodiff only
34 // supports at most one output that requires grad. Thus we need to remove
35 // the grad for index that doesn't require grad.
36 bool needTrimGrad(Node* n) {
37  static OperatorSet need_trim_grad_ops = {
38  "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
39  "aten::topk(Tensor self, int k, int dim, bool largest, bool sorted) -> (Tensor, Tensor)",
40  };
41  if (need_trim_grad_ops.find(n)) {
42  return true;
43  }
44  return false;
45 }
46 
47 bool isDifferentiable(Node* n) {
48  // TODO: scalar-tensor ops should be canonicalized
49  static OperatorSet differentiable_ops = {
50  "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
51  "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
52  "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
53  "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
54  "aten::mul(Tensor self, Tensor other) -> Tensor",
55  "aten::mul(Tensor self, Scalar other) -> Tensor",
56  "aten::div(Tensor self, Tensor other) -> Tensor",
57  "aten::div(Tensor self, Scalar other) -> Tensor",
58  "aten::max(Tensor self, Tensor other) -> Tensor",
59  "aten::min(Tensor self, Tensor other) -> Tensor",
60  "aten::sigmoid(Tensor self) -> Tensor",
61  "aten::tanh(Tensor self) -> Tensor",
62  "aten::relu(Tensor self) -> Tensor",
63  "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
64  "aten::erf(Tensor self) -> Tensor",
65  "aten::erfc(Tensor self) -> Tensor",
66  "aten::exp(Tensor self) -> Tensor",
67  "aten::t(Tensor self) -> Tensor",
68  "aten::neg(Tensor self) -> Tensor",
69  "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
70  "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
71  "aten::type_as(Tensor self, Tensor other) -> Tensor",
72  "aten::unsqueeze(Tensor self, int dim) -> Tensor",
73  "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
74  "aten::mm(Tensor self, Tensor mat2) -> Tensor",
75  "aten::lt(Tensor self, Tensor other) -> Tensor",
76  "aten::le(Tensor self, Tensor other) -> Tensor",
77  "aten::gt(Tensor self, Tensor other) -> Tensor",
78  "aten::ge(Tensor self, Tensor other) -> Tensor",
79  "aten::eq(Tensor self, Tensor other) -> Tensor",
80  "aten::ne(Tensor self, Tensor other) -> Tensor",
81  "aten::lt(Tensor self, Scalar other) -> Tensor",
82  "aten::le(Tensor self, Scalar other) -> Tensor",
83  "aten::gt(Tensor self, Scalar other) -> Tensor",
84  "aten::ge(Tensor self, Scalar other) -> Tensor",
85  "aten::eq(Tensor self, Scalar other) -> Tensor",
86  "aten::ne(Tensor self, Scalar other) -> Tensor",
87  "aten::abs(Tensor self) -> Tensor",
88  "aten::acos(Tensor self) -> Tensor",
89  "aten::asin(Tensor self) -> Tensor",
90  "aten::atan(Tensor self) -> Tensor",
91  "aten::ceil(Tensor self) -> Tensor",
92  "aten::cos(Tensor self) -> Tensor",
93  "aten::cosh(Tensor self) -> Tensor",
94  "aten::exp(Tensor self) -> Tensor",
95  "aten::expm1(Tensor self) -> Tensor",
96  "aten::floor(Tensor self) -> Tensor",
97  "aten::fmod(Tensor self, Scalar other) -> Tensor",
98  "aten::frac(Tensor self) -> Tensor",
99  "aten::log(Tensor self) -> Tensor",
100  "aten::log10(Tensor self) -> Tensor",
101  "aten::log1p(Tensor self) -> Tensor",
102  "aten::log2(Tensor self) -> Tensor",
103  "aten::rand_like(Tensor self) -> Tensor",
104  "aten::reciprocal(Tensor self) -> Tensor",
105  "aten::remainder(Tensor self, Scalar other) -> Tensor",
106  "aten::round(Tensor self) -> Tensor",
107  "aten::rsqrt(Tensor self) -> Tensor",
108  "aten::sin(Tensor self) -> Tensor",
109  "aten::sinh(Tensor self) -> Tensor",
110  "aten::tan(Tensor self) -> Tensor",
111  "aten::trunc(Tensor self) -> Tensor",
112  "aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)",
113  "aten::log_softmax(Tensor self, int dim) -> Tensor",
114  "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
115  "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
116  "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)",
117  "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
118  };
119 
120  // TODO: add support for the following fusible operators.
121  // They're a little tricky to implement; max/min require mutability for best
122  // perf "aten::atan2(Tensor self) -> Tensor", "aten::max(Tensor self) ->
123  // Tensor", "aten::min(Tensor self) -> Tensor"
124 
125  if (n->kind() == prim::Constant || n->kind() == prim::AutogradZero ||
126  n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk)
127  return true;
128  if (differentiable_ops.find(n))
129  return true;
130 
131  if (n->matches(
132  "aten::dropout(Tensor input, float p, bool train) -> Tensor", attr::train)) {
133  return n->get<bool>(attr::train).value();
134  }
135 
136  auto schema = n->maybeSchema();
137  if (schema && hasGradientInfoForSchema(*schema)) {
138  return true;
139  }
140 
141  if (n->matches(
142  "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
143  return n->get<std::vector<int64_t>>(attr::size) &&
144  n->is_constant(attr::implicit) &&
145  n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
146  }
147  if (n->matches("aten::view(Tensor self, int[] size) -> Tensor")) {
148  return n->get<std::vector<int64_t>>(attr::size) &&
149  n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
150  }
151  if (n->matches(
152  "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
153  // TODO(asuhan): support weight
154  return n->namedInput(attr::weight)->node()->mustBeNone();
155  }
156 
157  // linear blocks may appear as inputs to graph executors, but they are removed
158  // before differentiation occurs
159  if (n->kind() == prim::GradOf) {
160  auto body = n->blocks().at(0);
161  return std::all_of(
162  body->nodes().begin(),
163  body->nodes().end(),
164  static_cast<bool (*)(Node*)>(isDifferentiable));
165  }
166 
167  return false;
168 }
169 
170 bool isDifferentiable(Graph& g) {
171  return std::all_of(
172  g.nodes().begin(),
173  g.nodes().end(),
174  static_cast<bool (*)(Node*)>(isDifferentiable));
175 }
176 
177 // NB: Write gradient using torchscript
178 // For example, node aten::mul() should be defined as follows
179 // def forward(x, y):
180 // return x*y, (x, y)
181 // def backward(ctx, grad_output):
182 // x, y = ctx
183 // return (y * grad_output).sum_to_size(x), (x * grad_output).sum_to_size(y)
184 //
185 // Here ctx is a tuple that carries all input/intermediate results needed in
186 // backward from forward pass.
187 //
188 // This python code is compiled into a GradientPair which includes a forward
189 // graph and a backward graph. Forward graph will be used to replace the node in
190 // grad_desc.f, and backward graph will be used to construct GradOf(node) in
191 // reverse_block. Grad_values(a.k.a gradOutputs) propagated through
192 // node->owningGraph() in **reversed** order, thus GradientPair.forward ahould
193 // be inserted **after** the node being replaced, so that we don't traverse the
194 // graph infinite times.
195 //
196 // The output of compiled forward graph is [real_outputs, ctx]
197 // The input of compiled backward graph is [ctx, grad_values]
198 // We run LowerSimpleTuples afterwards to elmininate all tuples generated in
199 // this process. The original node and TupleConstruct nodes in forward graph
200 // will be cleaned up later using EliminateDeadCode(block). TupleUnPack node in
201 // backward graph will be removed in eliminateDeadcode(ReverseDetails) defined
202 // in this file.
203 static c10::optional<std::vector<Value*>> build_script_grad(
204  Node* node,
205  const ArrayRef<Value*>& grads) {
206  auto graph = node->owningGraph();
207 
208  auto compiled_graphs = gradientInfoForSchema(node->schema());
209  if (!compiled_graphs) {
210  return c10::nullopt;
211  }
212  // Use forward graph to replace node in grad_desc.f
213  value_list new_outputs;
214  {
215  WithInsertPoint guard(node->next());
216  auto fw_graph = compiled_graphs->forward;
217  new_outputs = inlineCallTo(
218  *graph, *fw_graph, node->inputs(), /*unpack_outputs=*/true);
219  auto outputs = node->outputs();
220  AT_ASSERT(new_outputs.size() == outputs.size() + 1);
221  for (size_t i = 0; i < outputs.size(); ++i) {
222  new_outputs.at(i)->setType(outputs[i]->type());
223  outputs[i]->replaceAllUsesWith(new_outputs.at(i));
224  }
225  }
226 
227  // Use backward graph to construct reverse_block
228  auto bw_graph = compiled_graphs->backward;
229  auto grad_vec = grads.vec();
230  if (needTrimGrad(node)) {
231  grad_vec.erase(grad_vec.begin()+1, grad_vec.end());
232  }
233  auto it = grad_vec.begin();
234  grad_vec.insert(it, new_outputs.back());
235  ArrayRef<Value*> grad(grad_vec);
236  auto grad_inputs =
237  inlineCallTo(*graph, *bw_graph, grad, /*unpack_outputs=*/true);
238  return grad_inputs;
239 };
240 
241 namespace {
242 class GradientHelper {
243  public:
244  GradientHelper(Node* n) : node(n) {}
245 
246  std::vector<Value*> gradient(ArrayRef<Value*> grad_values) {
247  if (!isDifferentiable(node)) {
248  throw std::runtime_error(
249  std::string("differentiation of ") + node->kind().toDisplayString() +
250  " is not supported, or it is missing necessary type information");
251  }
252  // If AD is defined using torchscript, use it instead of symbolic
253  auto script_grads = build_script_grad(node, grad_values);
254  if (script_grads)
255  return *script_grads;
256  // Definition not found in torchscript, look up in the buildSymbolicGradient
257  // TODO: migrate all to using torchscript
258  auto sym_grads = buildSymbolicGradient(fmap<SymbolicVariable>(grad_values));
259  return fmap(sym_grads, [](const SymbolicVariable& v) { return v.value(); });
260  }
261 
262  private:
263  Node* node;
264 
265  SymbolicVariable gradSumToSizeOf(SymbolicVariable v, Symbol input_name) {
266  Value* size;
267  {
268  WithInsertPoint insert_guard{node};
269  size = SymbolicVariable(node->namedInput(input_name)).size();
270  }
271  return v.gradSumToSize(size);
272  };
273 
274  std::vector<SymbolicVariable> buildSymbolicGradient(
275  const std::vector<SymbolicVariable>& grads) {
276  static const OperatorSet comparison_ops = {
277  "aten::lt(Tensor self, Tensor other) -> Tensor",
278  "aten::le(Tensor self, Tensor other) -> Tensor",
279  "aten::gt(Tensor self, Tensor other) -> Tensor",
280  "aten::ge(Tensor self, Tensor other) -> Tensor",
281  "aten::eq(Tensor self, Tensor other) -> Tensor",
282  "aten::ne(Tensor self, Tensor other) -> Tensor",
283  "aten::lt(Tensor self, Scalar other) -> Tensor",
284  "aten::le(Tensor self, Scalar other) -> Tensor",
285  "aten::gt(Tensor self, Scalar other) -> Tensor",
286  "aten::ge(Tensor self, Scalar other) -> Tensor",
287  "aten::eq(Tensor self, Scalar other) -> Tensor",
288  "aten::ne(Tensor self, Scalar other) -> Tensor",
289  };
290  auto inputs = fmap<SymbolicVariable>(node->inputs());
291  auto outputs = fmap<SymbolicVariable>(node->outputs());
292 
293  if (node->matches(
294  "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
295  return {gradSumToSizeOf(grads.at(0), attr::self),
296  gradSumToSizeOf(
297  grads.at(0) * node->namedInput(attr::alpha), attr::other),
298  nullptr};
299 
300  } else if (
301  node->matches(
302  "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
303  return {grads.at(0), nullptr, nullptr};
304 
305  } else if (node->kind() == prim::AutogradAdd) {
306  // NB: AutogradAdds don't broadcast
307  return {grads.at(0), grads.at(0)};
308 
309  } else if (
310  node->matches(
311  "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
312  return {gradSumToSizeOf(grads.at(0), attr::self),
313  gradSumToSizeOf(
314  -grads.at(0) * node->namedInput(attr::alpha), attr::other),
315  nullptr};
316 
317  } else if (
318  node->matches(
319  "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
320  return {grads.at(0), nullptr, nullptr};
321 
322  } else if (node->matches(
323  "aten::mul(Tensor self, Tensor other) -> Tensor")) {
324  return {gradSumToSizeOf(grads.at(0) * inputs.at(1), attr::self),
325  gradSumToSizeOf(grads.at(0) * inputs.at(0), attr::other)};
326 
327  } else if (node->matches(
328  "aten::mul(Tensor self, Scalar other) -> Tensor")) {
329  return {grads.at(0) * inputs.at(1), nullptr};
330 
331  } else if (node->matches(
332  "aten::div(Tensor self, Tensor other) -> Tensor")) {
333  return {gradSumToSizeOf(grads.at(0) / inputs.at(1), attr::self),
334  gradSumToSizeOf(
335  -grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1)),
336  attr::other)};
337 
338  } else if (node->matches(
339  "aten::div(Tensor self, Scalar other) -> Tensor")) {
340  return {grads.at(0) / inputs.at(1), nullptr};
341 
342  } else if (node->matches(
343  "aten::max(Tensor self, Tensor other) -> Tensor")) {
344  return {
345  gradSumToSizeOf(
346  grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)),
347  attr::self),
348  gradSumToSizeOf(
349  grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0)),
350  attr::other)};
351 
352  } else if (node->matches(
353  "aten::min(Tensor self, Tensor other) -> Tensor")) {
354  return {
355  gradSumToSizeOf(
356  grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)),
357  attr::self),
358  gradSumToSizeOf(
359  grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0)),
360  attr::other)};
361 
362  } else if (
363  node->matches(
364  "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor")) {
365  return {nullptr,
366  gradSumToSizeOf(
367  grads.at(0) * inputs.at(0).type_as(grads.at(0)), attr::self),
368  gradSumToSizeOf(
369  grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0)),
370  attr::other)};
371 
372  } else if (node->matches("aten::sigmoid(Tensor self) -> Tensor")) {
373  // TODO: The order of operations matter in this case. This
374  // works for ppc64le and x86_64. Need to look at why the
375  // order matters.
376  return {(1 - outputs.at(0)) * outputs.at(0) * grads.at(0)};
377 
378  } else if (node->matches("aten::tanh(Tensor self) -> Tensor")) {
379  return {grads.at(0) * (1 - outputs.at(0) * outputs.at(0))};
380 
381  } else if (node->matches("aten::relu(Tensor self) -> Tensor")) {
382  return {grads.at(0) *
383  (outputs.at(0) > at::Scalar(0)).type_as(outputs.at(0))};
384 
385  } else if (
386  node->matches(
387  "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor")) {
388  // handle the case that min/max is None
389  Value* min = inputs.at(1);
390  bool min_must_be_none = min->mustBeNone();
391  Value* max = inputs.at(2);
392  bool max_must_be_none = max->mustBeNone();
393  // XXX - this formula is wrong when min or max are not stricly a constant
394  // None but may be None dynamically. In this case an internal compiler
395  // error will get thrown when trying to generate expressions involving the
396  // values of min/max
397  if (!min_must_be_none && !max_must_be_none) {
398  return {grads.at(0) *
399  (1 - (inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))) *
400  (1 - (inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))),
401  nullptr,
402  nullptr};
403  } else if (max_must_be_none) {
404  return {grads.at(0) *
405  (1 - (inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))),
406  nullptr,
407  nullptr};
408  } else if (min_must_be_none) {
409  return {grads.at(0) *
410  (1 - (inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))),
411  nullptr,
412  nullptr};
413  } else {
414  return {grads.at(0), nullptr, nullptr};
415  }
416  } else if (
417  node->matches(
418  "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor")) {
419  auto threshold = node->get<at::Scalar>(attr::threshold).value();
420  return {grads.at(0) * (inputs.at(0) > threshold).type_as(outputs.at(0)),
421  nullptr,
422  nullptr};
423 
424  } else if (node->matches("aten::erf(Tensor self) -> Tensor")) {
425  return {grads.at(0) * 1.12837916709551 *
426  (-inputs.at(0) * inputs.at(0)).exp()};
427 
428  } else if (node->matches("aten::erfc(Tensor self) -> Tensor")) {
429  return {-grads.at(0) * 1.12837916709551 *
430  (-inputs.at(0) * inputs.at(0)).exp()};
431 
432  } else if (node->matches("aten::exp(Tensor self) -> Tensor")) {
433  return {grads.at(0) * (outputs.at(0))};
434 
435  } else if (node->matches("aten::t(Tensor self) -> Tensor")) {
436  return {grads.at(0).t()};
437 
438  } else if (node->matches("aten::neg(Tensor self) -> Tensor")) {
439  return {-grads.at(0)};
440 
441  } else if (node->matches("aten::abs(Tensor self) -> Tensor")) {
442  return {grads.at(0) * inputs.at(0).sign()};
443 
444  } else if (node->matches("aten::acos(Tensor self) -> Tensor")) {
445  return {grads.at(0) *
446  -((-inputs.at(0) * inputs.at(0) + at::Scalar(1)).rsqrt())};
447 
448  } else if (node->matches("aten::asin(Tensor self) -> Tensor")) {
449  return {grads.at(0) *
450  (-inputs.at(0) * inputs.at(0) + at::Scalar(1)).rsqrt()};
451 
452  } else if (node->matches("aten::atan(Tensor self) -> Tensor")) {
453  return {grads.at(0) / (inputs.at(0) * inputs.at(0) + at::Scalar(1))};
454 
455  } else if (
456  node->matches(
457  "aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)")) {
458  Value* self_size;
459  {
460  WithInsertPoint insert_guard{node};
461  self_size = inputs.at(0).size();
462  }
463  return {grads.at(0).expand(self_size), nullptr};
464 
465  } else if (node->matches("aten::ceil(Tensor self) -> Tensor")) {
466  return {SymbolicVariable::zeros_like(grads.at(0))};
467 
468  } else if (node->matches("aten::cos(Tensor self) -> Tensor")) {
469  return {grads.at(0) * -inputs.at(0).sin()};
470 
471  } else if (node->matches("aten::cosh(Tensor self) -> Tensor")) {
472  return {grads.at(0) * inputs.at(0).sinh()};
473 
474  } else if (node->matches("aten::exp(Tensor self) -> Tensor")) {
475  return {grads.at(0) * outputs.at(0)};
476 
477  } else if (node->matches("aten::expm1(Tensor self) -> Tensor")) {
478  return {grads.at(0) * (outputs.at(0) + at::Scalar(1))};
479 
480  } else if (node->matches("aten::floor(Tensor self) -> Tensor")) {
481  return {SymbolicVariable::zeros_like(grads.at(0))};
482 
483  } else if (node->matches(
484  "aten::fmod(Tensor self, Scalar other) -> Tensor")) {
485  return {grads.at(0), nullptr};
486 
487  } else if (node->matches("aten::frac(Tensor self) -> Tensor")) {
488  return {grads.at(0)};
489 
490  } else if (node->matches("aten::log(Tensor self) -> Tensor")) {
491  return {grads.at(0) / inputs.at(0)};
492 
493  } else if (node->matches("aten::log10(Tensor self) -> Tensor")) {
494  return {grads.at(0) / (inputs.at(0) * 2.3025850929940456)};
495 
496  } else if (node->matches("aten::log1p(Tensor self) -> Tensor")) {
497  return {grads.at(0) / (inputs.at(0) + at::Scalar(1))};
498 
499  } else if (node->matches("aten::log2(Tensor self) -> Tensor")) {
500  return {grads.at(0) / (inputs.at(0) * 0.6931471805599453)};
501 
502  } else if (node->matches("aten::reciprocal(Tensor self) -> Tensor")) {
503  return {-grads.at(0) * outputs.at(0) * outputs.at(0)};
504 
505  } else if (node->matches(
506  "aten::remainder(Tensor self, Scalar other) -> Tensor")) {
507  return {grads.at(0), nullptr};
508 
509  } else if (node->matches("aten::round(Tensor self) -> Tensor")) {
510  return {SymbolicVariable::zeros_like(grads.at(0))};
511 
512  } else if (node->matches("aten::rsqrt(Tensor self) -> Tensor")) {
513  return {grads.at(0) * outputs.at(0).pow(3.) * -0.5};
514 
515  } else if (node->matches("aten::sin(Tensor self) -> Tensor")) {
516  return {grads.at(0) * inputs.at(0).cos()};
517 
518  } else if (node->matches("aten::sinh(Tensor self) -> Tensor")) {
519  return {grads.at(0) * inputs.at(0).cosh()};
520 
521  } else if (node->matches("aten::tan(Tensor self) -> Tensor")) {
522  return {grads.at(0) * (1. + outputs.at(0) * outputs.at(0))};
523 
524  } else if (node->matches("aten::trunc(Tensor self) -> Tensor")) {
525  return {SymbolicVariable::zeros_like(grads.at(0))};
526 
527  } else if (node->kind() == prim::ConstantChunk) {
528  return {SymbolicVariable::cat(grads, node->i(attr::dim))};
529 
530  } else if (
531  node->matches("aten::view(Tensor self, int[] size) -> Tensor") ||
532  node->matches("aten::reshape(Tensor self, int[] shape) -> Tensor")) {
533  // TODO: if sizes are not available statically, add an operator that
534  // reutrns them as a tuple
535  auto sizes = node->namedInput(attr::self)
536  ->type()
537  ->expect<CompleteTensorType>()
538  ->sizes();
539  return {grads.at(0).reshape(sizes), nullptr};
540 
541  } else if (node->matches(
542  "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
543  return {grads.at(0).type_as(inputs.at(0)), nullptr};
544 
545  } else if (node->matches("aten::rand_like(Tensor self) -> Tensor")) {
546  return {nullptr};
547 
548  } else if (node->matches(
549  "aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
550  return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr};
551 
552  } else if (
553  node->matches(
554  "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
555  return {gradSumToSizeOf(
556  grads.at(0) * node->namedInput(attr::beta), attr::self),
557  grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
558  inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
559  nullptr,
560  nullptr};
561 
562  } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
563  return {grads.at(0).mm(inputs.at(1).t()),
564  inputs.at(0).t().mm(grads.at(0))};
565 
566  } else if (
567  node->matches(
568  "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
569  const auto& input_sizes = inputs.at(0).sizes();
570  if (input_sizes.size() == 0)
571  return {grads.at(0).sum(), nullptr, nullptr};
572  auto grad_sizes = node->get<std::vector<int64_t>>(attr::size).value();
573  auto grad = grads.at(0);
574  while (grad_sizes.size() > input_sizes.size()) {
575  grad = grad.sum(0, false);
576  grad_sizes.erase(grad_sizes.begin());
577  }
578  for (size_t i = 0; i < input_sizes.size(); ++i) {
579  if (input_sizes[i] == 1 && grad_sizes[i] > 1) {
580  grad = grad.sum(i, true);
581  }
582  }
583  return {grad, nullptr, nullptr};
584 
585  } else if (node->matches("aten::squeeze(Tensor self) -> Tensor")) {
586  const auto& sizes = inputs.at(0).sizes();
587  std::vector<size_t> squeezed_dims;
588  for (size_t i = 0; i < sizes.size(); ++i) {
589  if (sizes[i] != 1)
590  continue;
591  squeezed_dims.push_back(i);
592  }
593  SymbolicVariable returned_grad = grads.at(0);
594  for (const auto& dim : squeezed_dims) {
595  returned_grad = returned_grad.unsqueeze(dim);
596  }
597  return {returned_grad};
598 
599  } else if (node->matches(
600  "aten::squeeze(Tensor self, int dim) -> Tensor",
601  /*const_inputs=*/attr::dim)) {
602  int64_t dim = *node->get<int64_t>(attr::dim);
603  const auto& sizes = inputs.at(0).sizes();
604  wrapDim(dim, sizes);
605  if (sizes.size() == 0) {
606  return {grads.at(0), nullptr};
607  }
608  return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim),
609  nullptr};
610 
611  } else if (comparison_ops.find(node)) {
612  return {nullptr, nullptr};
613 
614  } else if (
615  node->matches(
616  "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor")) {
617  AT_ASSERT(grads.size() == 1);
618  auto graph = node->owningGraph();
619  auto backward_value = graph->insert(
620  aten::avg_pool2d_backward,
621  {grads.at(0).value(),
622  node->namedInput(attr::self),
623  node->namedInput(attr::kernel_size),
624  node->namedInput(attr::stride),
625  node->namedInput(attr::padding),
626  node->namedInput(attr::ceil_mode),
627  node->namedInput(attr::count_include_pad)});
628  return {backward_value->node()->output(0),
629  nullptr,
630  nullptr,
631  nullptr,
632  nullptr,
633  nullptr};
634 
635  } else if (
636  node->matches(
637  "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)")) {
638  AT_ASSERT(grads.size() == 2);
639  auto graph = node->owningGraph();
640  auto backward_value = graph->insert(
641  aten::max_pool2d_with_indices_backward,
642  {grads.at(0).value(),
643  node->namedInput(attr::self),
644  node->namedInput(attr::kernel_size),
645  node->namedInput(attr::stride),
646  node->namedInput(attr::padding),
647  node->namedInput(attr::dilation),
648  node->namedInput(attr::ceil_mode),
649  outputs.at(1).value()});
650  return {backward_value->node()->output(0),
651  nullptr,
652  nullptr,
653  nullptr,
654  nullptr,
655  nullptr};
656 
657  } else if (
658  node->matches(
659  "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) {
660  auto graph = node->owningGraph();
661  auto backward_value = graph->insert(
662  aten::thnn_conv2d_backward,
663  {grads.at(0).value(),
664  inputs.at(0).value(),
665  inputs.at(1).value(),
666  node->namedInput(attr::kernel_size),
667  node->namedInput(attr::stride),
668  node->namedInput(attr::padding),
669  outputs.at(1).value(),
670  outputs.at(2).value(),
671  graph->insertConstant(std::vector<bool>{true, true, true})});
672  // graph->insert returns a tuple automatically if multiple outputs are
673  // returned. So unpack them again.
674  Node* tuple_unpack_node =
675  graph->insertNode(graph->createTupleUnpack(backward_value));
676  auto tuple_outputs = tuple_unpack_node->outputs();
677  AT_ASSERT(tuple_outputs.size() == size_t(3));
678  return {tuple_outputs[0],
679  tuple_outputs[1],
680  nullptr,
681  tuple_outputs[2],
682  nullptr,
683  nullptr};
684 
685  } else if (
686  node->matches(
687  "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) {
688  auto graph = node->owningGraph();
689  auto backward_value = graph->insert(
690  aten::native_batch_norm_backward,
691  {grads.at(0).value(),
692  inputs.at(0).value(),
693  inputs.at(1).value(),
694  inputs.at(3).value(),
695  inputs.at(4).value(),
696  outputs.at(1).value(),
697  outputs.at(2).value(),
698  inputs.at(5).value(),
699  inputs.at(7).value(),
700  graph->insertConstant(std::vector<bool>{true, true, true})});
701  // graph->insert returns a tuple automatically if multiple outputs are
702  // returned. So unpack them again.
703  Node* tuple_unpack_node =
704  graph->insertNode(graph->createTupleUnpack(backward_value));
705  auto tuple_outputs = tuple_unpack_node->outputs();
706  AT_ASSERT(tuple_outputs.size() == size_t(3));
707  return {tuple_outputs[0],
708  tuple_outputs[1],
709  tuple_outputs[2],
710  nullptr,
711  nullptr,
712  nullptr,
713  nullptr,
714  nullptr};
715 
716  } else if (
717  node->matches(
718  "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
719  auto graph = node->owningGraph();
720  auto total_weight = graph->insertNode(graph->createAutogradZero());
721  auto weight = graph->insertNode(graph->createNone(TensorType::get()));
722  auto backward_value = graph->insert(
723  aten::nll_loss_backward,
724  {grads.at(0).value(),
725  inputs.at(0).value(),
726  inputs.at(1).value(),
727  weight->output(),
728  inputs.at(3).value(),
729  inputs.at(4).value(),
730  total_weight->output()});
731  return {backward_value->node()->output(0),
732  nullptr,
733  nullptr,
734  nullptr,
735  nullptr};
736 
737  } else if (node->matches(
738  "aten::log_softmax(Tensor self, int dim) -> Tensor")) {
739  AT_ASSERT(grads.size() == 1);
740  auto graph = node->owningGraph();
741  auto backward_value = graph->insert(
742  aten::_log_softmax_backward_data,
743  {grads.at(0).value(),
744  outputs.at(0).value(),
745  node->namedInput(attr::dim),
746  node->namedInput(attr::self)});
747  return {backward_value->node()->output(0), nullptr};
748 
749  } else if (
750  node->kind() == prim::Constant || node->kind() == prim::AutogradZero) {
751  return {};
752  }
753  throw std::runtime_error(
754  std::string("failed to differentiate `") +
755  node->kind().toDisplayString() + "`");
756  }
757 };
758 } // namespace
759 
760 // If we have a function y = f(x) with jacobian J, the backwards of f is dx =
761 // J^t dy. Note that because the backwards always implements this matrix
762 // multiply, we know that it maps an input vector of zeros to an output vector
763 // of zero regardless of what operations it choses to do inside to actually
764 // implement the matrix multiply (most use some optimized form and never
765 // generate J^t). More generally, we know that all of the backward computations
766 // are linear and can use this property to do more aggressive optimizations
767 // later. It is ok to replace any backward function with known-zero inputs with
768 // something that produces known-zero outputs. This function encloses each
769 // know-linear backward function in a 'GradOf' sub-block so that we can perform
770 // optimizations using this information. In particular, specializeUndef will
771 // observe if all the inputs to the linear block are Undef, which the autograd
772 // uses to represent zeros, and then propagate the undefs to the outputs of the
773 // block.
774 static std::vector<Value*> linearGradientForNode(
775  Node* node,
776  ArrayRef<Value*> grad_values) {
777  auto& graph = *node->owningGraph();
778 
779  // FIXME: In case forward has multi outputs, we only support one requires grad
780  if (needTrimGrad(node)) {
781  grad_values = grad_values.at(0);
782  }
783  auto linear = graph.insertNode(graph.create(prim::GradOf, {grad_values}, 0));
784  // to make reading gradient graphs easier, remember the name of the forward op
785  linear->s_(attr::name, node->kind().toDisplayString());
786  auto block = linear->addBlock();
787  WithInsertPoint guard(block);
788  auto results = GradientHelper(node).gradient(grad_values);
789  return fmap(results, [block, linear](Value* grad) -> Value* {
790  if (!grad)
791  return nullptr;
792  block->registerOutput(grad);
793  return linear->addOutput()->copyMetadata(grad);
794  });
795 }
796 
798  ReverseDetails(value_map&& grad_map, Block* reverse_block)
799  : grad_map(std::move(grad_map)), reverse_block(reverse_block) {}
800 
801  value_map grad_map;
802  Block* reverse_block;
803 };
804 
805 // AutogradAdd is a special addition function that handles Undef
806 // AutogradAdd(a, b) == a + b if defined(a) and defined(b)
807 // AutogradAdd(Undef, b) == b
808 // AutogradAdd(a, Undef) == a
809 // AutogradAdd(Undef, Undef) == Undef
810 static Value* createAutogradAdd(Value* a, Value* b) {
811  auto graph = a->owningGraph();
812  return graph->insertNode(graph->create(prim::AutogradAdd, {a, b}))->output();
813 }
814 
815 // Before:
816 // - grad_desc has field f initialized to the original 0-stage graph
817 // After:
818 // - the last node of f (f->nodes().reverse()[0]) is a gradient node
819 // whose block has vjp inputs for all outputs that require_grad
820 // and vjp outputs for all primal inputs that require_grad
821 // - grad_desc has df_input_vjps and df_output_vjps set
822 // (but df_input_vjps will be modified later as well)
823 static ReverseDetails addReverseInline(Gradient& grad_desc) {
824  auto& graph = *grad_desc.f;
825  // note: reverse_node is intentionally not inserted to avoid
826  // accidentally acting on it (e.g. in elminate dead code),
827  // std::cout << *reverse_node << to view its state.
828  auto reverse_node = graph.create(prim::Reverse, 0);
829  auto reverse_block = reverse_node->addBlock();
830  WithInsertPoint guard(reverse_block);
831 
832  value_map grad_map; // x -> dx mapping
833  const auto get_grad = [&](Value* v) -> Value* {
834  auto it = grad_map.find(v);
835  if (it == grad_map.end()) {
836  auto undef = graph.insertNode(graph.createAutogradZero());
837  std::tie(it, std::ignore) = grad_map.emplace(v, undef->output());
838  }
839  return it->second;
840  };
841  const auto set_grad = [&](Value* x, Value* dx) {
842  if (Value* prev_grad = grad_map[x]) {
843  grad_map[x] = createAutogradAdd(prev_grad, dx);
844  } else {
845  grad_map[x] = dx;
846  }
847  };
848 
849  auto outputs = graph.outputs();
850  for (size_t i = 0, num_outputs = outputs.size(); i < num_outputs; ++i) {
851  Value* output = outputs[i];
852  if (!output->requires_grad())
853  continue;
854  Value* output_grad = reverse_block->addInput()->setType(output->type());
855  set_grad(output, output_grad);
856  grad_desc.df_input_vjps.push_back(i);
857  }
858 
859  for (auto it = graph.nodes().rbegin(), end = graph.nodes().rend(); it != end;
860  ++it) {
861  Node* node = *it;
862  auto inputs = node->inputs();
863  auto outputs = node->outputs();
864  if (std::all_of(outputs.begin(), outputs.end(), [](Value* v) {
865  return !v->requires_grad();
866  })) {
867  continue;
868  }
869 
870  value_list grad_inputs =
871  linearGradientForNode(node, fmap(node->outputs(), get_grad));
872  LowerSimpleTuples(reverse_block);
873 
874  AT_ASSERT(grad_inputs.size() == node->inputs().size());
875  for (size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) {
876  if (!inputs[i]->requires_grad())
877  continue;
878  // NB: Not returning a gradient w.r.t. a value that requires grad is
879  // normal if the input is non-differentiable. This happens e.g. in the
880  // aten::type_as case.
881  if (!grad_inputs[i])
882  continue;
883  set_grad(inputs[i], grad_inputs[i]);
884  }
885  }
886 
887  auto inputs = graph.inputs();
888  for (size_t i = 0, num_inputs = inputs.size(); i < num_inputs; ++i) {
889  Value* input = inputs[i];
890  if (!input->requires_grad())
891  continue;
892  // NB: Not having a gradient defined w.r.t. an input to the graph which
893  // requires grad can happen and is not an error. It might have been used
894  // only in non-differentiable contexts (e.g. as second input to
895  // aten::type_as). In that case we simply ignore it as an output, because it
896  // won't ever produce any meaningful values.
897  if (grad_map.count(input) == 0)
898  continue;
899  reverse_block->registerOutput(get_grad(input));
900  grad_desc.df_output_vjps.push_back(i);
901  }
902 
903  return ReverseDetails(std::move(grad_map), reverse_block);
904 }
905 
906 // Returns a topologically-sorted list of values produced in f, and used in its
907 // reverse program.
908 static value_list getReverseCaptures(Gradient& grad_desc) {
909  auto& graph = *grad_desc.f;
910  auto primal_block = graph.block();
911 
912  value_set reverse_captures_set;
913  value_list reverse_captures; // Invariant: topo sorted
914  auto check_uses = [&](Value* v) {
915  for (auto use : v->uses()) {
916  if (use.user->owningBlock() == primal_block)
917  continue;
918  if (/* bool unseen = */ reverse_captures_set.emplace(v).second) {
919  reverse_captures.push_back(v);
920  }
921  }
922  };
923  for (Value* input : graph.inputs()) {
924  check_uses(input);
925  }
926  for (Node* node : graph.nodes()) {
927  for (Value* output : node->outputs())
928  check_uses(output);
929  }
930  return reverse_captures;
931 }
932 
933 // Any temporary value from the primal graphs needs to be captured for later use
934 // in the reverse graph, to avoid costly recomputations. However, a lot of the
935 // nodes we have in our graphs are simply constants, which are cheap to execute
936 // and replicate, and so it's better to just copy them into the reverse graph,
937 // without polluting the output lists unnecessarily.
938 static void liftConstants(Gradient& grad_desc, ReverseDetails& rev_info) {
939  static const auto err = [](Value*) -> Value* {
940  throw std::runtime_error("unexpected input");
941  };
942  auto& graph = *grad_desc.f;
943  Block* reverse_block = rev_info.reverse_block;
944 
945  for (Node* top_node : reverse_block->nodes()) {
946  AT_ASSERT(
947  top_node->kind() == prim::GradOf ||
948  top_node->kind() == prim::AutogradAdd ||
949  top_node->kind() == prim::AutogradZero);
950  if (top_node->kind() != prim::GradOf)
951  continue;
952  Block* grad_body = top_node->blocks().at(0);
953  for (Node* node : grad_body->nodes()) {
954  for (Value* input : node->inputs()) {
955  if (input->node()->kind() != prim::Constant)
956  continue;
957  if (input->node()->owningBlock() == grad_body)
958  continue;
959  Node* lifted_constant = graph.createClone(input->node(), err);
960  reverse_block->prependNode(lifted_constant);
961  node->replaceInputWith(input, lifted_constant->output());
962  }
963  }
964  }
965 }
966 
967 static void deduplicateSizeCaptures(
968  Gradient& grad_desc,
969  ReverseDetails& rev_info) {
970  Block* primal_block = grad_desc.f->block();
971  const auto usedOnlyInReverse = [primal_block](Value* v) {
972  const auto& uses = v->uses();
973  return std::all_of(uses.begin(), uses.end(), [primal_block](const Use& u) {
974  return u.user->owningBlock() != primal_block;
975  });
976  };
977  auto captures = getReverseCaptures(grad_desc);
978  value_set capture_set(captures.begin(), captures.end());
979  for (Value* capture : captures) {
980  Node* node = capture->node();
981  if (!node->matches("aten::size(Tensor self) -> int[]")) {
982  continue;
983  }
984  if (usedOnlyInReverse(capture) && capture_set.count(node->input())) {
985  WithInsertPoint insert_guard{*rev_info.reverse_block->nodes().begin()};
986  capture->replaceAllUsesWith(SymbolicVariable(node->input()).size());
987  node->destroy();
988  }
989  }
990 }
991 
992 static void eliminateDeadCode(ReverseDetails& rev_info) {
993  // addReverseInline has to call gradientForNode if *any* of the inputs
994  // require grad, but it will emit vjps for *all* inputs. Use DCE to remove
995  // unnecessary nodes. Additionally, requires_grad() on intermediates is an
996  // overapproximation of the real state, so we might have emitted some
997  // gradients, only to realize that they were unnecessary once we reach a
998  // point that doesn't require grad.
999  // Of course, we need to filter out corresponding entries of grad_map, because
1000  // we don't want to accidentally access freed pointers later.
1001  std::function<void(const std::unordered_set<const Value*>&)> cb =
1002  [&](const std::unordered_set<const Value*>& live_values) {
1003  std::vector<Value*> to_erase;
1004  for (auto& entry : rev_info.grad_map) {
1005  if (!live_values.count(entry.second)) {
1006  to_erase.push_back(entry.first);
1007  }
1008  }
1009  for (Value* v : to_erase) {
1010  rev_info.grad_map.erase(v);
1011  }
1012  };
1013  EliminateDeadCode(rev_info.reverse_block, std::move(cb));
1014 }
1015 
1016 static void Optimize(Gradient& grad_desc, ReverseDetails& rev_info) {
1017  // TODO: we are sometimes emitting expressions like
1018  // _grad_sum_to_size(_grad_sum_so_size(x, s1), s2), which are equivalent to
1019  // _grad_sum_to_size(x, s2), and could save us some
1020  // captures, but I'm not 100% sure how to optimize this at this stage, since
1021  // we don't know which GradOf blocks will be stitched together to form the
1022  // derivative. I guess a smart analysis could implement this, but I didn't
1023  // have time before the 1.0 release, so I put this only as a peephole
1024  // optimization.
1025  liftConstants(grad_desc, rev_info);
1026  // We generally add a lot of aten::size calls (for derivatives of broadcasting
1027  // operators), and they often end up duplicated, and would get captured
1028  // multiple times. Make sure we deduplicate them before lifting.
1029  EliminateCommonSubexpression(grad_desc.f);
1030  deduplicateSizeCaptures(grad_desc, rev_info);
1031  eliminateDeadCode(rev_info);
1032 }
1033 
1034 // Takes a grad_desc.f returned from `addReverseInline` and splits off the
1035 // reverse_block into its own graph, storing it in df.
1036 // All intermediates needed in the second stage are added to
1037 // outputs of f, and taken as inputs in df. For a more
1038 // detailed description see Note [Gradient graphs] in autodiff.h.
1039 // This function also initializes the fields in grad_desc that were undefined
1040 // after `addReverseInline` (and extends `df_input_vjps` with vjps for captured
1041 // temporaries).
1042 static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
1043  auto& graph = *grad_desc.f;
1044  auto reverse_block = rev_info.reverse_block;
1045 
1046  // --------------------------------------------------------------------------
1047  // 1. Find values of f that need to be captured.
1048  // --------------------------------------------------------------------------
1049  // First, we need to find all values that are produced in f,
1050  // and used in df. They will need to be added as inputs of the df
1051  // and some of them may also need to be appended as outputs of f if
1052  // they are not already an input or an output of f
1053  // Invariant: topo sorted
1054  value_list reverse_captures = getReverseCaptures(grad_desc);
1055 
1056  // --------------------------------------------------------------------------
1057  // 2. Prepare input/outputs lists for f and df
1058  // --------------------------------------------------------------------------
1059  // It's simple to construct primal_inputs/reverse_outputs,
1060  // but primal_outputs/reverse_inputs are much more subtle.
1061  // Here's a summary of how they are supposed to look like:
1062  //
1063  // Primal outputs:
1064  // [original outputs], [temporaries]
1065  //
1066  // Reverse inputs:
1067  // [output vjps (aka grad_outputs)], [temporary vjps]
1068  // [captured primal values, in topological order],
1069 
1070  // -- Construct primal_outputs, df_input_captures, f_real_outputs ----
1071  grad_desc.f_real_outputs = graph.outputs().size();
1072 
1073  std::unordered_map<Value*, size_t> orig_primal_outputs_idx;
1074  std::unordered_map<Value*, size_t> orig_primal_inputs_idx;
1075  // NOTE: we use emplace to avoid replacing an existing index if an output is
1076  // repeated
1077  for (size_t i = 0, num_outputs = graph.outputs().size(); i < num_outputs; ++i)
1078  orig_primal_outputs_idx.emplace(graph.outputs()[i], i);
1079  for (size_t i = 0, num_inputs = graph.inputs().size(); i < num_inputs; ++i)
1080  orig_primal_inputs_idx[graph.inputs()[i]] = i;
1081 
1082  // NB: reverse_captures are already deduplicated, and in topo order
1083  for (Value* capture_val : reverse_captures) {
1084  // If it's already an output we don't have to add anything,
1085  // but register the fact that it needs to be captured.
1086  if (orig_primal_outputs_idx.count(capture_val) > 0) {
1087  grad_desc.df_input_captured_outputs.push_back(
1088  orig_primal_outputs_idx[capture_val]);
1089  // If it's an input, we could add it as an output but in fact it's
1090  // more efficient to use a special kind of capture.
1091  } else if (orig_primal_inputs_idx.count(capture_val) > 0) {
1092  grad_desc.df_input_captured_inputs.push_back(
1093  orig_primal_inputs_idx.at(capture_val));
1094  // Otherwise it's just a regular intermediate value that we need to add as
1095  // an output
1096  } else {
1097  // we need to create a new temporary output for this capture because it
1098  // wasn't availiable.
1099  graph.registerOutput(capture_val);
1100  grad_desc.df_input_captured_outputs.emplace_back(
1101  graph.outputs().size() - 1);
1102  }
1103  }
1104 
1105  // -- Add VJPs for temporaries, adjust df_input_vjps -------------------------
1106  // NB [possible optimization]: use the newly added vjp input as soon as the
1107  // first vjp for that value is generated, to reduce the lifespan of this input
1108  // (currently we add it to the final vjp after all adds).
1109  for (size_t i = grad_desc.f_real_outputs; i < graph.outputs().size(); ++i) {
1110  Value* tmp = graph.outputs().at(i);
1111  // Add VJP inputs only for intermediates that actually required grad.
1112  // Note that we check the contents of the grad_map instead of
1113  // tmp->requires_grad(), becuase it's actually a more faithful source.
1114  // tmp->requires_grad() is really an overapproximation (i.e. it can have
1115  // false positives), while the gradients we will emit for this value can get
1116  // DCE-d in the optimization pass (because it has no influence on the real
1117  // f's outputs that we differentiate).
1118  if (rev_info.grad_map.count(tmp) == 0)
1119  continue;
1120  Value* tmp_vjp_in = reverse_block->addInput()->setType(tmp->type());
1121  Value* tmp_vjp_prev = rev_info.grad_map.at(tmp);
1122  // This is quite weird because we can't first make a sum and then replace
1123  // all uses of tmp_vjp_prev (that would replace its use in the sum too!), so
1124  // we create an incorrect sum that doesn't use prev vjp, replace uses, and
1125  // fix the sum.
1126  Value* new_vjp = createAutogradAdd(tmp_vjp_in, tmp_vjp_in);
1127  new_vjp->node()->moveAfter(tmp_vjp_prev->node());
1128  tmp_vjp_prev->replaceAllUsesWith(new_vjp);
1129  new_vjp->node()->replaceInput(1, tmp_vjp_prev);
1130  grad_desc.df_input_vjps.emplace_back(i);
1131  }
1132 
1133  // add the captures as formal arguments to the reverse_block
1134  // afterward inputs: [output vjps][temporary vjps][captures]
1135  // construct a map from captured 'value' to the index in the input list
1136  // used to extract this block into its own function
1137  std::unordered_map<Value*, size_t> capture_to_formal_index;
1138  const auto& add_capture = [&](Value* captured) {
1139  capture_to_formal_index[captured] = reverse_block->inputs().size();
1140  reverse_block->addInput()->copyMetadata(captured);
1141  };
1142  for (auto& offset : grad_desc.df_input_captured_inputs)
1143  add_capture(graph.inputs()[offset]);
1144  for (auto& offset : grad_desc.df_input_captured_outputs)
1145  add_capture(graph.outputs()[offset]);
1146 
1147  grad_desc.df = std::make_shared<Graph>();
1148  grad_desc.df->block()->cloneFrom(reverse_block, [&](Value* v) {
1149  return grad_desc.df->inputs()[capture_to_formal_index.at(v)];
1150  });
1151  // reverse_node was just to hold onto reverse_block in a debuggable way
1152  // we can remove it now.
1153  reverse_block->owningNode()->destroy();
1154 }
1155 
1156 Gradient differentiate(std::shared_ptr<Graph>& graph) {
1157  Gradient grad_desc;
1158  // Take ownership of the graph
1159  AT_CHECK(
1160  graph.use_count() == 1,
1161  "differentiate will mutate and destroy the graph, so it requires "
1162  "graph.use_count() == 1, but found %d",
1163  graph.use_count());
1164  std::swap(graph, grad_desc.f);
1165  // XXX: Take care when handling outputs - they can be duplicated!
1166 
1167  WithInsertPoint guard(grad_desc.f->block());
1168  // Fills in df_input_vjps and df_output_vjps
1169  auto rev_info = addReverseInline(grad_desc);
1170  Optimize(grad_desc, rev_info);
1171  // Clean up old nodes which has been replaced by forward graphs in torchscript
1172  EliminateDeadCode(grad_desc.f->block());
1173 
1174  // Fills in f, df, f_real_outputs, df_input_captures,
1175  // modifies df_input_vjps (new vjps are added for temporaries)
1176  lambdaLiftReverse(grad_desc, rev_info);
1177  // It's possible the we've cloned the same constants many times, so
1178  // de-duplicate them
1179  ConstantPooling(grad_desc.df);
1180  return grad_desc;
1181 }
1182 
1183 } // namespace jit
1184 } // namespace torch
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
Definition: jit_type.h:17
An utility class for setting temporary insertion points.
Definition: ir.h:1174
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...