Caffe2 - C++ API
A deep learning, cross platform ML framework
final_returns.cpp
1 #include <torch/csrc/jit/ir.h>
2 #include <torch/csrc/jit/script/final_returns.h>
3 
4 namespace torch {
5 namespace jit {
6 namespace script {
7 
8 struct ReturnInfo {
9  bool returns_; // true - all paths through stmts_ always return
10  // false - all paths through stmts_ do not return
11  List<Stmt> stmts_;
12 };
13 
14 void checkNoReturn(const TreeRef& ref) {
15  if (ref->kind() == TK_RETURN)
16  throw ErrorReport(ref) << "return is not allowed from a loop.";
17  for (const TreeRef& child : ref->trees()) {
18  checkNoReturn(child);
19  }
20 }
21 
22 // transform stmts so that its last action is to return or report that it
23 // never returns.
24 // return_none - if true, add an implicit `return None` to the end of the block
25 // this handles the case where the return is implicit at the end of the
26 // function.
27 ReturnInfo makeReturnsFinal(
28  const SourceRange& range,
30  bool return_none);
31 ReturnInfo makeReturnsFinal(const List<Stmt>& stmts, bool return_none) {
32  return makeReturnsFinal(stmts.range(), stmts.get()->trees(), return_none);
33 }
34 ReturnInfo makeReturnsFinal(
35  const SourceRange& range,
37  bool return_none) {
38  std::vector<TreeRef> changed;
39  changed.reserve(stmts.size());
40  for (size_t i = 0; i < stmts.size(); ++i) {
41  const TreeRef& stmt = stmts[i];
42  switch (stmt->kind()) {
43  case TK_IF: {
44  auto if_stmt = If(stmt);
45  auto true_final = makeReturnsFinal(if_stmt.trueBranch(), false);
46  // (3) early return an if statement without an else block:
47  if (true_final.returns_ && if_stmt.falseBranch().size() == 0) {
48  auto rest_final =
49  makeReturnsFinal(range, stmts.slice(i + 1), return_none);
50  if (!rest_final.returns_) {
51  throw ErrorReport(if_stmt)
52  << "This if statement performs an early return, but the block of code that follows it does not return."
53  << " Early returns are only allowed when the block following them also returns.";
54  }
55  changed.emplace_back(
56  if_stmt.withNewBranches(true_final.stmts_, rest_final.stmts_));
57  return {true, List<Stmt>::unsafeCreate(range, std::move(changed))};
58  }
59 
60  auto false_final = makeReturnsFinal(if_stmt.falseBranch(), false);
61  // (1) neither branch returns just keep processing the block
62  if (!true_final.returns_ && !false_final.returns_) {
63  changed.emplace_back(if_stmt);
64  break;
65  }
66  // (2) all branches return
67  if (true_final.returns_ && false_final.returns_) {
68  changed.emplace_back(
69  if_stmt.withNewBranches(true_final.stmts_, false_final.stmts_));
70  return {true, List<Stmt>::unsafeCreate(range, std::move(changed))};
71  }
72  throw ErrorReport(if_stmt)
73  << "This if statement contains some paths that return and some paths that do not. "
74  << "If statements must either entirely return or never return.";
75  } break;
76  case TK_WHILE:
77  case TK_FOR:
78  changed.emplace_back(stmt);
79  checkNoReturn(stmt);
80  break;
81  case TK_RETURN:
82  changed.emplace_back(stmt);
83  // ignore the rest the the block, it is dead.
84  return {true, List<Stmt>::unsafeCreate(range, std::move(changed))};
85  default:
86  changed.emplace_back(stmt);
87  break;
88  }
89  }
90  if (return_none) {
91  // add an implicit return none node
92  changed.emplace_back(
93  Return::create(range, Expr(Compound::create(TK_NONE, range, {}))));
94  }
95  // we reach the end of the block, no returns have happened
96  // unless we just inserted a return_none implicit return.
97  return {return_none, List<Stmt>::unsafeCreate(range, std::move(changed))};
98 }
99 
100 List<Stmt> moveAllReturnsToEnd(const List<Stmt>& stmts) {
101  return makeReturnsFinal(stmts, true).stmts_;
102 }
103 
104 } // namespace script
105 } // namespace jit
106 } // namespace torch
AT_CPP14_CONSTEXPR ArrayRef< T > slice(size_t N, size_t M) const
slice(n, m) - Chop off the first N elements of the array, and keep M elements in the array...
Definition: ArrayRef.h:161
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41