Caffe2 - C++ API
A deep learning, cross platform ML framework
schema_type_parser.cpp
1 #include <torch/csrc/jit/script/schema_type_parser.h>
2 #include <ATen/core/interned_strings.h>
3 #include <torch/csrc/jit/alias_info.h>
4 #include <torch/csrc/jit/ir.h>
5 #include <torch/csrc/jit/script/lexer.h>
6 #include <torch/csrc/jit/script/parse_string_literal.h>
7 #include <string>
8 
9 namespace torch {
10 namespace jit {
11 namespace script {
12 
13 TypeAndAlias SchemaTypeParser::parseBaseType() {
14  static std::unordered_map<std::string, TypePtr> type_map = {
15  {"Generator", GeneratorType::get()},
16  {"ScalarType", IntType::get()},
17  {"Layout", IntType::get()},
18  {"Device", DeviceObjType::get()},
19  {"Scalar", NumberType::get()},
20  {"str", StringType::get()},
21  {"float", FloatType::get()},
22  {"int", IntType::get()},
23  {"bool", BoolType::get()},
24  };
25  auto tok = L.expect(TK_IDENT);
26  auto text = tok.text();
27  auto it = type_map.find(text);
28  if (it == type_map.end()) {
29  if (text.size() > 0 && islower(text[0])) {
30  // lower case identifiers that are not otherwise valid types
31  // are treated as type variables
32  return TypeAndAlias(VarType::create(text), parseAliasAnnotation());
33  }
34  throw ErrorReport(tok.range) << "unknown type specifier";
35  }
36  return TypeAndAlias(it->second, c10::nullopt);
37 }
38 
39 // Examples:
40 // Tensor(a) // Tensor is in set a
41 // Tensor(a!) // it is also written to
42 // Tensor! // shorthand for Tensor(fresh_identifier!)
43 // Tensor(a! -> a|b) // Tensor is in set a, written to,
44 // and after the write is in set a AND b.
45 c10::optional<AliasInfo> SchemaTypeParser::parseAliasAnnotation() {
46  std::set<Symbol> sets;
47  AliasInfo alias_info;
48  if (L.nextIf('(')) {
49  // optional 'alias set annotation'
50  parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
51  if (L.nextIf('*')) {
52  alias_info = AliasInfo::createWildcard();
53 
54  // If we found a wildcard, ignore all subsequent annotations
55  } else if (!alias_info.isWildcard()) {
56  alias_info.addBeforeSet(
57  Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
58  }
59  });
60  if (L.nextIf('!')) {
61  alias_info.setIsWrite(true);
62  }
63  if (L.nextIf(TK_ARROW)) {
64  // optional 'alias set annotation'
65  parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
66  if (L.cur().kind == '*') {
67  L.reportError("Wildcard not allowed as part of the output set");
68  }
69  alias_info.addAfterSet(
70  Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
71  });
72  } else {
73  // We didn't encounter an ->, so assume the "after set" is identical
74  // to the "before set"
75  AT_ASSERT(alias_info.afterSets().empty());
76  for (const auto& set : alias_info.beforeSets()) {
77  alias_info.addAfterSet(set);
78  }
79  }
80  L.expect(')');
81  } else if (L.nextIf('!')) {
82  alias_info.addBeforeSet(
83  Symbol::fromQualString("alias::$" + std::to_string(next_id++)));
84  alias_info.setIsWrite(true);
85  } else {
86  return c10::nullopt;
87  }
88 
89  return alias_info;
90 }
91 
92 c10::optional<at::ScalarType> SchemaTypeParser::parseTensorDType(
93  const std::string& dtype) {
94 #define DEFINE_SCALAR_TYPE(_1, n, _2) {#n, at::ScalarType::n},
95 
96  static std::unordered_map<std::string, at::ScalarType> type_map = {
97  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_SCALAR_TYPE)};
98 
99  auto type = type_map.find(dtype);
100  if (type != type_map.end()) {
101  return type->second;
102  }
103  return c10::nullopt;
104 }
105 
106 TypePtr SchemaTypeParser::parseRefinedTensor() {
107  auto maybe_dtype = parseTensorDType(L.expect(TK_IDENT).text());
108  AT_ASSERT(maybe_dtype);
109  at::ScalarType dtype = *maybe_dtype;
110  TypePtr ptr;
111  L.expect('(');
112  TypePtr tensor_type;
113  if (L.cur().kind == '*') {
114  size_t num_dims = 0;
115  parseList(TK_NOTHING, ',', ')', [&] {
116  L.expect('*');
117  num_dims++;
118  });
119  ptr = DimensionedTensorType::create(dtype, at::DeviceType::CPU, num_dims);
120  } else {
121  std::vector<int64_t> dims;
122  parseList(TK_NOTHING, ',', ')', [&] {
123  const std::string& num = L.expect(TK_NUMBER).text();
124  std::string::size_type num_len;
125  size_t dim = std::stoi(num, &num_len);
126  AT_ASSERTM(
127  num_len == num.size(),
128  "Bad tensor dimension size. Strides not yet supported in parsing",
129  num);
130  dims.push_back(dim);
131  });
132  at::IntArrayRef dims_ref(dims);
133  ptr =
134  CompleteTensorType::create(dtype, at::DeviceType::CPU, dims_ref, false);
135  }
136  return ptr;
137 }
138 
139 std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
140  TypePtr value;
141  c10::optional<AliasInfo> alias_info;
142  // Tuple type
143  if (L.cur().kind == '(') {
144  std::vector<TypePtr> types;
145  parseList('(', ',', ')', [&] {
146  auto r = parseType();
147  types.push_back(std::move(r.first));
148  if (alias_info && r.second) {
149  alias_info->addContainedType(std::move(*r.second));
150  }
151  });
152  value = TupleType::create(std::move(types));
153  } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
154  L.next(); // Future
155  L.expect('(');
156  auto p = parseType();
157  auto subtype = std::move(p.first);
158  auto subalias = std::move(p.second);
159  L.expect(')');
160  value = FutureType::create(subtype);
161  } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
162  L.next();
163  value = TensorType::get();
164  alias_info = parseAliasAnnotation();
165  } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
166  L.next();
167  L.expect('(');
168  auto key_type = parseType().first;
169  L.expect(',');
170  auto value_type = parseType().first;
171  L.expect(')');
172  alias_info = parseAliasAnnotation();
173  value = DictType::create(key_type, value_type);
174  } else if (
175  complete_tensor_types && L.cur().kind == TK_IDENT &&
176  parseTensorDType(L.cur().text())) {
177  value = parseRefinedTensor();
178  alias_info = parseAliasAnnotation();
179  } else {
180  auto value_alias = parseBaseType();
181  value = value_alias.first;
182  alias_info = value_alias.second;
183  }
184  while (true) {
185  if (L.cur().kind == '[' && L.lookahead().kind == ']') {
186  L.next(); // [
187  L.next(); // ]
188  value = ListType::create(value);
189  auto container = parseAliasAnnotation();
190  if (container && alias_info) {
191  container->addContainedType(std::move(*alias_info));
192  }
193  alias_info = std::move(container);
194  } else if (L.nextIf('?')) {
195  value = OptionalType::create(value);
196  } else {
197  break;
198  }
199  }
200  return std::make_pair(std::move(value), std::move(alias_info));
201 }
202 
203 void SchemaTypeParser::parseList(
204  int begin,
205  int sep,
206  int end,
207  const std::function<void()>& callback) {
208  auto r = L.cur().range;
209  if (begin != TK_NOTHING)
210  L.expect(begin);
211  if (L.cur().kind != end) {
212  do {
213  callback();
214  } while (L.nextIf(sep));
215  }
216  if (end != TK_NOTHING)
217  L.expect(end);
218 }
219 } // namespace script
220 } // namespace jit
221 } // namespace torch
Definition: jit_type.h:17