Caffe2 - C++ API
A deep learning, cross platform ML framework
ir.h
1 #pragma once
2 
3 #include <torch/csrc/jit/attributes.h>
4 #include <torch/csrc/jit/graph_node_list.h>
5 #include <torch/csrc/jit/named_value.h>
6 #include <torch/csrc/jit/scope.h>
7 
8 #include <torch/csrc/WindowsTorchApiMacro.h>
9 #include <torch/csrc/utils/disallow_copy.h>
10 #include <torch/csrc/utils/object_ptr.h>
11 
12 #include <ATen/ATen.h>
13 #include <ATen/core/function_schema.h>
14 #include <ATen/core/functional.h>
15 #include <ATen/core/interned_strings.h>
16 #include <ATen/core/ivalue.h>
17 #include <ATen/core/jit_type.h>
18 #include <c10/util/ArrayRef.h>
19 #include <c10/util/Exception.h>
20 
21 #include <functional>
22 #include <iostream>
23 #include <unordered_set>
24 #include <vector>
25 
26 namespace torch {
27 namespace jit {
28 
29 using ::c10::Argument;
30 using ::c10::FunctionSchema;
31 using ::c10::Symbol;
32 
33 using ::c10::ivalue::List;
34 using ::c10::ivalue::Shared;
35 
36 using ::c10::IValue;
37 using ::c10::ivalue::Future;
38 using ::c10::ivalue::Tuple;
39 
40 using ::c10::ivalue::BoolList;
41 using ::c10::ivalue::DoubleList;
42 using ::c10::ivalue::GenericList;
43 using ::c10::ivalue::IntList;
44 using ::c10::ivalue::TensorList;
45 
46 using ::c10::ivalue::ConstantString;
47 
48 #define C10_USING(T) using ::c10::T;
49 C10_FORALL_TYPES(C10_USING)
50 #undef C10_USING
51 
52 #define C10_USING(T) using ::c10::T##Ptr;
53 C10_FORALL_TYPES(C10_USING)
54 #undef C10_USING
55 
56 using ::c10::Type;
57 using ::c10::TypeEnv;
58 using ::c10::TypePtr;
59 
60 using ::c10::getTypePtr;
61 using ::c10::MatchTypeReturn;
62 using ::c10::TypeKind;
63 
64 using ::c10::fmap;
65 
66 namespace prim {
67 using namespace ::c10::prim;
68 }
69 namespace attr {
70 using namespace ::c10::attr;
71 }
72 namespace aten {
73 using namespace ::c10::aten;
74 }
75 
76 // Graph represents one "function" of computation.
77 // It uses a simple ownership model where the graph owns all the nodes inside
78 // it. All references inside the graph are raw pointers. Destroying the Graph
79 // will invalidate any pointers to nodes in the graph.
80 struct Graph;
81 
82 // Node is the base class of the IR graph. It represents one computation
83 // and dependencies on a list of Values. The "prim-ops", so to speak.
84 struct Node;
85 
86 // A Value represents an input or output to node that is either a
87 // Tensor or an opaque Handle object, as determined by type().
88 struct Value;
89 
90 TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
91 TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n);
92 
93 // A list of nodes, with inputs and outputs
94 struct Block;
95 
96 // Each use is represented by this type, see Node::uses()
97 // 'user' is the consumer of the value, offset is the index into
98 // 'user's input this where the produces will be found.
99 struct Use {
100  Use(Node* user, size_t offset) : user(user), offset(offset) {}
101  Node* user;
102  size_t offset;
103 
104  bool operator==(const Use& b) {
105  return user == b.user && offset == b.offset;
106  }
107 };
108 
109 // Note [User node does not uniquely identify use]
110 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
111 // A while back, we wrote some code manipulating uses that looked like this:
112 //
113 // for (auto& use : used_val->uses_) {
114 // if (use.user == this_node) {
115 // use.offset += 1;
116 // break;
117 // }
118 // }
119 //
120 // This code is trying to find a particular use (our node's use) to update it.
121 // However, it's wrong: there may be *multiple* uses of a value %x in a node,
122 // as might be the case in this IR:
123 //
124 // %y = Add %x %x
125 //
126 // In this case, there are two uses of %x whose user is the node 'Add %x %x'.
127 // So, "use induced by this node" is not a well-formed concept.
128 //
129 // If you are looking for "use induced by an input", it's best to use
130 // findUseForInput() to get it.
131 
132 // the list types are intentionally simple, but we type-def
133 // them here so if we need to change them, refactoring will be easier
134 using node_list = std::vector<Node*>;
135 using value_list = std::vector<Value*>;
136 using use_list = std::vector<Use>;
137 using pyobj_list = std::vector<THPObjectPtr>;
138 template <typename T>
139 using ArrayRef = at::ArrayRef<T>;
140 using NodeKind = Symbol;
141 using topo_position_t = int64_t;
142 using ValueSet = std::unordered_set<const Value*>;
143 
144 struct Value {
145  TH_DISALLOW_COPY_AND_ASSIGN(Value);
146  Value(Node* node_, size_t offset_);
147 
148  private:
149  friend struct Node;
150  friend struct Graph;
151  Node* node_;
152  size_t offset_;
153  size_t unique_ = 0; // unique id
154  use_list uses_;
155  std::string unique_name_;
156  TypePtr type_;
157 
158  public:
159  Value* setType(TypePtr type);
160  void inferTypeFrom(const at::Tensor& output) {
161  setType(CompleteTensorType::create(output));
162  }
163  const TypePtr& type() const {
164  AT_ASSERT(type_ != nullptr);
165  return type_;
166  }
167  bool requires_grad() const {
168  return type()->requires_grad();
169  }
170  bool isTensor() const {
171  return type()->kind() == TypeKind::CompleteTensorType;
172  }
173  TORCH_API bool mustBeNone() const;
174  size_t unique() const {
175  return unique_;
176  }
177  bool hasUniqueName() const {
178  return !unique_name_.empty();
179  }
180  static bool isValidName(const std::string& name);
181  TORCH_API Value* setUniqueName(const std::string& name);
182  std::string uniqueName() const {
183  if (hasUniqueName()) {
184  return unique_name_;
185  }
186  return std::to_string(unique());
187  }
188  TORCH_API std::string uniqueNameBase() const;
189  Node* node() {
190  return node_;
191  }
192  size_t offset() const {
193  return offset_;
194  }
195  void setOffset(size_t offset) {
196  offset_ = offset;
197  }
198  const Node* node() const {
199  return node_;
200  }
201  Graph* owningGraph();
202  const Graph* owningGraph() const;
203  // TODO: make this more const correct
204  const use_list& uses() const {
205  return uses_;
206  }
207 
208  bool hasUses() const {
209  return !uses().empty();
210  }
211 
212  TORCH_API void replaceFirstUseWith(Value* newValue);
213 
214  // Replaces all uses of this value with 'newValue'.
215  //
216  // Given: %3 = f(%1, %2)
217  // %4 = g(%3)
218  // %5 = h(%3, %3)
219  // Execute: %3.replaceAllUsesWith(%6)
220  // Result: %3 = f(%1, %2)
221  // %4 = g(%6)
222  // %5 = h(%6, %6)
223  TORCH_API void replaceAllUsesWith(Value* newValue);
224 
225  TORCH_API Value* copyMetadata(Value* from);
226 };
227 
228 struct Node {
229  TH_DISALLOW_COPY_AND_ASSIGN(Node);
230  friend struct Graph;
231  friend struct Block;
232  friend struct Value;
233  friend graph_node_list;
234  friend const_graph_node_list;
237 
238  private:
239  const NodeKind kind_;
240  std::vector<Value*> inputs_;
241  std::vector<Value*> outputs_;
242  // subblocks
243  std::vector<Block*> blocks_;
244  Graph* graph_;
245  Block* owning_block_;
246  std::shared_ptr<SourceLocation> source_location_;
247  ScopePtr scope_;
248  // Assumes FunctionSchemas are persistent, so we don't manage their lifetime.
249  // This field is effective a cache that's populated on attribute lookups and
250  // invalidated every time we perform an operation that could potentially
251  // change the schema. note: mutable because schema_ is effectively a cache
252  mutable const FunctionSchema* schema_;
253  topo_position_t topo_position_ = 0;
254 
255  protected:
256  TORCH_API Node(Graph* graph_, NodeKind kind_); // defined after graph
257  public:
258  // each node but Return/Param
259  // is associated with exactly one place in the node list...
260  // of the graph_
261  // this circular is a doubly-linked list, the Return node is used as the
262  // sentinel for the beginning and end of the list such that the list never has
263  // null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev
264  // pointer using an array to allow the same iterator class for forward and
265  // reverse node lists This list represents a topological sort
266  Node* next_in_graph[2] = {nullptr, nullptr};
267 
268  Node*& next() {
269  return next_in_graph[kNextDirection];
270  }
271  Node*& prev() {
272  return next_in_graph[kPrevDirection];
273  }
274  Node* const& next() const {
275  return next_in_graph[kNextDirection];
276  }
277  Node* const& prev() const {
278  return next_in_graph[kPrevDirection];
279  }
280 
281  NodeKind kind() const {
282  return kind_;
283  }
284  Node* setSourceLocation(std::shared_ptr<SourceLocation> sl) {
285  source_location_ = std::move(sl);
286  return this;
287  }
288  std::shared_ptr<SourceLocation> getSourceLocation() const {
289  return source_location_;
290  }
291  Graph* owningGraph() {
292  return graph_;
293  }
294  const Graph* owningGraph() const {
295  return graph_;
296  }
297  Block* owningBlock() {
298  return owning_block_;
299  }
300  const Block* owningBlock() const {
301  return owning_block_;
302  }
303  ScopePtr scope() {
304  return scope_;
305  }
306  void setScope(ScopePtr scope) {
307  scope_ = std::move(scope);
308  }
309  std::string scopeName() const {
310  if (!scope_) {
311  return "";
312  }
313  return scope_->namesFromRoot();
314  }
315  // NB: This returns an ArrayRef; that means that it will
316  // get invalidated if you resize inputs (e.g., using addInput)
317  // We can't return a std::vector<Node*>& because there's no
318  // way to soundly cast to std::vector<const Node*> (an insane
319  // implementation of std::vector could make this representationally
320  // different.)
321  at::ArrayRef<Value*> inputs() {
322  return inputs_;
323  }
324  at::ArrayRef<const Value*> inputs() const {
325  // Vectors are not convertible in const-ness of elements, but
326  // raw pointers are.
327  return {inputs_.data(), inputs_.size()};
328  }
329  // NB: This returns an ArrayRef; that means that it will
330  // get invalidated if you resize inputs (e.g., using addInput)
331  // We can't return a std::vector<Node*>& because there's no
332  // way to soundly cast to std::vector<const Node*> (an insane
333  // implementation of std::vector could make this representationally
334  // different.)
335  at::ArrayRef<Value*> outputs() {
336  return outputs_;
337  }
338  at::ArrayRef<const Value*> outputs() const {
339  // Vectors are not convertible in const-ness of elements, but
340  // raw pointers are.
341  return {outputs_.data(), outputs_.size()};
342  }
343  Value* output(size_t i) const {
344  return outputs_.at(i);
345  }
346  bool hasUses() const {
347  for (auto o : outputs()) {
348  if (!o->uses().empty()) {
349  return true;
350  }
351  }
352  return false;
353  }
354 
355  TORCH_API void replaceAllUsesWith(Node* n);
356 
357  // lots of things like chunk have a single input or single output, so we have
358  // a helper to make accessing it easier
359  Value* input() {
360  AT_ASSERT(inputs_.size() == 1);
361  return inputs_.at(0);
362  }
363  Value* output() {
364  AT_ASSERT(outputs_.size() == 1);
365  return outputs_.at(0);
366  }
367  const Value* output() const {
368  AT_ASSERT(outputs_.size() == 1);
369  return outputs_.at(0);
370  }
371  const Value* input() const {
372  AT_ASSERT(inputs_.size() == 1);
373  return inputs_.at(0);
374  }
375  // Access a particular input. This is a checked index.
376  Value* input(size_t i) const {
377  return inputs_.at(i);
378  }
379 
380  Value* namedInput(Symbol name) const;
381 
382  c10::optional<IValue> get(Symbol name) const;
383 
384  template <typename T>
385  c10::optional<T> get(Symbol name) const {
386  if (auto v = get(name)) {
387  return v->template to<T>();
388  }
389  return c10::nullopt;
390  }
391 
392  // Returns true if the value of input name is statically known
393  bool is_constant(Symbol name) const {
394  return static_cast<bool>(get(name));
395  }
396  TORCH_API bool mustBeNone() const;
397 
398  TORCH_API bool isNondeterministic() const;
399  TORCH_API bool hasSideEffects() const;
400 
401  // Graphs
402 
403  // Note [Topological invariant]
404  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
405  // We always maintain an up-to-date topological ordering of all nodes via
406  // the next()/prev() links. All transformations to graphs must preserve
407  // this topological ordering: for example, it is only valid to 'addInput'
408  // with an input which is topologically before the current node.
409  //
410  // Usually, it is obvious whether or not topological order is maintained;
411  // for example, if you are adding nodes to the end of the topsort, it's
412  // impossible for them to refer to inputs that are not in the topsort.
413  // If it is not obvious, please comment accordingly.
414 
415  // Add 'node' as an input to 'this' at the end of existing
416  // arguments. Returns the added node for ease of chaining.
417  //
418  // Given: %3 = f(%1, %2)
419  // Execute: %3.addInput(%4)
420  // Result: %3 = f(%1, %2, %4)
421  TORCH_API Value* addInput(Value* value);
422 
423  // Add 'value' as an input to 'this' at the specified position in the
424  // arguments. Returns the added value for ease of chaining.
425  TORCH_API Value* insertInput(size_t i, Value* value);
426 
427  // Replace the input of 'this' at position 'i' with
428  // 'newValue', returning the old node.
429  //
430  // Given: %3 = f(%1, %2)
431  // Execute: %3.replaceInput(1, %4)
432  // Result: %3 = f(%1, %4)
433  TORCH_API Value* replaceInput(size_t i, Value* newValue);
434 
435  // Replace all occurrences of 'from' in the inputs of this
436  // node with 'to'. Corresponds to llvm's replaceUsesOfWith.
437  //
438  // Given: %3 = f(%1, %2, %1)
439  // Execute: %3.replaceInputWith(%1, %4)
440  // Result: %3 = f(%4, %2, %4)
441  TORCH_API void replaceInputWith(Value* from, Value* to);
442 
443  TORCH_API Value* addOutput();
444 
445  TORCH_API Value* insertOutput(size_t i);
446 
447  TORCH_API void eraseOutput(size_t i);
448 
449  TORCH_API Block* addBlock();
450  TORCH_API void eraseBlock(size_t i);
451 
452  // Each Node can have a list of subblocks. These are used to define structured
453  // nested control flow operators such as If and Loop.
454  // The meaning of a block is specific to the kind of node it is in, but
455  // all blocks share these semantics:
456  // * Nested lexical scoping: If a node 'Parent' has a subblock which contains
457  // a node 'Child', Child can use any value that was in scope for the Parent
458  // node in addition to any values defined before 'Child' in the subblock.
459  // * The list of inputs to the block are in scope for the duration of the
460  // block
461  // * the outputs of the Parent node are not in scope for the subblocks
462  // Typically the inputs to a block that represents control flow act as
463  // as the equivalents phi-nodes in standard SSA form,
464  // defining a new Value to represent any term that has multiple
465  // definitions depending on how control flowed. Outputs of the node containing
466  // control flow serve a similiar purpose defining new values for variables
467  // that would have different defintions depending on which way control flowed.
468 
469  at::ArrayRef<Block*> blocks() {
470  return blocks_;
471  }
472  at::ArrayRef<const Block*> blocks() const {
473  // Vectors are not convertible in const-ness of elements, but
474  // raw pointers are.
475  return {blocks_.data(), blocks_.size()};
476  }
477 
478  // Is 'this' before 'n' in the topological order?
479  TORCH_API bool isBefore(const Node* n) const;
480 
481  // Is 'this' after 'n' in the topological order?
482  TORCH_API bool isAfter(const Node* n) const;
483 
484  // Insert unattached 'this' node before 'n' in the topological order.
485  // Returns this (for chaining).
486  //
487  // Given: %3 = f(%1, %2)
488  // %4 = g(%3)
489  // and unattached: %5 = h(%1)
490  // Execute: %5.insertBefore(%4)
491  // Result: %3 = f(%1, %2)
492  // %5 = h(%1)
493  // %4 = g(%3)
494  TORCH_API Node* insertBefore(Node* n);
495 
496  // Insert unattached 'this' node after 'n' in the topological order.
497  // Returns this (for chaining).
498  //
499  // Given: %3 = f(%1, %2)
500  // %4 = g(%3)
501  // and unattached: %5 = h(%1)
502  // Execute: %5.insertAfter(%4)
503  // Result: %3 = f(%1, %2)
504  // %4 = g(%3)
505  // %5 = h(%1)
506  TORCH_API Node* insertAfter(Node* n);
507 
508  // Move 'this' (already in the graph) after 'n' in the topological order.
509  //
510  // NOTE: Does not check that value dependencies are preserved, see
511  // AliasDb::moveAfterTopologicallyValid
512  //
513  // Given: %2 = f(%1)
514  // %3 = g(%1)
515  // Execute: %2.moveAfter(%3)
516  // Result: %3 = g(%1)
517  // %2 = f(%1)
518  //
519  TORCH_API void moveAfter(Node* n);
520 
521  // Move a node 'n' (already in the graph) before 'this' in the topological
522  // order.
523  //
524  // NOTE: Does not check that value dependencies are preserved, see
525  // AliasDb::moveBeforeTopologicallyValid
526  //
527  // Given: %2 = f(%1)
528  // %3 = g(%1)
529  // Execute: %3.moveBefore(%2)
530  // Result: %3 = g(%1)
531  // %2 = f(%1)
532  TORCH_API void moveBefore(Node* n);
533 
534  // Remove the input at 'i' from this node.
535  //
536  // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling
537  // removeInput.
538  //
539  // Given: %3 = f(%1, %2)
540  // Execute: %3.removeInput(1)
541  // Result: %3 = f(%1)
542  TORCH_API void removeInput(size_t i);
543 
544  // Remove all inputs from a node.
545  //
546  // Given: %3 = f(%1, %2)
547  // Execute: %3.removeAllInputs()
548  // Result: %3 = f()
549  TORCH_API void removeAllInputs();
550 
551  // iterators of the node list starting at this node
552  // useful for resuming a search starting at this node
553  inline graph_node_list_iterator iterator() {
554  return {this, 0};
555  }
556  inline graph_node_list_iterator reverseIterator() {
557  return iterator().reverse();
558  }
559  inline const_graph_node_list_iterator iterator() const {
560  return {this, 0};
561  }
562  inline const_graph_node_list_iterator reverseIterator() const {
563  return iterator().reverse();
564  }
565 
566  // Remove 'this' from the instruction list and deallocate it.
567  //
568  // Invariant: no outputs of 'this' may have any uses.
569  //
570  // Given: %2 = f(%1)
571  // %3 = g(%1)
572  // Execute: %2.destroy()
573  // Result: %3 = g(%1)
574  TORCH_API void destroy();
575 
576  // Dynamically cast this node to the subclass indicated by the
577  // template variable, returning nullptr if the cast is invalid..
578  //
579  // Example usage: if(auto s = n.cast<Select>()) { ... }
580  //
581  // TODO: Make this const correct
582  template <typename T>
583  T* cast() {
584  if (T::Kind == kind()) {
585  return static_cast<T*>(this);
586  }
587  return nullptr;
588  }
589  template <typename T>
590  T* expect() {
591  AT_CHECK(
592  T::Kind == kind(),
593  "expected a ",
594  T::Kind.toDisplayString(),
595  " but found a ",
596  kind().toDisplayString());
597  return static_cast<T*>(this);
598  }
599 
600  // XXX: this function is meant to be used with string literals only!
601  TORCH_API bool matches(
602  const char* signature_literal,
603  at::ArrayRef<Symbol> const_inputs = {}) const;
604 
605  const FunctionSchema& schema() const {
606  if (!schema_) {
607  findSchema();
608  }
609  return *schema_;
610  }
611  const FunctionSchema* maybeSchema() const;
612 
613  void dump() const;
614 
615  std::ostream& print(
616  std::ostream& out,
617  size_t level,
618  std::vector<const Node*>* groups) const;
619 
620  virtual ~Node() = default;
621 
622  // Methods for accessing attributes
623  void copyAttributes(const Node& rhs) {
624  values_.clear();
625  for (const AVPtr& i : rhs.values_) {
626  values_.push_back(i->clone());
627  }
628  }
629  bool hasAttribute(Symbol name) const {
630  AT_ASSERT(name.is_attr());
631  return findAttr(name, false) != values_.end();
632  }
633  bool hasAttributeS(const std::string& name) const {
634  return hasAttribute(Symbol::attr(name));
635  }
636  AttributeKind kindOf(Symbol name) const {
637  AT_ASSERT(name.is_attr());
638  return (*findAttr(name, true))->kind();
639  }
640  AttributeKind kindOfS(const std::string& name) const {
641  return kindOf(Symbol::attr(name));
642  }
643  Node* removeAttribute(Symbol name) {
644  AT_ASSERT(name.is_attr());
645  values_.erase(findAttr(name, true));
646  return this;
647  }
648  Node* removeAttributeS(const std::string& name) {
649  return removeAttribute(Symbol::attr(name));
650  }
651  bool hasAttributes() const {
652  return values_.size() > 0;
653  }
654  size_t numAttributes() const {
655  return values_.size();
656  }
657  // The names are returned in order, since name actually is the index.
658  std::vector<Symbol> attributeNames() const {
659  std::vector<Symbol> names;
660  for (const AVPtr& a : values_) {
661  names.push_back(a->name);
662  }
663  return names;
664  }
665  std::vector<const char*> attributeNamesS() const {
666  std::vector<const char*> names;
667  for (const AVPtr& a : values_) {
668  names.push_back(a->name.toUnqualString());
669  }
670  return names;
671  }
672 
673 #define CREATE_ACCESSOR(Kind, method) \
674  Node* method##_(Symbol name, Kind##Attr::ConstructorType v) { \
675  return setAttr<Kind##Attr>( \
676  name, std::forward<Kind##Attr::ConstructorType>(v)); \
677  } \
678  const Kind##Attr::ValueType& method(Symbol name) const { \
679  return getAttr<Kind##Attr>(name); \
680  }
681 
682  CREATE_ACCESSOR(Float, f)
683  CREATE_ACCESSOR(Floats, fs)
684  CREATE_ACCESSOR(String, s)
685  CREATE_ACCESSOR(Strings, ss)
686  CREATE_ACCESSOR(Int, i)
687  CREATE_ACCESSOR(Ints, is)
688  CREATE_ACCESSOR(Graph, g)
689  CREATE_ACCESSOR(Graphs, gs)
690 
691 #undef CREATE_ACCESSOR
692 
693  // Our Graphs are not very const-correct, so we need to allow returning
694  // non-const references too
695  GraphAttr::ValueType& g(Symbol name) {
696  return getAttr<GraphAttr>(name);
697  }
698 
699  // does not use CREATE_ACCESSOR because we need additional asserts
700  Node* t_(Symbol name, TensorAttr::ConstructorType v) {
701  AT_ASSERT(!v.defined() || v.is_variable());
702  return setAttr<TensorAttr>(
703  name, std::forward<TensorAttr::ConstructorType>(v));
704  }
705  const TensorAttr::ValueType& t(Symbol name) const {
706  return getAttr<TensorAttr>(name);
707  }
708 
709  Node* ts_(Symbol name, TensorsAttr::ConstructorType v) {
710  for (const at::Tensor& t : v) {
711  AT_ASSERT(!t.defined() || t.is_variable());
712  }
713  return setAttr<TensorsAttr>(
714  name, std::forward<TensorsAttr::ConstructorType>(v));
715  }
716  const TensorsAttr::ValueType& ts(Symbol name) const {
717  return getAttr<TensorsAttr>(name);
718  }
719 
720  private:
721  void printAttrValue(std::ostream& out, const Symbol& name) const;
722  void printAttributes(std::ostream& out, bool ignore_subgraph) const;
723 
724  template <typename T>
725  Node* setAttr(Symbol name, typename T::ConstructorType v) {
726  AT_ASSERT(name.is_attr());
727  auto it = findAttr(name, false);
728  auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
729  if (it == values_.end()) {
730  values_.push_back(std::move(nv));
731  } else {
732  *it = std::move(nv);
733  }
734  return this;
735  }
736  template <typename T>
737  typename T::ValueType& getAttr(Symbol name) const {
738  AT_ASSERT(name.is_attr());
739  auto it = findAttr(name, true);
740  auto* child = dynamic_cast<T*>(it->get());
741  if (child == nullptr) {
742  throw AttributeError(name, true);
743  }
744  return child->value();
745  }
746  using AVPtr = AttributeValue::Ptr;
747  // NB: For determinism, we use a vector rather than a hash map. This does
748  // mean that lookups are O(n), so you shouldn't use Attributes to store
749  // a big pile of messages.
750  std::vector<AVPtr> values_;
751  std::vector<AVPtr>::iterator findAttr(Symbol name, bool required) {
752  AT_ASSERT(name.is_attr());
753  auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
754  return v->name == name;
755  });
756  if (required && it == values_.end()) {
757  throw AttributeError(name, false);
758  }
759  AT_ASSERT(!required || it != values_.end());
760  return it;
761  }
762  std::vector<AVPtr>::const_iterator findAttr(Symbol name, bool required)
763  const {
764  AT_ASSERT(name.is_attr());
765  auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
766  return v->name == name;
767  });
768  if (required && it == values_.end()) {
769  throw AttributeError(name, false);
770  }
771  AT_ASSERT(!required || it != values_.end());
772  return it;
773  }
774 
775  enum class MoveSide { BEFORE, AFTER };
776  bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
777 
778  std::pair<Value*, const Argument&> findInput(Symbol name);
779  void findSchema() const;
780  // Lookup iterator in use list of _input i_ that corresponds to its use of
781  // _this_
782  TORCH_API use_list::iterator findUseForInput(size_t i);
783 
784  // remove the use of input i, this sets input i to nullptr, but
785  // is only used internally to Node before setting it to a new value
786  // or erasing the entry from the list.
787  TORCH_API Value* dropInput(size_t i);
788 
789  bool inBlockList() const {
790  if (next() == nullptr) {
791  AT_ASSERT(prev() == nullptr);
792  }
793  return next() != nullptr;
794  }
795 
796  TORCH_API void removeFromList();
797  TORCH_API void lint() const;
798 
799  void assignTopoPosition();
800 
801  protected:
802  // subclasses must override
803  // this function is used by createClone to initialize a new version
804  // of a node in another graph. It should allocate a new instance of the same
805  // concrete type as 'this', but in graph 'g' which might be different
806  // than graph_
807  virtual Node* allocNewInstance(Graph* g) {
808  return new Node(g, kind());
809  }
810  // create a copy of all properties of Node s into this.
811  // subclasses should extend if they have additional information to copy.
812  // 'this' will be allocated with s->allocNewInstance(g) so it should have
813  // the same concrete type as 's'
814  //
815  TORCH_API virtual void cloneFrom(Node* s);
816 };
817 
818 struct Block {
819  friend struct Node;
820  friend struct Graph;
821 
822  TH_DISALLOW_COPY_AND_ASSIGN(Block);
823  TORCH_API Block(Graph* graph_, Node* node_);
824 
825  at::ArrayRef<Value*> inputs() {
826  return input_->outputs();
827  }
828  at::ArrayRef<const Value*> inputs() const {
829  const auto& inputs = input_->outputs();
830  return {inputs.data(), inputs.size()};
831  }
832  at::ArrayRef<Value*> outputs() {
833  return output_->inputs();
834  }
835  at::ArrayRef<const Value*> outputs() const {
836  return static_cast<const Node*>(output_)->inputs();
837  }
838  graph_node_list nodes() {
839  return {output_, kNextDirection};
840  }
841  const_graph_node_list nodes() const {
842  return {output_, kNextDirection};
843  }
844  Node* return_node() {
845  return output_;
846  }
847  const Node* return_node() const {
848  return output_;
849  }
850  Node* param_node() {
851  return input_;
852  }
853  const Node* param_node() const {
854  return input_;
855  }
856  Graph* owningGraph() {
857  return graph_;
858  }
859  const Graph* owningGraph() const {
860  return graph_;
861  }
862  Node* owningNode() {
863  return owning_node_;
864  }
865  const Node* owningNode() const {
866  return owning_node_;
867  }
868 
869  Value* addInput(std::string name = "") {
870  Value* v = input_->addOutput();
871  v->setUniqueName(std::move(name));
872  return v;
873  }
874  Value* insertInput(size_t i, std::string name = "") {
875  Value* v = input_->insertOutput(i);
876  v->setUniqueName(std::move(name));
877  return v;
878  }
879  void eraseInput(size_t i) {
880  input_->eraseOutput(i);
881  }
882  size_t registerOutput(Value* v) {
883  output_->addInput(v);
884  return outputs().size() - 1;
885  }
886  size_t insertOutput(size_t i, Value* n) {
887  output_->insertInput(i, n);
888  return i;
889  }
890  void eraseOutput(size_t i) {
891  output_->removeInput(i);
892  }
893 
894  Node* appendNode(Node* n) {
895  AT_ASSERT(n->graph_ == graph_ && !n->inBlockList());
896  n->insertBefore(output_);
897  return n;
898  }
899  Node* prependNode(Node* n) {
900  AT_ASSERT(n->graph_ == graph_ && !n->inBlockList());
901  n->insertAfter(output_);
902  return n;
903  }
904  // clone all inputs, nodes, and outputs from src and append them
905  // to the inputs, nodes, and outputs of this block
906  // value_map is used whenever a node in src references a free variable
907  // in src to look up its corresponding value
908  TORCH_API void cloneFrom(Block* src, std::function<Value*(Value*)> value_map);
909 
910  private:
911  void reIndexTopology();
912 
913  // should only be called in the constructor
914  Node* initOutput(Node* p) {
915  p->next() = p;
916  p->prev() = p;
917  return p;
918  }
919 
920  // get rid of all nodes
921  // destroys in reverse order so that uses internal to this block
922  // do not have to be removed before you can destroy the block
923  void destroy();
924 
925  Graph* const graph_;
926  // holds outputs in a way that can be reflected
927  // as a Use object
928  // also used as the beginning/end of the circular node list to avoid
929  // having corner cases where the list is empty.
930  Node* const output_;
931  Node* const input_;
932  Node* const
933  owning_node_; // either the node that has this block or nullptr for root
934 };
935 
936 struct Graph {
937  TH_DISALLOW_COPY_AND_ASSIGN(Graph);
938  friend struct Node;
939  friend struct Value;
940  friend struct Block;
941 
942  private:
943  // only used to keep track of allocated nodes
944  // actual representation of Graph is done with
945  // inputs, outputs, nodes
946 
947  std::unordered_set<const Node*> all_nodes;
948  std::unordered_set<const Value*> all_values;
949  std::unordered_set<const Block*> all_blocks;
950  size_t next_unique_;
951 
952  std::unordered_map<std::string, Value*> unique_names_;
953 
954  ScopePtr current_scope_;
955 
956  Block* const block_;
957  // when insertNode() is called, the node is inserted before this node
958  // by default this is set to append to the top level block
959  Node* insert_before_;
960 
961  public:
962  Graph(ScopePtr scope_root)
963  : next_unique_(0),
964  current_scope_(std::move(scope_root)),
965  block_(new Block(this, nullptr)),
966  insert_before_(return_node()) {}
967 
968  Graph() : Graph(c10::make_intrusive<Scope>()) {}
969 
970  at::ArrayRef<Value*> inputs() {
971  return block_->inputs();
972  }
973  at::ArrayRef<const Value*> inputs() const {
974  const Block& block = *block_;
975  return block.inputs();
976  }
977  at::ArrayRef<Value*> outputs() {
978  return block_->outputs();
979  }
980  at::ArrayRef<const Value*> outputs() const {
981  const Block& block = *block_;
982  return block.outputs();
983  }
984  graph_node_list nodes() {
985  return block_->nodes();
986  }
987  const_graph_node_list nodes() const {
988  const Block& block = *block_;
989  return block.nodes();
990  }
991  Node* param_node() {
992  return block_->param_node();
993  }
994  const Node* param_node() const {
995  return block_->param_node();
996  }
997  Node* return_node() {
998  return block_->return_node();
999  }
1000  const Node* return_node() const {
1001  return block_->return_node();
1002  }
1003  const std::unordered_map<std::string, Value*>& uniqueNames() const {
1004  return unique_names_;
1005  }
1006 
1007  void push_scope(const std::string& scope_name) {
1008  current_scope_ = current_scope_->push(Symbol::scope(scope_name));
1009  }
1010  void pop_scope() {
1011  current_scope_ = current_scope_->parent();
1012  }
1013  ScopePtr current_scope() {
1014  return current_scope_;
1015  }
1016  void set_current_scope(ScopePtr scope) {
1017  current_scope_ = std::move(scope);
1018  }
1019 
1020  Value* addInput(std::string name = "") {
1021  return block_->addInput(std::move(name));
1022  }
1023  Value* insertInput(size_t i, std::string name = "") {
1024  return block_->insertInput(i, std::move(name));
1025  }
1026  void eraseInput(size_t i) {
1027  block_->eraseInput(i);
1028  }
1029  size_t registerOutput(Value* n) {
1030  return block_->registerOutput(n);
1031  }
1032  void eraseOutput(size_t i) {
1033  block_->eraseOutput(i);
1034  }
1035 
1036  TORCH_API Node* create(NodeKind kind, size_t num_outputs = 1);
1037  TORCH_API Node* create(
1038  NodeKind kind,
1039  ArrayRef<Value*> inputs,
1040  size_t num_outputs = 1);
1041 
1042  TORCH_API Node* createNone(
1043  TypePtr typ); // value of None with type Optional[typ]
1044  TORCH_API Node* createAutogradZero();
1045  TORCH_API Node* createFusionGroup();
1046  TORCH_API Node* createDifferentiableSubgraph();
1047  TORCH_API Node* createTuple(
1048  at::ArrayRef<Value*> values,
1049  c10::OptNameList field_names = c10::nullopt);
1050  TORCH_API Node* createTupleUnpack(Value* v);
1051  TORCH_API Node* createTupleIndex(Value* tup, int64_t index);
1052  TORCH_API Node* createTupleSlice(Value* tup, int64_t beg, int64_t end);
1053  TORCH_API Node* createList(
1054  const TypePtr& elem_type,
1055  at::ArrayRef<Value*> values);
1056  TORCH_API Node* createListUnpack(Value* v, size_t size);
1057  TORCH_API Node* createDict(
1058  const TypePtr& key_type,
1059  const TypePtr& value_type,
1060  at::ArrayRef<Value*> keys,
1061  at::ArrayRef<Value*> values);
1062  TORCH_API Node* createDictIndex(Value* dict, Value* index);
1063  TORCH_API Node* createNumToTensor(Value* value);
1064  TORCH_API Node* createImplicitTensorToNum(const TypePtr& type, Value* value);
1065  TORCH_API Node* createObject(const ClassTypePtr& type);
1066  TORCH_API Node* createSetAttr(
1067  Value* obj,
1068  const std::string& field,
1069  Value* newValue);
1070  TORCH_API Node* createGetAttr(Value* obj, const std::string& field);
1071  Node* createPythonOp(
1072  THPObjectPtr&& pyobj,
1073  const std::string& cconv,
1074  pyobj_list&& scalar_args);
1075  // clone n, making a new node in _this_ graph.
1076  // use node_map to translate inputs of n to inputs of the cloned node
1077  // if copy_blocks is false, it will not recursively clone the nested blocks
1078  // this node contains.
1079  TORCH_API Node* createClone(
1080  Node* n,
1081  const std::function<Value*(Value*)>& value_map,
1082  bool copy_blocks = true);
1083 
1084  // Insert constant IValue into the graph. If the type cannot be fully deduced
1085  // from the ivalue, as with a None that is set to t?, use result_type
1086  TORCH_API Value* insertConstant(
1087  IValue val,
1088  const TypePtr& result_type = nullptr,
1089  c10::optional<SourceRange> loc = c10::nullopt,
1090  c10::optional<ScopePtr> scope = c10::nullopt);
1091 
1092  // Schema-driven insert:
1093  // This inserts a node into the graph with inputs determined from args and
1094  // kwargs using Python argument matching rules, and checks that the op matches
1095  // a known schema.
1096  //
1097  // If this node successfully completes, it guarentees the node
1098  // is a correctly-formed invocation of opname
1099  TORCH_API Value* insert(
1100  Symbol opname,
1102  at::ArrayRef<NamedValue> kwargs = {},
1103  const c10::optional<SourceRange>& range = {});
1104 
1105  Node* appendNode(Node* n) {
1106  return block_->appendNode(n);
1107  }
1108 
1109  Node* prependNode(Node* n) {
1110  return block_->prependNode(n);
1111  }
1112 
1113  // insert before insert_before_ node
1114  // initialized to insert at the end of the top level block
1115  // can be changed with setInsertPoint()
1116  Node* insertNode(Node* n) {
1117  AT_ASSERT(
1118  insert_before_->inBlockList() &&
1119  "insert point node is no longer in a block list");
1120  return n->insertBefore(insert_before_);
1121  }
1122  // set where nodes are inserted to append to the end of this block
1123  void setInsertPoint(Block* b) {
1124  AT_ASSERT(b->owningGraph() == this);
1125  insert_before_ = b->return_node();
1126  }
1127  // set where nodes are inserted to insert _before_ this node
1128  // for implementation simplicity we only support inserting before a node for
1129  // now
1130  void setInsertPoint(Node* n) {
1131  AT_ASSERT(n->owningGraph() == this && n->inBlockList());
1132  insert_before_ = n;
1133  }
1134  Node* insertPoint() {
1135  return insert_before_;
1136  }
1137 
1138  // the top level block
1139  Block* block() {
1140  return block_;
1141  }
1142  const Block* block() const {
1143  return block_;
1144  }
1145 
1146  // Checks well-formedness and invariants of graph
1147  TORCH_API void lint() const;
1148  // for use in debugger
1149  TORCH_API void dump() const;
1150 
1151  TORCH_API ~Graph();
1152 
1153  TORCH_API std::string toString() const;
1154 
1155  friend TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
1156 
1157  TORCH_API std::ostream& prettyPrint(std::ostream& out);
1158  TORCH_API void dumpPretty();
1159 
1160  TORCH_API std::shared_ptr<Graph> copy();
1161 
1162  private:
1163  TORCH_API void freeNode(Node* n);
1164  TORCH_API void freeValue(Value* v);
1165  TORCH_API void freeBlock(Block* b);
1166 };
1167 
1175  WithInsertPoint(Node* n) : prev_(n->owningGraph()->insertPoint()) {
1176  n->owningGraph()->setInsertPoint(n);
1177  }
1178  WithInsertPoint(Block* b) : WithInsertPoint(b->return_node()) {}
1179 
1180  ~WithInsertPoint() {
1181  prev_->owningGraph()->setInsertPoint(prev_);
1182  }
1183 
1184  private:
1185  Node* prev_;
1186 };
1187 
1194  WithCurrentScope(Graph& g, ScopePtr scope)
1195  : graph_(&g), prev_scope_(g.current_scope()) {
1196  g.set_current_scope(std::move(scope));
1197  }
1198  ~WithCurrentScope() {
1199  graph_->set_current_scope(prev_scope_);
1200  }
1201 
1202  private:
1203  Graph* graph_;
1204  ScopePtr prev_scope_;
1205 };
1206 
1207 inline Value::Value(Node* node_, size_t offset_)
1208  : node_(node_),
1209  offset_(offset_),
1210  unique_(node_->graph_->next_unique_++),
1211  type_(TensorType::get()) {
1212  node_->graph_->all_values.emplace(this);
1213 }
1214 
1215 inline Value* Value::setType(TypePtr type) {
1216  AT_ASSERT(type);
1217  type_ = std::move(type);
1218  for (Use& use : uses_) {
1219  use.user->schema_ = nullptr;
1220  }
1221  return this;
1222 }
1223 
1224 inline Graph* Value::owningGraph() {
1225  return node()->owningGraph();
1226 }
1227 
1228 inline const Graph* Value::owningGraph() const {
1229  return node()->owningGraph();
1230 }
1231 
1232 /************* All nodes not required to be defined before Graph **************/
1233 
1234 // execute a Python function, used for Ops we can't optimize but that we want to
1235 // optimize around
1236 struct PythonOp : public Node {
1237  static constexpr Symbol Kind = ::c10::prim::PythonOp;
1238 
1239  PythonOp(Graph* graph) : Node(graph, ::c10::prim::PythonOp) {}
1240  PythonOp* init(
1241  THPObjectPtr&& pyobj,
1242  const std::string& cconv,
1243  pyobj_list&& scalar_args) {
1244  this->pyobj = std::move(pyobj);
1245  this->scalar_args = std::move(scalar_args);
1246  this->cconv = cconv;
1247  return this;
1248  }
1249  // The Python object which contains the implementation of this function.
1250  // This is either a class (non-legacy) or an object (legacy). See
1251  // TraceInterpreterState for execution semantics.
1252  THPObjectPtr pyobj;
1253  // The calling convention for the Python function.
1254  // 'c' -- constant argument
1255  // 'd' -- dynamic argument
1256  std::string cconv;
1257  // Scalar arguments to the Python function. Not necessarily passed to
1258  // the function in this order; see cconv for the correct order.
1259  std::vector<THPObjectPtr> scalar_args;
1260  virtual std::string name() const = 0;
1261  virtual void writeScalars(std::ostream& out) const = 0;
1262  void cloneFrom(Node* other_) override = 0;
1263  Node* allocNewInstance(Graph* g) override = 0;
1264  // recover the autograd.Function instance, if this PythonOp's function
1265  // was originally SomeFunction.apply
1266  // used in ONNX for discovering symbolics
1267  virtual c10::optional<THPObjectPtr> autogradFunction() const = 0;
1268 
1269  // should this Python function be skipped over when exported (i.e. for
1270  // debugging functions that only run in Python)
1271  bool ignore_on_export = false;
1272 };
1273 // patched in when python bindings are loaded
1274 TORCH_API PythonOp* allocPythonOp(Graph* g);
1275 TORCH_API void setAllocPythonOp(PythonOp* (*v)(Graph* g));
1276 inline Node* Graph::createPythonOp(
1277  THPObjectPtr&& pyobj,
1278  const std::string& cconv,
1279  pyobj_list&& scalar_args) {
1280  PythonOp* op = allocPythonOp(this);
1281  return op->init(std::move(pyobj), cconv, std::move(scalar_args));
1282 }
1283 
1284 TORCH_API void LintGraph(std::shared_ptr<Graph>& graph);
1285 
1286 TORCH_API at::ArrayRef<Value*> createTupleUnpack(Value* v);
1287 // unpack_outputs - if true, and the callee returns a single tuple value, then
1288 // insert a tuple unpack node
1289 // and return the resulting values
1290 TORCH_API std::vector<Value*> inlineCallTo(
1291  Graph& g,
1292  Graph& callee,
1293  ArrayRef<Value*> inputs,
1294  bool unpack_outputs = false);
1295 
1296 } // namespace jit
1297 } // namespace torch
An utility class for setting temporary scopes.
Definition: ir.h:1193
bool is_variable() const noexcept
Returns true if the Tensor is actually a torch::autograd::Variable.
Definition: jit_type.h:17
An utility class for setting temporary insertion points.
Definition: ir.h:1174
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...
C10_NODISCARD TensorOptions requires_grad(c10::optional< bool > requires_grad) const noexcept
Sets the requires_grad property of the TensorOptions.