1 #include <torch/csrc/jit/script/script_type_parser.h> 2 #include <torch/csrc/jit/ir.h> 3 #include <torch/csrc/jit/script/parser.h> 4 #include <torch/csrc/jit/script/script_type_parser.h> 5 #include <torch/csrc/jit/script/tree_views.h> 11 const std::unordered_map<std::string, TypePtr>& ident_to_type_lut() {
12 static std::unordered_map<std::string, TypePtr> map = {
13 {
"Tensor", TensorType::get()},
14 {
"int", IntType::get()},
15 {
"float", FloatType::get()},
16 {
"bool", BoolType::get()},
17 {
"str", StringType::get()},
18 {
"Device", DeviceObjType::get()},
21 {
"number", NumberType::get()},
22 {
"None", NoneType::get()},
27 const std::unordered_map<std::string, std::function<TypePtr(Subscript)>>&
28 subscript_to_type_fns() {
29 static std::unordered_map<std::string, std::function<TypePtr(Subscript)>>
32 [](Subscript subscript) -> TypePtr {
33 std::vector<TypePtr> subscript_expr_types;
34 for (
auto expr : subscript.subscript_exprs()) {
35 subscript_expr_types.push_back(parseTypeFromExpr(expr));
37 return TupleType::create(subscript_expr_types);
40 [](Subscript subscript) -> TypePtr {
41 if (subscript.subscript_exprs().size() != 1) {
42 throw ErrorReport(subscript)
43 <<
" expected exactly one element type but found " 44 << subscript.subscript_exprs().size();
47 parseTypeFromExpr(*subscript.subscript_exprs().begin());
48 return ListType::create(elem_type);
51 [](Subscript subscript) -> TypePtr {
52 if (subscript.subscript_exprs().size() != 1) {
53 throw ErrorReport(subscript)
54 <<
" expected exactly one element type but found " 55 << subscript.subscript_exprs().size();
58 parseTypeFromExpr(*subscript.subscript_exprs().begin());
59 return OptionalType::create(elem_type);
62 [](Subscript subscript) -> TypePtr {
63 if (subscript.subscript_exprs().size() != 1) {
64 throw ErrorReport(subscript)
65 <<
" expected exactly one element type but found " 66 << subscript.subscript_exprs().size();
69 parseTypeFromExpr(*subscript.subscript_exprs().begin());
70 return FutureType::create(elem_type);
73 [](Subscript subscript) -> TypePtr {
74 if (subscript.subscript_exprs().size() != 2) {
75 throw ErrorReport(subscript)
76 <<
" expected exactly 2 element types but found " 77 << subscript.subscript_exprs().size();
79 auto key_type = parseTypeFromExpr(subscript.subscript_exprs()[0]);
81 parseTypeFromExpr(subscript.subscript_exprs()[1]);
82 return DictType::create(key_type, value_type);
88 bool isTorch(
const Expr& expr) {
89 return expr.kind() == TK_VAR && Var(expr).name().name() ==
"torch";
94 if (expr.kind() != TK_SUBSCRIPT)
96 auto subscript = Subscript(expr);
97 if (subscript.value().kind() != TK_VAR)
99 auto var = Var(subscript.value());
100 auto subscript_exprs = subscript.subscript_exprs();
103 if (var.
name().name() ==
"Optional") {
104 auto broadcast_list = parseBroadcastList(subscript_exprs[0]);
105 if (broadcast_list) {
106 TypePtr opt_type = OptionalType::create(broadcast_list->first);
107 return std::pair<TypePtr, int32_t>(opt_type, broadcast_list->second);
111 }
else if (var.
name().name().find(
"BroadcastingList") != 0) {
115 if (subscript_exprs.size() != 1)
116 throw ErrorReport(subscript.subscript_exprs().range())
117 <<
"BroadcastingList/Optional[BroadcastingList] must be subscripted with a type";
119 auto typ = subscript_exprs[0];
120 auto len = var.
name().name().substr(strlen(
"BroadcastingList"));
122 if (typ.kind() != TK_VAR)
123 throw ErrorReport(subscript.value().range())
124 <<
"Subscripted type must be a type identifier";
126 auto value_name = Var(typ).name().name();
127 if (value_name !=
"float" && value_name !=
"int")
128 throw ErrorReport(subscript.value().range())
129 <<
"Broadcastable lists only supported for int or float";
131 auto elem_ptr = ident_to_type_lut().find(value_name);
132 AT_ASSERT(elem_ptr != ident_to_type_lut().end());
133 TypePtr list_ptr = ListType::create(elem_ptr->second);
135 const char* len_c = len.c_str();
137 size_t len_v = strtoull(len_c, &end, 10);
138 if (end != len_c + len.size()) {
139 throw ErrorReport(subscript.subscript_exprs().range())
140 <<
"subscript of Broadcastable list must be a positive integer";
142 return std::pair<TypePtr, int32_t>(list_ptr, len_v);
148 switch (expr.kind()) {
150 return Var(expr).name().name();
156 auto select = Select(expr);
157 const std::string& name = select.selector().name();
158 if (isTorch(select.value()) && name ==
"Tensor")
165 TypePtr parseTypeFromExpr(
const Expr& expr) {
166 if (expr.kind() == TK_SUBSCRIPT) {
167 auto subscript = Subscript(expr);
168 auto value_name = parseBaseTypeName(subscript.value());
170 throw ErrorReport(subscript.value().range())
171 <<
"Subscripted type must be a type identifier";
173 if (!subscript_to_type_fns().count(*value_name)) {
174 throw ErrorReport(subscript.range())
175 <<
"Unknown type constructor " << *value_name;
177 return subscript_to_type_fns().at(*value_name)(subscript);
178 }
else if (
auto name = parseBaseTypeName(expr)) {
179 auto itr = ident_to_type_lut().find(*name);
180 if (itr != ident_to_type_lut().end()) {
183 if (
auto typePtr = ClassType::get(*name)) {
186 throw ErrorReport(expr) <<
"Unknown type name " << *name;
188 throw ErrorReport(expr.range())
189 <<
"Expression of type " << kindToString(expr.kind())
190 <<
" cannot be used in a type expression";
193 TypePtr parseType(
const std::string& str) {
195 return parseTypeFromExpr(p.parseExp());
const std::string & name() const noexcept
Returns the name of the Module.