Caffe2 - C++ API
A deep learning, cross platform ML framework
requires_grad_analysis.cpp
1 #include <torch/csrc/jit/argument_spec.h>
2 #include <torch/csrc/jit/ir.h>
3 #include <torch/csrc/jit/operator.h>
4 #include <ATen/core/jit_type.h>
5 
6 #include <vector>
7 
8 namespace torch {
9 namespace jit {
10 
11 namespace {
12 
13 bool getRequiresGrad(Value* value) {
14  return value->requires_grad();
15 }
16 
17 void setRequiresGrad(Value* value, bool req_value) {
18  if (auto type = value->type()->cast<DimensionedTensorType>()) {
19  value->setType(type->withRequiresGrad(req_value));
20  }
21 }
22 
23 void setRequiresGrad(
24  at::ArrayRef<Value*> outputs,
25  const std::vector<bool>& values) {
26  AT_ASSERT(outputs.size() == values.size());
27  for (size_t i = 0; i < values.size(); ++i) {
28  setRequiresGrad(outputs[i], values[i]);
29  }
30 }
31 
32 void setRequiresGrad(Node* node, const std::vector<bool>& values) {
33  setRequiresGrad(node->outputs(), values);
34 }
35 
36 std::vector<bool> bitwiseOr(std::vector<bool> a, const std::vector<bool>& b) {
37  AT_ASSERT(a.size() == b.size());
38  for (size_t i = 0; i < a.size(); ++i) {
39  a[i] = a[i] || b[i];
40  }
41  return a;
42 }
43 
44 void PropagateRequiresGradSimpleNode(Node* node) {
45  static const OperatorSet comparison_ops = {
46  "aten::lt(Tensor self, Tensor other) -> Tensor",
47  "aten::le(Tensor self, Tensor other) -> Tensor",
48  "aten::gt(Tensor self, Tensor other) -> Tensor",
49  "aten::ge(Tensor self, Tensor other) -> Tensor",
50  "aten::eq(Tensor self, Tensor other) -> Tensor",
51  "aten::ne(Tensor self, Tensor other) -> Tensor",
52  "aten::lt(Tensor self, Scalar other) -> Tensor",
53  "aten::le(Tensor self, Scalar other) -> Tensor",
54  "aten::gt(Tensor self, Scalar other) -> Tensor",
55  "aten::ge(Tensor self, Scalar other) -> Tensor",
56  "aten::eq(Tensor self, Scalar other) -> Tensor",
57  "aten::ne(Tensor self, Scalar other) -> Tensor",
58  };
59 
60  if (comparison_ops.find(node)) {
61  return setRequiresGrad(node->output(), false);
62  } else if (node->matches(
63  "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
64  return setRequiresGrad(node->output(), node->input(0)->requires_grad());
65  } else if (node->matches("aten::detach(Tensor self) -> Tensor")) {
66  return setRequiresGrad(node->output(), false);
67  }
68 
69  auto inputs = node->inputs();
70  auto outputs = node->outputs();
71  bool should_require =
72  std::any_of(inputs.begin(), inputs.end(), getRequiresGrad);
73  for (Value* output : outputs) {
74  if (auto type = output->type()->cast<DimensionedTensorType>()) {
75  setRequiresGrad(
76  output, should_require && at::isFloatingType(type->scalarType()));
77  }
78  }
79 }
80 
81 void PropagateRequiresGrad(Block* block);
82 
83 void PropagateRequiresGrad(Node* node) {
84  if (node->kind() == prim::If) {
85  auto blocks = node->blocks();
86  auto true_block = blocks.at(0);
87  auto false_block = blocks.at(1);
88 
89  PropagateRequiresGrad(true_block);
90  PropagateRequiresGrad(false_block);
91 
92  auto outputs_require = bitwiseOr(
93  fmap(true_block->outputs(), getRequiresGrad),
94  fmap(false_block->outputs(), getRequiresGrad));
95  setRequiresGrad(node, outputs_require);
96  } else if (node->kind() == prim::Loop) {
97  auto body = node->blocks().at(0);
98  std::vector<bool> body_inputs_require =
99  fmap(node->inputs().slice(2), getRequiresGrad);
100  std::vector<bool> body_outputs_require(node->outputs().size(), false);
101 
102  while (body_inputs_require != body_outputs_require) {
103  body_inputs_require =
104  bitwiseOr(body_inputs_require, body_outputs_require);
105  setRequiresGrad(
106  body->param_node()->outputs().slice(1), body_inputs_require);
107  PropagateRequiresGrad(body);
108  body_outputs_require =
109  fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
110  }
111 
112  setRequiresGrad(node, body_outputs_require);
113  } else {
114  PropagateRequiresGradSimpleNode(node);
115  }
116 }
117 
118 void PropagateRequiresGrad(Block* block) {
119  for (Node* node : block->nodes()) {
120  PropagateRequiresGrad(node);
121  }
122 }
123 
124 } // anonymous namespace
125 
126 void PropagateRequiresGrad(std::shared_ptr<Graph>& graph) {
127  PropagateRequiresGrad(graph->block());
128 }
129 
130 } // namespace jit
131 } // namespace torch
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41