Caffe2 - C++ API
A deep learning, cross platform ML framework
ir_views.h
1 #include <torch/csrc/jit/ir.h>
2 
3 namespace torch {
4 namespace jit {
5 
6 struct IfView {
7  explicit IfView(Node* node) : node_(node) {
8  AT_ASSERT(node->kind() == ::c10::prim::If);
9  }
10  Value* cond() const {
11  return node_->input(0);
12  }
13  Block* thenBlock() const {
14  return node_->blocks().at(0);
15  }
16  Block* elseBlock() const {
17  return node_->blocks().at(1);
18  }
19  ArrayRef<Value*> thenOutputs() const {
20  return thenBlock()->outputs();
21  }
22  ArrayRef<Value*> elseOutputs() const {
23  return elseBlock()->outputs();
24  }
25  ArrayRef<Value*> outputs() const {
26  return node_->outputs();
27  }
28  Node* node() const {
29  return node_;
30  }
31  operator Node*() const {
32  return node_;
33  }
34 
35  private:
36  Node* node_;
37 };
38 
39 struct LoopView {
40  explicit LoopView(Node* node) : node_(node) {
41  AT_ASSERT(
42  node->kind() == ::c10::prim::Loop || node->kind() == ::c10::onnx::Loop);
43  }
44  Block* bodyBlock() const {
45  return node_->blocks().at(0);
46  }
47  Value* cond() const {
48  return node_->input(0);
49  }
50  Value* maxTripCount() const {
51  return node_->input(0);
52  }
53  Value* inputCond() const {
54  return node_->input(1);
55  }
56  Value* nextCond() const {
57  return bodyBlock()->outputs().at(0);
58  }
59  Value* currentTripCount() const {
60  return bodyBlock()->inputs().at(0);
61  }
62  ArrayRef<Value*> carriedInputs() const {
63  // skip trip count and cond
64  return node_->inputs().slice(2);
65  }
66  ArrayRef<Value*> carriedOutputs() const {
67  return node_->outputs();
68  }
69  ArrayRef<Value*> bodyCarriedInputs() const {
70  // skip trip count and cond
71  return bodyBlock()->inputs().slice(1);
72  }
73  ArrayRef<Value*> bodyCarriedOutputs() const {
74  return bodyBlock()->outputs().slice(1);
75  }
76  Node* node() const {
77  return node_;
78  }
79  operator Node*() const {
80  return node_;
81  }
82 
83  private:
84  Node* node_;
85 };
86 
87 } // namespace jit
88 } // namespace torch
AT_CPP14_CONSTEXPR ArrayRef< T > slice(size_t N, size_t M) const
slice(n, m) - Chop off the first N elements of the array, and keep M elements in the array...
Definition: ArrayRef.h:161
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41