1 #include <torch/csrc/jit/passes/specialize_autogradzero.h> 2 #include <torch/csrc/jit/symbolic_variable.h> 13 void specializeAutogradZero(Graph& g) {
14 enum class State { Nonzero, Zero, Unknown };
15 std::unordered_map<Value*, State> state;
17 for (Value* input : g.inputs()) {
18 const auto& tp = input->type();
19 if (tp->isSubtypeOf(AutogradZeroTensorType::get())) {
20 state[input] = State::Zero;
21 }
else if (tp->isSubtypeOf(TensorType::get())) {
22 state[input] = State::Nonzero;
24 state[input] = State::Unknown;
28 for (
auto it = g.nodes().begin(); it != g.nodes().end(); ++it) {
33 std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
34 return state[v] == State::Zero;
40 auto zero = g.createAutogradZero()->insertAfter(n)->output();
41 for (
auto o : n->outputs()) {
42 o->replaceAllUsesWith(zero);
51 auto body = n->blocks().at(0);
52 for (
auto input : n->inputs()) {
57 AT_ASSERT(state[input] != State::Unknown);
60 for (
auto it = body->nodes().begin(); it != body->nodes().end();) {
61 auto block_node = *it++;
62 block_node->moveBefore(n);
65 for (
size_t i = 0; i < n->outputs().size(); ++i)
66 n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
70 case prim::AutogradAdd: {
74 if (state[a] == State::Zero) {
76 n->output()->replaceAllUsesWith(b);
78 }
else if (state[b] == State::Zero) {
80 n->output()->replaceAllUsesWith(a);
82 }
else if (state[a] == State::Nonzero && state[b] == State::Nonzero) {
85 WithInsertPoint guard(n);
86 Value* new_add = toVar(a) + toVar(b);
87 state[new_add] = State::Nonzero;
88 n->output()->replaceAllUsesWith(new_add);
94 state[n->output()] = State::Unknown;
97 case prim::AutogradZero: {
98 state[n->output()] = State::Zero;
101 for (
auto o : n->outputs()) {
102 state[o] = State::Unknown;