Caffe2 - C++ API
A deep learning, cross platform ML framework
script_type_parser.cpp
1 #include <torch/csrc/jit/script/script_type_parser.h>
2 #include <torch/csrc/jit/ir.h>
3 #include <torch/csrc/jit/script/parser.h>
4 #include <torch/csrc/jit/script/script_type_parser.h>
5 #include <torch/csrc/jit/script/tree_views.h>
6 
7 namespace torch {
8 namespace jit {
9 namespace script {
10 
11 const std::unordered_map<std::string, TypePtr>& ident_to_type_lut() {
12  static std::unordered_map<std::string, TypePtr> map = {
13  {"Tensor", TensorType::get()},
14  {"int", IntType::get()},
15  {"float", FloatType::get()},
16  {"bool", BoolType::get()},
17  {"str", StringType::get()},
18  {"Device", DeviceObjType::get()},
19  // technically this is not a python type but we need it when
20  // parsing serialized methods that use implicit converions to Scalar
21  {"number", NumberType::get()},
22  {"None", NoneType::get()},
23  };
24  return map;
25 }
26 
27 const std::unordered_map<std::string, std::function<TypePtr(Subscript)>>&
28 subscript_to_type_fns() {
29  static std::unordered_map<std::string, std::function<TypePtr(Subscript)>>
30  map = {
31  {"Tuple",
32  [](Subscript subscript) -> TypePtr {
33  std::vector<TypePtr> subscript_expr_types;
34  for (auto expr : subscript.subscript_exprs()) {
35  subscript_expr_types.push_back(parseTypeFromExpr(expr));
36  }
37  return TupleType::create(subscript_expr_types);
38  }},
39  {"List",
40  [](Subscript subscript) -> TypePtr {
41  if (subscript.subscript_exprs().size() != 1) {
42  throw ErrorReport(subscript)
43  << " expected exactly one element type but found "
44  << subscript.subscript_exprs().size();
45  }
46  auto elem_type =
47  parseTypeFromExpr(*subscript.subscript_exprs().begin());
48  return ListType::create(elem_type);
49  }},
50  {"Optional",
51  [](Subscript subscript) -> TypePtr {
52  if (subscript.subscript_exprs().size() != 1) {
53  throw ErrorReport(subscript)
54  << " expected exactly one element type but found "
55  << subscript.subscript_exprs().size();
56  }
57  auto elem_type =
58  parseTypeFromExpr(*subscript.subscript_exprs().begin());
59  return OptionalType::create(elem_type);
60  }},
61  {"Future",
62  [](Subscript subscript) -> TypePtr {
63  if (subscript.subscript_exprs().size() != 1) {
64  throw ErrorReport(subscript)
65  << " expected exactly one element type but found "
66  << subscript.subscript_exprs().size();
67  }
68  auto elem_type =
69  parseTypeFromExpr(*subscript.subscript_exprs().begin());
70  return FutureType::create(elem_type);
71  }},
72  {"Dict",
73  [](Subscript subscript) -> TypePtr {
74  if (subscript.subscript_exprs().size() != 2) {
75  throw ErrorReport(subscript)
76  << " expected exactly 2 element types but found "
77  << subscript.subscript_exprs().size();
78  }
79  auto key_type = parseTypeFromExpr(subscript.subscript_exprs()[0]);
80  auto value_type =
81  parseTypeFromExpr(subscript.subscript_exprs()[1]);
82  return DictType::create(key_type, value_type);
83  }},
84  };
85  return map;
86 }
87 
88 bool isTorch(const Expr& expr) {
89  return expr.kind() == TK_VAR && Var(expr).name().name() == "torch";
90 }
91 
93  const Expr& expr) {
94  if (expr.kind() != TK_SUBSCRIPT)
95  return c10::nullopt;
96  auto subscript = Subscript(expr);
97  if (subscript.value().kind() != TK_VAR)
98  return c10::nullopt;
99  auto var = Var(subscript.value());
100  auto subscript_exprs = subscript.subscript_exprs();
101 
102  // handle the case where the BroadcastingList is wrapped in a Optional type
103  if (var.name().name() == "Optional") {
104  auto broadcast_list = parseBroadcastList(subscript_exprs[0]);
105  if (broadcast_list) {
106  TypePtr opt_type = OptionalType::create(broadcast_list->first);
107  return std::pair<TypePtr, int32_t>(opt_type, broadcast_list->second);
108  } else {
109  return c10::nullopt;
110  }
111  } else if (var.name().name().find("BroadcastingList") != 0) {
112  return c10::nullopt;
113  }
114 
115  if (subscript_exprs.size() != 1)
116  throw ErrorReport(subscript.subscript_exprs().range())
117  << "BroadcastingList/Optional[BroadcastingList] must be subscripted with a type";
118 
119  auto typ = subscript_exprs[0];
120  auto len = var.name().name().substr(strlen("BroadcastingList"));
121 
122  if (typ.kind() != TK_VAR)
123  throw ErrorReport(subscript.value().range())
124  << "Subscripted type must be a type identifier";
125 
126  auto value_name = Var(typ).name().name();
127  if (value_name != "float" && value_name != "int")
128  throw ErrorReport(subscript.value().range())
129  << "Broadcastable lists only supported for int or float";
130 
131  auto elem_ptr = ident_to_type_lut().find(value_name);
132  AT_ASSERT(elem_ptr != ident_to_type_lut().end());
133  TypePtr list_ptr = ListType::create(elem_ptr->second);
134 
135  const char* len_c = len.c_str();
136  char* end;
137  size_t len_v = strtoull(len_c, &end, 10);
138  if (end != len_c + len.size()) {
139  throw ErrorReport(subscript.subscript_exprs().range())
140  << "subscript of Broadcastable list must be a positive integer";
141  }
142  return std::pair<TypePtr, int32_t>(list_ptr, len_v);
143 }
144 
145 // gets the base type name given namespaces where the types live
146 // turns torch.Tensor -> Tensor, X -> X
147 c10::optional<std::string> parseBaseTypeName(const Expr& expr) {
148  switch (expr.kind()) {
149  case TK_VAR: {
150  return Var(expr).name().name();
151  }
152  case TK_NONE: {
153  return "None";
154  }
155  case '.': {
156  auto select = Select(expr);
157  const std::string& name = select.selector().name();
158  if (isTorch(select.value()) && name == "Tensor")
159  return "Tensor";
160  } break;
161  }
162  return at::nullopt;
163 }
164 
165 TypePtr parseTypeFromExpr(const Expr& expr) {
166  if (expr.kind() == TK_SUBSCRIPT) {
167  auto subscript = Subscript(expr);
168  auto value_name = parseBaseTypeName(subscript.value());
169  if (!value_name) {
170  throw ErrorReport(subscript.value().range())
171  << "Subscripted type must be a type identifier";
172  }
173  if (!subscript_to_type_fns().count(*value_name)) {
174  throw ErrorReport(subscript.range())
175  << "Unknown type constructor " << *value_name;
176  }
177  return subscript_to_type_fns().at(*value_name)(subscript);
178  } else if (auto name = parseBaseTypeName(expr)) {
179  auto itr = ident_to_type_lut().find(*name);
180  if (itr != ident_to_type_lut().end()) {
181  return itr->second;
182  }
183  if (auto typePtr = ClassType::get(*name)) {
184  return typePtr;
185  }
186  throw ErrorReport(expr) << "Unknown type name " << *name;
187  }
188  throw ErrorReport(expr.range())
189  << "Expression of type " << kindToString(expr.kind())
190  << " cannot be used in a type expression";
191 }
192 
193 TypePtr parseType(const std::string& str) {
194  Parser p(str);
195  return parseTypeFromExpr(p.parseExp());
196 }
197 } // namespace script
198 } // namespace jit
199 } // namespace torch
const std::string & name() const noexcept
Returns the name of the Module.
Definition: module.cpp:53
Definition: jit_type.h:17