Caffe2 - C++ API
A deep learning, cross platform ML framework
specialize_autogradzero.cpp
1 #include <torch/csrc/jit/passes/specialize_autogradzero.h>
2 #include <torch/csrc/jit/symbolic_variable.h>
3 
4 namespace torch {
5 namespace jit {
6 
7 // propagate autograd zero information through a gradient graph and
8 // remove grad_of blocks if present.
9 // Note: this is a very limited pass. It only propagates autograd zeros for
10 // operations generated by the symbolic autodiff code and cleans up
11 // AutogradAdds when possible. Outputs of other nodes are conservatively
12 // marked Unknown and not optimized.
13 void specializeAutogradZero(Graph& g) {
14  enum class State { Nonzero, Zero, Unknown };
15  std::unordered_map<Value*, State> state;
16 
17  for (Value* input : g.inputs()) {
18  const auto& tp = input->type();
19  if (tp->isSubtypeOf(AutogradZeroTensorType::get())) {
20  state[input] = State::Zero;
21  } else if (tp->isSubtypeOf(TensorType::get())) {
22  state[input] = State::Nonzero;
23  } else {
24  state[input] = State::Unknown;
25  }
26  }
27 
28  for (auto it = g.nodes().begin(); it != g.nodes().end(); ++it) {
29  auto n = *it;
30  switch (n->kind()) {
31  case prim::GradOf: {
32  auto all_zeros =
33  std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
34  return state[v] == State::Zero;
35  });
36  // Property 1: if all the gradInputs to the GradOf are Zero
37  // then the gradOutputs are also zero and will be represented as
38  // AutogradZero nodes
39  if (all_zeros) {
40  auto zero = g.createAutogradZero()->insertAfter(n)->output();
41  for (auto o : n->outputs()) {
42  o->replaceAllUsesWith(zero);
43  }
44  } else {
45  // Property 2: GradOfs are required to correctly handle combinations
46  // of Nonzero and zero inputs. They are expected to produce
47  // Nonzero output tensors in this case.
48 
49  // Remove the GradOf, splicing its body back into the surrounding
50  // block
51  auto body = n->blocks().at(0);
52  for (auto input : n->inputs()) {
53  // we should never get into a situation when specializing a GradOf
54  // where we do not know if a value is Nonzero since at the top level
55  // a gradient graph is composed of Linear nodes and AutogradAdds
56  // and LinearNodes only appear in these graphs
57  AT_ASSERT(state[input] != State::Unknown);
58  }
59  // hoist the nodes in the GradOf body to be before the linear block
60  for (auto it = body->nodes().begin(); it != body->nodes().end();) {
61  auto block_node = *it++;
62  block_node->moveBefore(n);
63  }
64 
65  for (size_t i = 0; i < n->outputs().size(); ++i)
66  n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
67  }
68  it.destroyCurrent();
69  } break;
70  case prim::AutogradAdd: {
71  auto a = n->input(0);
72  auto b = n->input(1);
73  // if one is Autograd zero, we can just drop the add
74  if (state[a] == State::Zero) {
75  // Zero + b == b
76  n->output()->replaceAllUsesWith(b);
77  it.destroyCurrent();
78  } else if (state[b] == State::Zero) {
79  // a + Zero == a
80  n->output()->replaceAllUsesWith(a);
81  it.destroyCurrent();
82  } else if (state[a] == State::Nonzero && state[b] == State::Nonzero) {
83  // when both are Nonzero, we can use a normal, optimizable add
84  // instruction
85  WithInsertPoint guard(n);
86  Value* new_add = toVar(a) + toVar(b);
87  state[new_add] = State::Nonzero;
88  n->output()->replaceAllUsesWith(new_add);
89  it.destroyCurrent();
90  } else {
91  // otherwise we have conditionally-Nonzero things, and we need
92  // to actually run an AutogradAdd which will guard for Zeros
93  // so we leave the op as is
94  state[n->output()] = State::Unknown;
95  }
96  } break;
97  case prim::AutogradZero: {
98  state[n->output()] = State::Zero;
99  } break;
100  default:
101  for (auto o : n->outputs()) {
102  state[o] = State::Unknown;
103  }
104  break;
105  }
106  }
107 }
108 
109 } // namespace jit
110 } // namespace torch
Definition: jit_type.h:17