Caffe2 - C++ API
A deep learning, cross platform ML framework
lower_grad_of.cpp
1 #include <torch/csrc/jit/passes/lower_grad_of.h>
2 
3 namespace torch {
4 namespace jit {
5 
6 void LowerGradOf(Graph& g) {
7  for (auto it = g.nodes().begin(); it != g.nodes().end(); ++it) {
8  if (it->kind() == prim::GradOf) {
9  // if any_defined(inputs):
10  // outputs = <original_computation>
11  // else:
12  // outputs = autograd zero tensors
13  WithInsertPoint guard(*it);
14  auto cond = g.insertNode(g.create(prim::AutogradAnyNonZero, it->inputs()))
15  ->output()
16  ->setType(IntType::get());
17  auto if_stat =
18  g.insertNode(g.create(prim::If, {cond}, it->outputs().size()));
19  if_stat->addBlock()->cloneFrom(
20  it->blocks().at(0), [](Value* v) { return v; });
21  auto else_block = if_stat->addBlock();
22  auto undef = g.createAutogradZero()
23  ->insertBefore(else_block->return_node())
24  ->output();
25  for (size_t i = 0; i < it->outputs().size(); ++i) {
26  else_block->registerOutput(undef);
27  if_stat->outputs().at(i)->copyMetadata(it->outputs().at(i));
28  }
29  it->replaceAllUsesWith(if_stat);
30  it.destroyCurrent();
31  }
32  }
33 }
34 
35 } // namespace jit
36 } // namespace torch
Definition: jit_type.h:17