1 #include <torch/csrc/jit/passes/lower_grad_of.h> 6 void LowerGradOf(Graph& g) {
7 for (
auto it = g.nodes().begin(); it != g.nodes().end(); ++it) {
8 if (it->kind() == prim::GradOf) {
13 WithInsertPoint guard(*it);
14 auto cond = g.insertNode(g.create(prim::AutogradAnyNonZero, it->inputs()))
16 ->setType(IntType::get());
18 g.insertNode(g.create(prim::If, {cond}, it->outputs().size()));
19 if_stat->addBlock()->cloneFrom(
20 it->blocks().at(0), [](Value* v) {
return v; });
21 auto else_block = if_stat->addBlock();
22 auto undef = g.createAutogradZero()
23 ->insertBefore(else_block->return_node())
25 for (
size_t i = 0; i < it->outputs().size(); ++i) {
26 else_block->registerOutput(undef);
27 if_stat->outputs().at(i)->copyMetadata(it->outputs().at(i));
29 it->replaceAllUsesWith(if_stat);