1 #include <torch/csrc/jit/script/schema_type_parser.h> 2 #include <ATen/core/interned_strings.h> 3 #include <torch/csrc/jit/alias_info.h> 4 #include <torch/csrc/jit/ir.h> 5 #include <torch/csrc/jit/script/lexer.h> 6 #include <torch/csrc/jit/script/parse_string_literal.h> 13 TypeAndAlias SchemaTypeParser::parseBaseType() {
14 static std::unordered_map<std::string, TypePtr> type_map = {
15 {
"Generator", GeneratorType::get()},
16 {
"ScalarType", IntType::get()},
17 {
"Layout", IntType::get()},
18 {
"Device", DeviceObjType::get()},
19 {
"Scalar", NumberType::get()},
20 {
"str", StringType::get()},
21 {
"float", FloatType::get()},
22 {
"int", IntType::get()},
23 {
"bool", BoolType::get()},
25 auto tok = L.expect(TK_IDENT);
26 auto text = tok.text();
27 auto it = type_map.find(text);
28 if (it == type_map.end()) {
29 if (text.size() > 0 && islower(text[0])) {
32 return TypeAndAlias(VarType::create(text), parseAliasAnnotation());
34 throw ErrorReport(tok.range) <<
"unknown type specifier";
36 return TypeAndAlias(it->second, c10::nullopt);
46 std::set<Symbol> sets;
50 parseList(TK_NOTHING,
'|', TK_NOTHING, [&] {
52 alias_info = AliasInfo::createWildcard();
55 }
else if (!alias_info.isWildcard()) {
56 alias_info.addBeforeSet(
57 Symbol::fromQualString(
"alias::" + L.expect(TK_IDENT).text()));
61 alias_info.setIsWrite(
true);
63 if (L.nextIf(TK_ARROW)) {
65 parseList(TK_NOTHING,
'|', TK_NOTHING, [&] {
66 if (L.cur().kind ==
'*') {
67 L.reportError(
"Wildcard not allowed as part of the output set");
69 alias_info.addAfterSet(
70 Symbol::fromQualString(
"alias::" + L.expect(TK_IDENT).text()));
75 AT_ASSERT(alias_info.afterSets().empty());
76 for (
const auto&
set : alias_info.beforeSets()) {
77 alias_info.addAfterSet(
set);
81 }
else if (L.nextIf(
'!')) {
82 alias_info.addBeforeSet(
83 Symbol::fromQualString(
"alias::$" + std::to_string(next_id++)));
84 alias_info.setIsWrite(
true);
93 const std::string& dtype) {
94 #define DEFINE_SCALAR_TYPE(_1, n, _2) {#n, at::ScalarType::n}, 96 static std::unordered_map<std::string, at::ScalarType> type_map = {
97 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_SCALAR_TYPE)};
99 auto type = type_map.find(dtype);
100 if (type != type_map.end()) {
106 TypePtr SchemaTypeParser::parseRefinedTensor() {
107 auto maybe_dtype = parseTensorDType(L.expect(TK_IDENT).text());
108 AT_ASSERT(maybe_dtype);
109 at::ScalarType dtype = *maybe_dtype;
113 if (L.cur().kind ==
'*') {
115 parseList(TK_NOTHING,
',',
')', [&] {
119 ptr = DimensionedTensorType::create(dtype, at::DeviceType::CPU, num_dims);
121 std::vector<int64_t> dims;
122 parseList(TK_NOTHING,
',',
')', [&] {
123 const std::string& num = L.expect(TK_NUMBER).text();
124 std::string::size_type num_len;
125 size_t dim = std::stoi(num, &num_len);
127 num_len == num.size(),
128 "Bad tensor dimension size. Strides not yet supported in parsing",
134 CompleteTensorType::create(dtype, at::DeviceType::CPU, dims_ref,
false);
139 std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
143 if (L.cur().kind ==
'(') {
144 std::vector<TypePtr> types;
145 parseList(
'(',
',',
')', [&] {
146 auto r = parseType();
147 types.push_back(std::move(r.first));
148 if (alias_info && r.second) {
149 alias_info->addContainedType(std::move(*r.second));
152 value = TupleType::create(std::move(types));
153 }
else if (L.cur().kind == TK_IDENT && L.cur().text() ==
"Future") {
156 auto p = parseType();
157 auto subtype = std::move(p.first);
158 auto subalias = std::move(p.second);
160 value = FutureType::create(subtype);
161 }
else if (L.cur().kind == TK_IDENT && L.cur().text() ==
"Tensor") {
163 value = TensorType::get();
164 alias_info = parseAliasAnnotation();
165 }
else if (L.cur().kind == TK_IDENT && L.cur().text() ==
"Dict") {
168 auto key_type = parseType().first;
170 auto value_type = parseType().first;
172 alias_info = parseAliasAnnotation();
173 value = DictType::create(key_type, value_type);
175 complete_tensor_types && L.cur().kind == TK_IDENT &&
176 parseTensorDType(L.cur().text())) {
177 value = parseRefinedTensor();
178 alias_info = parseAliasAnnotation();
180 auto value_alias = parseBaseType();
181 value = value_alias.first;
182 alias_info = value_alias.second;
185 if (L.cur().kind ==
'[' && L.lookahead().kind ==
']') {
188 value = ListType::create(value);
189 auto container = parseAliasAnnotation();
190 if (container && alias_info) {
191 container->addContainedType(std::move(*alias_info));
193 alias_info = std::move(container);
194 }
else if (L.nextIf(
'?')) {
195 value = OptionalType::create(value);
200 return std::make_pair(std::move(value), std::move(alias_info));
203 void SchemaTypeParser::parseList(
207 const std::function<
void()>& callback) {
208 auto r = L.cur().range;
209 if (begin != TK_NOTHING)
211 if (L.cur().kind != end) {
214 }
while (L.nextIf(sep));
216 if (end != TK_NOTHING)