3 #include <torch/csrc/jit/attributes.h> 4 #include <torch/csrc/jit/graph_node_list.h> 5 #include <torch/csrc/jit/named_value.h> 6 #include <torch/csrc/jit/scope.h> 8 #include <torch/csrc/WindowsTorchApiMacro.h> 9 #include <torch/csrc/utils/disallow_copy.h> 10 #include <torch/csrc/utils/object_ptr.h> 12 #include <ATen/ATen.h> 13 #include <ATen/core/function_schema.h> 14 #include <ATen/core/functional.h> 15 #include <ATen/core/interned_strings.h> 16 #include <ATen/core/ivalue.h> 17 #include <ATen/core/jit_type.h> 18 #include <c10/util/ArrayRef.h> 19 #include <c10/util/Exception.h> 23 #include <unordered_set> 29 using ::c10::Argument;
30 using ::c10::FunctionSchema;
33 using ::c10::ivalue::List;
34 using ::c10::ivalue::Shared;
37 using ::c10::ivalue::Future;
38 using ::c10::ivalue::Tuple;
40 using ::c10::ivalue::BoolList;
41 using ::c10::ivalue::DoubleList;
42 using ::c10::ivalue::GenericList;
43 using ::c10::ivalue::IntList;
44 using ::c10::ivalue::TensorList;
46 using ::c10::ivalue::ConstantString;
48 #define C10_USING(T) using ::c10::T; 49 C10_FORALL_TYPES(C10_USING)
52 #define C10_USING(T) using ::c10::T##Ptr; 53 C10_FORALL_TYPES(C10_USING)
60 using ::c10::getTypePtr;
61 using ::c10::MatchTypeReturn;
62 using ::c10::TypeKind;
67 using namespace ::c10::prim;
70 using namespace ::c10::attr;
73 using namespace ::c10::aten;
90 TORCH_API std::ostream& operator<<(std::ostream& out,
const Graph& g);
91 TORCH_API std::ostream& operator<<(std::ostream& out,
const Node& n);
100 Use(
Node* user,
size_t offset) : user(user), offset(offset) {}
104 bool operator==(
const Use& b) {
105 return user == b.user && offset == b.offset;
134 using node_list = std::vector<Node*>;
135 using value_list = std::vector<Value*>;
136 using use_list = std::vector<Use>;
137 using pyobj_list = std::vector<THPObjectPtr>;
138 template <
typename T>
141 using topo_position_t = int64_t;
142 using ValueSet = std::unordered_set<const Value*>;
145 TH_DISALLOW_COPY_AND_ASSIGN(
Value);
155 std::string unique_name_;
159 Value* setType(TypePtr type);
160 void inferTypeFrom(
const at::Tensor& output) {
161 setType(CompleteTensorType::create(output));
163 const TypePtr& type()
const {
164 AT_ASSERT(type_ !=
nullptr);
170 bool isTensor()
const {
171 return type()->kind() == TypeKind::CompleteTensorType;
173 TORCH_API
bool mustBeNone()
const;
174 size_t unique()
const {
177 bool hasUniqueName()
const {
178 return !unique_name_.empty();
180 static bool isValidName(
const std::string& name);
181 TORCH_API
Value* setUniqueName(
const std::string& name);
182 std::string uniqueName()
const {
183 if (hasUniqueName()) {
186 return std::to_string(unique());
188 TORCH_API std::string uniqueNameBase()
const;
192 size_t offset()
const {
195 void setOffset(
size_t offset) {
198 const Node* node()
const {
201 Graph* owningGraph();
202 const Graph* owningGraph()
const;
204 const use_list& uses()
const {
208 bool hasUses()
const {
209 return !uses().empty();
212 TORCH_API
void replaceFirstUseWith(
Value* newValue);
223 TORCH_API
void replaceAllUsesWith(
Value* newValue);
229 TH_DISALLOW_COPY_AND_ASSIGN(
Node);
240 std::vector<Value*> inputs_;
241 std::vector<Value*> outputs_;
243 std::vector<Block*> blocks_;
245 Block* owning_block_;
246 std::shared_ptr<SourceLocation> source_location_;
253 topo_position_t topo_position_ = 0;
266 Node* next_in_graph[2] = {
nullptr,
nullptr};
269 return next_in_graph[kNextDirection];
272 return next_in_graph[kPrevDirection];
274 Node*
const& next()
const {
275 return next_in_graph[kNextDirection];
277 Node*
const& prev()
const {
278 return next_in_graph[kPrevDirection];
284 Node* setSourceLocation(std::shared_ptr<SourceLocation> sl) {
285 source_location_ = std::move(sl);
288 std::shared_ptr<SourceLocation> getSourceLocation()
const {
289 return source_location_;
291 Graph* owningGraph() {
294 const Graph* owningGraph()
const {
297 Block* owningBlock() {
298 return owning_block_;
300 const Block* owningBlock()
const {
301 return owning_block_;
307 scope_ = std::move(scope);
309 std::string scopeName()
const {
313 return scope_->namesFromRoot();
327 return {inputs_.data(), inputs_.size()};
341 return {outputs_.data(), outputs_.size()};
343 Value* output(
size_t i)
const {
344 return outputs_.at(i);
346 bool hasUses()
const {
347 for (
auto o : outputs()) {
348 if (!o->uses().empty()) {
355 TORCH_API
void replaceAllUsesWith(Node* n);
360 AT_ASSERT(inputs_.size() == 1);
361 return inputs_.at(0);
364 AT_ASSERT(outputs_.size() == 1);
365 return outputs_.at(0);
367 const Value* output()
const {
368 AT_ASSERT(outputs_.size() == 1);
369 return outputs_.at(0);
371 const Value* input()
const {
372 AT_ASSERT(inputs_.size() == 1);
373 return inputs_.at(0);
376 Value* input(
size_t i)
const {
377 return inputs_.at(i);
384 template <
typename T>
386 if (
auto v =
get(name)) {
387 return v->template to<T>();
393 bool is_constant(
Symbol name)
const {
394 return static_cast<bool>(
get(name));
396 TORCH_API
bool mustBeNone()
const;
398 TORCH_API
bool isNondeterministic()
const;
399 TORCH_API
bool hasSideEffects()
const;
425 TORCH_API
Value* insertInput(
size_t i,
Value* value);
433 TORCH_API
Value* replaceInput(
size_t i,
Value* newValue);
441 TORCH_API
void replaceInputWith(
Value* from,
Value* to);
443 TORCH_API
Value* addOutput();
445 TORCH_API
Value* insertOutput(
size_t i);
447 TORCH_API
void eraseOutput(
size_t i);
449 TORCH_API
Block* addBlock();
450 TORCH_API
void eraseBlock(
size_t i);
475 return {blocks_.data(), blocks_.size()};
479 TORCH_API
bool isBefore(
const Node* n)
const;
482 TORCH_API
bool isAfter(
const Node* n)
const;
494 TORCH_API Node* insertBefore(Node* n);
506 TORCH_API Node* insertAfter(Node* n);
519 TORCH_API
void moveAfter(Node* n);
532 TORCH_API
void moveBefore(Node* n);
542 TORCH_API
void removeInput(
size_t i);
549 TORCH_API
void removeAllInputs();
557 return iterator().reverse();
563 return iterator().reverse();
574 TORCH_API
void destroy();
582 template <
typename T>
584 if (T::Kind == kind()) {
585 return static_cast<T*
>(
this);
589 template <
typename T>
594 T::Kind.toDisplayString(),
596 kind().toDisplayString());
597 return static_cast<T*
>(
this);
601 TORCH_API
bool matches(
602 const char* signature_literal,
618 std::vector<const Node*>* groups)
const;
620 virtual ~Node() =
default;
623 void copyAttributes(
const Node& rhs) {
625 for (
const AVPtr& i : rhs.values_) {
626 values_.push_back(i->clone());
629 bool hasAttribute(
Symbol name)
const {
630 AT_ASSERT(name.is_attr());
631 return findAttr(name,
false) != values_.end();
633 bool hasAttributeS(
const std::string& name)
const {
634 return hasAttribute(Symbol::attr(name));
636 AttributeKind kindOf(
Symbol name)
const {
637 AT_ASSERT(name.is_attr());
638 return (*findAttr(name,
true))->kind();
640 AttributeKind kindOfS(
const std::string& name)
const {
641 return kindOf(Symbol::attr(name));
643 Node* removeAttribute(
Symbol name) {
644 AT_ASSERT(name.is_attr());
645 values_.erase(findAttr(name,
true));
648 Node* removeAttributeS(
const std::string& name) {
649 return removeAttribute(Symbol::attr(name));
651 bool hasAttributes()
const {
652 return values_.size() > 0;
654 size_t numAttributes()
const {
655 return values_.size();
658 std::vector<Symbol> attributeNames()
const {
659 std::vector<Symbol> names;
660 for (
const AVPtr& a : values_) {
661 names.push_back(a->name);
665 std::vector<const char*> attributeNamesS()
const {
666 std::vector<const char*> names;
667 for (
const AVPtr& a : values_) {
668 names.push_back(a->name.toUnqualString());
673 #define CREATE_ACCESSOR(Kind, method) \ 674 Node* method##_(Symbol name, Kind##Attr::ConstructorType v) { \ 675 return setAttr<Kind##Attr>( \ 676 name, std::forward<Kind##Attr::ConstructorType>(v)); \ 678 const Kind##Attr::ValueType& method(Symbol name) const { \ 679 return getAttr<Kind##Attr>(name); \ 682 CREATE_ACCESSOR(Float, f)
683 CREATE_ACCESSOR(Floats, fs)
684 CREATE_ACCESSOR(String, s)
685 CREATE_ACCESSOR(Strings, ss)
686 CREATE_ACCESSOR(Int, i)
687 CREATE_ACCESSOR(Ints, is)
688 CREATE_ACCESSOR(
Graph, g)
689 CREATE_ACCESSOR(Graphs, gs)
691 #undef CREATE_ACCESSOR 695 GraphAttr::ValueType& g(
Symbol name) {
696 return getAttr<GraphAttr>(name);
701 AT_ASSERT(!v.defined() || v.is_variable());
702 return setAttr<TensorAttr>(
703 name, std::forward<TensorAttr::ConstructorType>(v));
706 return getAttr<TensorAttr>(name);
709 Node* ts_(
Symbol name, TensorsAttr::ConstructorType v) {
713 return setAttr<TensorsAttr>(
714 name, std::forward<TensorsAttr::ConstructorType>(v));
716 const TensorsAttr::ValueType& ts(
Symbol name)
const {
717 return getAttr<TensorsAttr>(name);
721 void printAttrValue(std::ostream& out,
const Symbol& name)
const;
722 void printAttributes(std::ostream& out,
bool ignore_subgraph)
const;
724 template <
typename T>
725 Node* setAttr(
Symbol name,
typename T::ConstructorType v) {
726 AT_ASSERT(name.is_attr());
727 auto it = findAttr(name,
false);
728 auto nv = AVPtr(
new T(name, std::forward<typename T::ConstructorType>(v)));
729 if (it == values_.end()) {
730 values_.push_back(std::move(nv));
736 template <
typename T>
737 typename T::ValueType& getAttr(
Symbol name)
const {
738 AT_ASSERT(name.is_attr());
739 auto it = findAttr(name,
true);
740 auto* child =
dynamic_cast<T*
>(it->get());
741 if (child ==
nullptr) {
744 return child->value();
746 using AVPtr = AttributeValue::Ptr;
750 std::vector<AVPtr> values_;
751 std::vector<AVPtr>::iterator findAttr(
Symbol name,
bool required) {
752 AT_ASSERT(name.is_attr());
753 auto it = std::find_if(values_.begin(), values_.end(), [&](
const AVPtr& v) {
754 return v->name == name;
756 if (required && it == values_.end()) {
759 AT_ASSERT(!required || it != values_.end());
762 std::vector<AVPtr>::const_iterator findAttr(
Symbol name,
bool required)
764 AT_ASSERT(name.is_attr());
765 auto it = std::find_if(values_.begin(), values_.end(), [&](
const AVPtr& v) {
766 return v->name == name;
768 if (required && it == values_.end()) {
771 AT_ASSERT(!required || it != values_.end());
775 enum class MoveSide { BEFORE, AFTER };
776 bool isBeforeOrAfter(
const Node* n, MoveSide moveSide)
const;
778 std::pair<Value*, const Argument&> findInput(
Symbol name);
779 void findSchema()
const;
782 TORCH_API use_list::iterator findUseForInput(
size_t i);
787 TORCH_API
Value* dropInput(
size_t i);
789 bool inBlockList()
const {
790 if (next() ==
nullptr) {
791 AT_ASSERT(prev() ==
nullptr);
793 return next() !=
nullptr;
796 TORCH_API
void removeFromList();
797 TORCH_API
void lint()
const;
799 void assignTopoPosition();
807 virtual Node* allocNewInstance(
Graph* g) {
808 return new Node(g, kind());
815 TORCH_API
virtual void cloneFrom(Node* s);
822 TH_DISALLOW_COPY_AND_ASSIGN(
Block);
826 return input_->outputs();
829 const auto& inputs = input_->outputs();
830 return {inputs.data(), inputs.size()};
833 return output_->inputs();
836 return static_cast<const Node*
>(output_)->inputs();
839 return {output_, kNextDirection};
842 return {output_, kNextDirection};
844 Node* return_node() {
847 const Node* return_node()
const {
853 const Node* param_node()
const {
856 Graph* owningGraph() {
859 const Graph* owningGraph()
const {
865 const Node* owningNode()
const {
869 Value* addInput(std::string name =
"") {
870 Value* v = input_->addOutput();
871 v->setUniqueName(std::move(name));
874 Value* insertInput(
size_t i, std::string name =
"") {
875 Value* v = input_->insertOutput(i);
876 v->setUniqueName(std::move(name));
879 void eraseInput(
size_t i) {
880 input_->eraseOutput(i);
882 size_t registerOutput(
Value* v) {
883 output_->addInput(v);
884 return outputs().size() - 1;
886 size_t insertOutput(
size_t i,
Value* n) {
887 output_->insertInput(i, n);
890 void eraseOutput(
size_t i) {
891 output_->removeInput(i);
895 AT_ASSERT(n->graph_ == graph_ && !n->inBlockList());
896 n->insertBefore(output_);
900 AT_ASSERT(n->graph_ == graph_ && !n->inBlockList());
901 n->insertAfter(output_);
908 TORCH_API
void cloneFrom(Block* src, std::function<
Value*(
Value*)> value_map);
911 void reIndexTopology();
937 TH_DISALLOW_COPY_AND_ASSIGN(
Graph);
947 std::unordered_set<const Node*> all_nodes;
948 std::unordered_set<const Value*> all_values;
949 std::unordered_set<const Block*> all_blocks;
952 std::unordered_map<std::string, Value*> unique_names_;
959 Node* insert_before_;
964 current_scope_(std::move(scope_root)),
965 block_(
new Block(
this,
nullptr)),
966 insert_before_(return_node()) {}
968 Graph() :
Graph(c10::make_intrusive<Scope>()) {}
971 return block_->inputs();
974 const Block& block = *block_;
975 return block.inputs();
978 return block_->outputs();
981 const Block& block = *block_;
982 return block.outputs();
985 return block_->nodes();
988 const Block& block = *block_;
989 return block.nodes();
992 return block_->param_node();
994 const Node* param_node()
const {
995 return block_->param_node();
997 Node* return_node() {
998 return block_->return_node();
1000 const Node* return_node()
const {
1001 return block_->return_node();
1003 const std::unordered_map<std::string, Value*>& uniqueNames()
const {
1004 return unique_names_;
1007 void push_scope(
const std::string& scope_name) {
1008 current_scope_ = current_scope_->push(Symbol::scope(scope_name));
1011 current_scope_ = current_scope_->parent();
1014 return current_scope_;
1016 void set_current_scope(
ScopePtr scope) {
1017 current_scope_ = std::move(scope);
1020 Value* addInput(std::string name =
"") {
1021 return block_->addInput(std::move(name));
1023 Value* insertInput(
size_t i, std::string name =
"") {
1024 return block_->insertInput(i, std::move(name));
1026 void eraseInput(
size_t i) {
1027 block_->eraseInput(i);
1029 size_t registerOutput(
Value* n) {
1030 return block_->registerOutput(n);
1032 void eraseOutput(
size_t i) {
1033 block_->eraseOutput(i);
1036 TORCH_API
Node* create(
NodeKind kind,
size_t num_outputs = 1);
1037 TORCH_API
Node* create(
1040 size_t num_outputs = 1);
1042 TORCH_API
Node* createNone(
1044 TORCH_API
Node* createAutogradZero();
1045 TORCH_API
Node* createFusionGroup();
1046 TORCH_API
Node* createDifferentiableSubgraph();
1047 TORCH_API
Node* createTuple(
1050 TORCH_API
Node* createTupleUnpack(
Value* v);
1051 TORCH_API
Node* createTupleIndex(
Value* tup, int64_t index);
1052 TORCH_API
Node* createTupleSlice(
Value* tup, int64_t beg, int64_t end);
1053 TORCH_API
Node* createList(
1054 const TypePtr& elem_type,
1056 TORCH_API
Node* createListUnpack(
Value* v,
size_t size);
1057 TORCH_API
Node* createDict(
1058 const TypePtr& key_type,
1059 const TypePtr& value_type,
1063 TORCH_API
Node* createNumToTensor(
Value* value);
1064 TORCH_API
Node* createImplicitTensorToNum(
const TypePtr& type,
Value* value);
1065 TORCH_API
Node* createObject(
const ClassTypePtr& type);
1066 TORCH_API
Node* createSetAttr(
1068 const std::string& field,
1070 TORCH_API
Node* createGetAttr(
Value* obj,
const std::string& field);
1071 Node* createPythonOp(
1073 const std::string& cconv,
1074 pyobj_list&& scalar_args);
1079 TORCH_API
Node* createClone(
1081 const std::function<
Value*(
Value*)>& value_map,
1082 bool copy_blocks =
true);
1086 TORCH_API
Value* insertConstant(
1088 const TypePtr& result_type =
nullptr,
1099 TORCH_API
Value* insert(
1106 return block_->appendNode(n);
1110 return block_->prependNode(n);
1118 insert_before_->inBlockList() &&
1119 "insert point node is no longer in a block list");
1120 return n->insertBefore(insert_before_);
1123 void setInsertPoint(
Block* b) {
1124 AT_ASSERT(b->owningGraph() ==
this);
1125 insert_before_ = b->return_node();
1130 void setInsertPoint(
Node* n) {
1131 AT_ASSERT(n->owningGraph() ==
this && n->inBlockList());
1134 Node* insertPoint() {
1135 return insert_before_;
1142 const Block* block()
const {
1147 TORCH_API
void lint()
const;
1149 TORCH_API
void dump()
const;
1153 TORCH_API std::string toString()
const;
1155 friend TORCH_API std::ostream& operator<<(std::ostream& out,
const Graph& g);
1157 TORCH_API std::ostream& prettyPrint(std::ostream& out);
1158 TORCH_API
void dumpPretty();
1160 TORCH_API std::shared_ptr<Graph> copy();
1163 TORCH_API
void freeNode(
Node* n);
1164 TORCH_API
void freeValue(
Value* v);
1165 TORCH_API
void freeBlock(
Block* b);
1176 n->owningGraph()->setInsertPoint(n);
1181 prev_->owningGraph()->setInsertPoint(prev_);
1195 : graph_(&g), prev_scope_(g.current_scope()) {
1196 g.set_current_scope(std::move(scope));
1199 graph_->set_current_scope(prev_scope_);
1207 inline Value::Value(
Node* node_,
size_t offset_)
1210 unique_(node_->graph_->next_unique_++),
1211 type_(TensorType::get()) {
1212 node_->graph_->all_values.emplace(
this);
1215 inline Value* Value::setType(TypePtr type) {
1217 type_ = std::move(type);
1218 for (
Use& use : uses_) {
1219 use.user->schema_ =
nullptr;
1224 inline Graph* Value::owningGraph() {
1225 return node()->owningGraph();
1228 inline const Graph* Value::owningGraph()
const {
1229 return node()->owningGraph();
1237 static constexpr
Symbol Kind = ::c10::prim::PythonOp;
1242 const std::string& cconv,
1243 pyobj_list&& scalar_args) {
1244 this->pyobj = std::move(pyobj);
1245 this->scalar_args = std::move(scalar_args);
1246 this->cconv = cconv;
1259 std::vector<THPObjectPtr> scalar_args;
1260 virtual std::string name()
const = 0;
1261 virtual void writeScalars(std::ostream& out)
const = 0;
1262 void cloneFrom(
Node* other_)
override = 0;
1263 Node* allocNewInstance(
Graph* g)
override = 0;
1271 bool ignore_on_export =
false;
1276 inline Node* Graph::createPythonOp(
1278 const std::string& cconv,
1279 pyobj_list&& scalar_args) {
1280 PythonOp* op = allocPythonOp(
this);
1281 return op->init(std::move(pyobj), cconv, std::move(scalar_args));
1284 TORCH_API
void LintGraph(std::shared_ptr<Graph>& graph);
1290 TORCH_API std::vector<Value*> inlineCallTo(
1294 bool unpack_outputs =
false);
An utility class for setting temporary scopes.
bool is_variable() const noexcept
Returns true if the Tensor is actually a torch::autograd::Variable.
An utility class for setting temporary insertion points.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...
C10_NODISCARD TensorOptions requires_grad(c10::optional< bool > requires_grad) const noexcept
Sets the requires_grad property of the TensorOptions.