1 #include <torch/csrc/jit/argument_spec.h> 2 #include <torch/csrc/jit/ir.h> 3 #include <torch/csrc/jit/operator.h> 4 #include <ATen/core/jit_type.h> 13 bool getRequiresGrad(Value* value) {
14 return value->requires_grad();
17 void setRequiresGrad(Value* value,
bool req_value) {
18 if (
auto type = value->type()->cast<DimensionedTensorType>()) {
19 value->setType(type->withRequiresGrad(req_value));
25 const std::vector<bool>& values) {
26 AT_ASSERT(outputs.
size() == values.size());
27 for (
size_t i = 0; i < values.size(); ++i) {
28 setRequiresGrad(outputs[i], values[i]);
32 void setRequiresGrad(Node* node,
const std::vector<bool>& values) {
33 setRequiresGrad(node->outputs(), values);
36 std::vector<bool> bitwiseOr(std::vector<bool> a,
const std::vector<bool>& b) {
37 AT_ASSERT(a.size() == b.size());
38 for (
size_t i = 0; i < a.size(); ++i) {
44 void PropagateRequiresGradSimpleNode(Node* node) {
45 static const OperatorSet comparison_ops = {
46 "aten::lt(Tensor self, Tensor other) -> Tensor",
47 "aten::le(Tensor self, Tensor other) -> Tensor",
48 "aten::gt(Tensor self, Tensor other) -> Tensor",
49 "aten::ge(Tensor self, Tensor other) -> Tensor",
50 "aten::eq(Tensor self, Tensor other) -> Tensor",
51 "aten::ne(Tensor self, Tensor other) -> Tensor",
52 "aten::lt(Tensor self, Scalar other) -> Tensor",
53 "aten::le(Tensor self, Scalar other) -> Tensor",
54 "aten::gt(Tensor self, Scalar other) -> Tensor",
55 "aten::ge(Tensor self, Scalar other) -> Tensor",
56 "aten::eq(Tensor self, Scalar other) -> Tensor",
57 "aten::ne(Tensor self, Scalar other) -> Tensor",
60 if (comparison_ops.find(node)) {
61 return setRequiresGrad(node->output(),
false);
62 }
else if (node->matches(
63 "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
64 return setRequiresGrad(node->output(), node->input(0)->requires_grad());
65 }
else if (node->matches(
"aten::detach(Tensor self) -> Tensor")) {
66 return setRequiresGrad(node->output(),
false);
69 auto inputs = node->inputs();
70 auto outputs = node->outputs();
72 std::any_of(inputs.begin(), inputs.end(), getRequiresGrad);
73 for (Value* output : outputs) {
74 if (
auto type = output->type()->cast<DimensionedTensorType>()) {
76 output, should_require && at::isFloatingType(type->scalarType()));
81 void PropagateRequiresGrad(Block* block);
83 void PropagateRequiresGrad(Node* node) {
84 if (node->kind() == prim::If) {
85 auto blocks = node->blocks();
86 auto true_block = blocks.at(0);
87 auto false_block = blocks.at(1);
89 PropagateRequiresGrad(true_block);
90 PropagateRequiresGrad(false_block);
92 auto outputs_require = bitwiseOr(
93 fmap(true_block->outputs(), getRequiresGrad),
94 fmap(false_block->outputs(), getRequiresGrad));
95 setRequiresGrad(node, outputs_require);
96 }
else if (node->kind() == prim::Loop) {
97 auto body = node->blocks().at(0);
98 std::vector<bool> body_inputs_require =
99 fmap(node->inputs().slice(2), getRequiresGrad);
100 std::vector<bool> body_outputs_require(node->outputs().size(),
false);
102 while (body_inputs_require != body_outputs_require) {
103 body_inputs_require =
104 bitwiseOr(body_inputs_require, body_outputs_require);
106 body->param_node()->outputs().slice(1), body_inputs_require);
107 PropagateRequiresGrad(body);
108 body_outputs_require =
109 fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
112 setRequiresGrad(node, body_outputs_require);
114 PropagateRequiresGradSimpleNode(node);
118 void PropagateRequiresGrad(Block* block) {
119 for (Node* node : block->nodes()) {
120 PropagateRequiresGrad(node);
126 void PropagateRequiresGrad(std::shared_ptr<Graph>& graph) {
127 PropagateRequiresGrad(graph->block());
constexpr size_t size() const
size - Get the array size.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...