7 from textwrap
import dedent
8 from functools
import partial
9 from collections
import namedtuple
13 _reserved_prefix =
'__jit' 14 _reserved_names = {
'print'}
15 _identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits)
18 def is_reserved_name(name):
19 return name.startswith(_reserved_prefix)
or name
in _reserved_names
23 ast.FunctionDef:
"function definitions",
25 ast.Delete:
"del statements",
26 ast.ClassDef:
"class definitions",
27 ast.With:
"with statements",
28 ast.Raise:
"raise statements",
29 ast.Assert:
"assertions",
30 ast.Import:
"import statements",
31 ast.ImportFrom:
"import statements",
32 ast.Global:
"global variables",
33 ast.Break:
"break statements",
34 ast.Continue:
"continue statements",
38 ast.FunctionDef:
"def",
41 ast.ClassDef:
"class",
46 ast.ImportFrom:
"from",
49 ast.Continue:
"continue",
53 pretty_node_names.update({
54 ast.Print:
"print statements",
55 ast.TryExcept:
"try blocks",
56 ast.TryFinally:
"try blocks",
57 ast.Exec:
"exec statements",
60 node_start_tokens.update({
63 ast.TryFinally:
"try",
67 pretty_node_names.update({
68 ast.AsyncFunctionDef:
"async function definitions",
69 ast.AsyncFor:
"async for loops",
70 ast.AsyncWith:
"async with statements",
71 ast.Try:
"try blocks",
72 ast.Nonlocal:
"nonlocal variables",
75 node_start_tokens.update({
76 ast.AsyncFunctionDef:
"async def",
77 ast.AsyncFor:
"async for",
78 ast.AsyncWith:
"async with",
80 ast.Nonlocal:
"nonlocal",
83 if sys.version_info >= (3, 6):
84 pretty_node_names.update({
85 ast.AnnAssign:
"annotated assignments",
91 def __init__(self, source_range, msg):
98 result +=
'\n' + self.source_range.highlight()
106 class UnsupportedNodeError(NotSupportedError):
107 def __init__(self, ctx, offending_node):
109 node_type = type(offending_node)
110 range_len = len(node_start_tokens.get(node_type,
' '))
111 source_range = ctx.make_range(offending_node.lineno,
112 offending_node.col_offset,
113 offending_node.col_offset + range_len)
114 feature_name = pretty_node_names.get(node_type, node_type.__name__)
115 msg =
"{} aren't supported".format(feature_name)
116 super(NotSupportedError, self).__init__(source_range, msg)
123 def build_stmts(ctx, stmts):
124 stmts = [build_stmt(ctx, s)
for s
in stmts]
125 return list(filter(
None, stmts))
128 def _uses_true_division(fn):
131 if inspect.ismethod(fn):
132 return _uses_true_division(fn.__func__)
133 elif inspect.isfunction(fn):
134 return fn.__globals__.get(
'division')
is __future__.division
137 '_uses_true_division: expected function or method, got {}'.format(type(fn)))
140 def get_jit_class_def(cls, self_name=None):
142 methods = inspect.getmembers(
143 cls, predicate=
lambda m: inspect.ismethod(m)
or inspect.isfunction(m))
144 method_defs = [get_jit_def(method[1],
145 self_name=cls.__name__)
for method
in methods]
147 source = dedent(inspect.getsource(cls))
148 py_ast = ast.parse(source)
149 ctx = SourceContext(source,
False)
150 return build_class_def(ctx, py_ast.body[0], method_defs)
153 def get_jit_def(fn, self_name=None):
154 source = dedent(inspect.getsource(fn))
155 py_ast = ast.parse(source)
156 if len(py_ast.body) != 1
or not isinstance(py_ast.body[0], ast.FunctionDef):
157 raise RuntimeError(
"expected a single top-level function")
159 ctx = SourceContext(source, _uses_true_division(fn))
160 return build_def(ctx, py_ast.body[0], type_line, self_name)
166 def __init__(self, source, uses_true_division=True):
167 super(SourceContext, self).__init__(source)
172 def __call__(self, ctx, node):
173 method = getattr(self,
'build_' + node.__class__.__name__,
None)
176 return method(ctx, node)
179 def build_class_def(ctx, py_def, methods):
180 r = ctx.make_range(py_def.lineno, py_def.col_offset,
181 py_def.col_offset + len(
"class"))
182 return ClassDef(Ident(r, py_def.name), methods)
185 def build_def(ctx, py_def, type_line, self_name=None):
187 r = ctx.make_range(py_def.lineno, py_def.col_offset,
188 py_def.col_offset + len(
"def"))
189 param_list = build_param_list(ctx, py_def.args, self_name)
191 if getattr(py_def,
'returns',
None)
is not None:
192 return_type = build_expr(ctx, py_def.returns)
193 decl = Decl(r, param_list, return_type)
194 is_method = self_name
is not None 195 if type_line
is not None:
196 type_comment_decl = torch._C.parse_type_comment(type_line)
197 decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
198 return Def(Ident(r, py_def.name),
200 build_stmts(ctx, body))
203 _vararg_kwarg_err = (
"Compiled functions can't take variable number of arguments " 204 "or use keyword-only arguments with defaults")
207 def build_param_list(ctx, py_args, self_name):
208 if py_args.vararg
is not None or py_args.kwarg
is not None:
209 raise ValueError(_vararg_kwarg_err)
210 if not PY2
and py_args.kw_defaults:
211 raise ValueError(_vararg_kwarg_err)
212 result = [build_param(ctx, arg, self_name,
False)
for arg
in py_args.args]
214 result += [build_params(ctx, arg, self_name,
True)
for arg
in py_args.kwonlyargs]
218 def build_param(ctx, py_arg, self_name, kwarg_only):
221 name = py_arg.id
if PY2
else py_arg.arg
222 r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name))
223 if getattr(py_arg,
'annotation',
None)
is not None:
224 annotation_expr = build_expr(ctx, py_arg.annotation)
225 elif self_name
is not None and name ==
'self':
226 annotation_expr = Var(Ident(r, self_name))
228 annotation_expr = Var(Ident(r,
'Tensor'))
229 return Param(annotation_expr, Ident(r, name), kwarg_only)
232 def get_default_args(fn):
234 argspec = inspect.getargspec(fn)
235 if argspec.defaults
is not None:
236 return dict(zip(argspec.args[-len(argspec.defaults):], argspec.defaults))
240 signature = inspect.signature(fn)
243 for k, v
in signature.parameters.items()
244 if v.default
is not inspect.Parameter.empty
257 def build_Expr(ctx, stmt):
259 if value.__class__.__name__ ==
'Str':
264 return ExprStmt(build_expr(ctx, value))
267 def build_Assign(ctx, stmt):
268 rhs = build_expr(ctx, stmt.value)
269 if len(stmt.targets) > 1:
270 start_point = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + 1)
272 "Performing multiple assignments in a single line isn't supported")
273 lhs = build_expr(ctx, stmt.targets[0])
274 return Assign(lhs, rhs)
277 def build_Return(ctx, stmt):
278 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"return"))
279 return Return(r,
None if stmt.value
is None else build_expr(ctx, stmt.value))
282 def build_Raise(ctx, stmt):
283 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"raise"))
288 expr = build_expr(ctx, stmt.inst)
if stmt.inst
else None 290 expr = build_expr(ctx, stmt.exc)
291 return Raise(r, expr)
294 def build_Assert(ctx, stmt):
295 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"assert"))
296 test = build_expr(ctx, stmt.test)
297 msg = build_expr(ctx, stmt.msg)
if stmt.msg
is not None else None 298 return Assert(r, test, msg)
301 def build_AugAssign(ctx, stmt):
302 lhs = build_expr(ctx, stmt.target)
303 rhs = build_expr(ctx, stmt.value)
305 if op
in StmtBuilder.augassign_map:
306 op_token = StmtBuilder.augassign_map[op]
309 find_before(ctx, rhs.range().start,
'=', offsets=(-1, 0)),
310 "unsupported kind of augumented assignment: " + op.__name__)
311 return AugAssign(lhs, op_token, rhs)
314 def build_While(ctx, stmt):
319 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"while"))
320 return While(r, build_expr(ctx, stmt.test),
321 build_stmts(ctx, stmt.body))
324 def build_For(ctx, stmt):
325 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"for"))
327 r, [build_expr(ctx, stmt.target)],
328 [build_expr(ctx, stmt.iter)], build_stmts(ctx, stmt.body))
331 def build_If(ctx, stmt):
332 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"if"))
333 return If(r, build_expr(ctx, stmt.test),
334 build_stmts(ctx, stmt.body),
335 build_stmts(ctx, stmt.orelse))
338 def build_Print(ctx, stmt):
339 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"print"))
341 raise NotSupportedError(r,
"print statements with non-default destinations aren't supported")
342 args = [build_expr(ctx, val)
for val
in stmt.values]
343 return ExprStmt(Apply(Var(Ident(r,
"print")), args, []))
346 def build_Pass(ctx, stmt):
347 r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"pass"))
366 binop_map[ast.MatMult] =
'@' 390 def build_Attribute(ctx, expr):
392 value = build_expr(ctx, expr.value)
395 pos = find_after(ctx, value.range().end,
'.').end
396 while source[pos]
in string.whitespace:
399 while source[pos]
in _identifier_chars:
401 name_range = ctx.make_raw_range(start_pos, pos)
402 return Select(value, Ident(name_range, expr.attr))
405 def build_Call(ctx, expr):
406 func = build_expr(ctx, expr.func)
407 args = [build_expr(ctx, py_arg)
for py_arg
in expr.args]
408 if hasattr(expr,
'starargs')
and expr.starargs:
409 stararg_expr = build_expr(ctx, expr.starargs)
410 args += [Starred(stararg_expr.range(), stararg_expr)]
412 for kw
in expr.keywords:
413 kw_expr = build_expr(ctx, kw.value)
415 kwargs.append(
Attribute(Ident(kw_expr.range(), kw.arg), kw_expr))
416 return Apply(func, args, kwargs)
419 def build_Name(ctx, expr):
420 r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id))
421 if expr.id.startswith(_reserved_prefix):
423 "can't start with " + _reserved_prefix)
424 if expr.id ==
"True":
425 return TrueLiteral(r)
426 elif expr.id ==
"False":
427 return FalseLiteral(r)
428 elif expr.id ==
"None":
429 return NoneLiteral(r)
430 return Var(Ident(r, expr.id))
433 def build_NameConstant(ctx, expr):
434 r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(str(expr.value)))
435 if expr.value
is True:
436 return TrueLiteral(r)
437 elif expr.value
is False:
438 return FalseLiteral(r)
439 elif expr.value
is None:
440 return NoneLiteral(r)
442 raise ValueError(
"Name constant value unsupported: " + str(expr.value))
445 def build_BinOp(ctx, expr):
446 lhs = build_expr(ctx, expr.left)
447 rhs = build_expr(ctx, expr.right)
450 if op == ast.Div
and not ctx.uses_true_division:
451 raise RuntimeError(
'Division of ints in JIT script uses Python 3 true ' 452 'division semantics. Please put `from __future__ ' 453 'import division` at the top of your file')
455 op_token = ExprBuilder.binop_map.get(op)
457 err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
458 raise NotSupportedError(err_range,
"unsupported binary operator: " + op.__name__)
459 return BinOp(op_token, lhs, rhs)
462 def build_UnaryOp(ctx, expr):
463 sub_expr = build_expr(ctx, expr.operand)
465 op_token = ExprBuilder.unop_map.get(op)
466 r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(op_token))
468 err_range = ctx.make_raw_range(r.start, sub_expr.range().end)
470 return UnaryOp(r, op_token, sub_expr)
473 def build_BoolOp(ctx, expr):
474 if len(expr.values) < 2:
475 raise AssertionError(
"expected at least 2 values in BoolOp, but got " + str(len(expr.values)))
476 sub_exprs = [build_expr(ctx, sub_expr)
for sub_expr
in expr.values]
478 op_token = ExprBuilder.boolop_map.get(op)
480 err_range = ctx.make_raw_range(sub_exprs[0].range().end, sub_exprs[1].range().start)
481 raise NotSupportedError(err_range,
"unsupported boolean operator: " + op.__name__)
483 for rhs
in sub_exprs[1:]:
484 lhs = BinOp(op_token, lhs, rhs)
488 def build_IfExp(ctx, expr):
489 return TernaryIf(build_expr(ctx, expr.test),
490 build_expr(ctx, expr.body),
491 build_expr(ctx, expr.orelse))
494 def build_Compare(ctx, expr):
495 operands = [build_expr(ctx, e)
for e
in [expr.left] + list(expr.comparators)]
497 for lhs, op_, rhs
in zip(operands, expr.ops, operands[1:]):
499 op_token = ExprBuilder.cmpop_map.get(op)
501 err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
502 raise NotSupportedError(err_range,
"unsupported comparison operator: " + op.__name__)
503 cmp_expr = BinOp(op_token, lhs, rhs)
507 result = BinOp(
'and', result, cmp_expr)
511 def build_Subscript(ctx, expr):
512 def build_SliceExpr(ctx, base, slice_expr):
513 lower = build_expr(ctx, slice_expr.lower)
if slice_expr.lower
is not None else None 514 upper = build_expr(ctx, slice_expr.upper)
if slice_expr.upper
is not None else None 515 if slice_expr.step
is not None:
516 step = build_expr(ctx, slice_expr.step)
517 raise NotSupportedError(step.range(),
"slices with ranges are not supported yet")
518 return SliceExpr(base.range(), lower, upper)
520 def build_Index(ctx, base, index_expr):
521 if isinstance(index_expr.value, ast.Tuple)
or \
522 isinstance(index_expr.value, ast.List):
524 "slicing multiple dimensions with " 525 "sequences not supported yet")
526 return build_expr(ctx, index_expr.value)
528 def build_ExtSlice(ctx, base, extslice):
530 for expr
in extslice.dims:
531 sub_type = type(expr)
532 if sub_type
is ast.Index:
533 sub_exprs.append(build_Index(ctx, base, expr))
534 elif sub_type
is ast.Slice:
535 sub_exprs.append(build_SliceExpr(ctx, base, expr))
538 "slicing multiple dimensions with " 539 "{} not supported".format(sub_type))
542 base = build_expr(ctx, expr.value)
543 sub_type = type(expr.slice)
544 if sub_type
is ast.Index:
545 if isinstance(expr.slice.value, ast.Tuple)
or isinstance(expr.slice.value, ast.List):
547 for index_expr
in expr.slice.value.elts:
548 indices.append(build_expr(ctx, index_expr))
549 return Subscript(base, indices)
551 return Subscript(base, [build_expr(ctx, expr.slice.value)])
552 elif sub_type
is ast.Slice:
553 return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)])
554 elif sub_type
is ast.ExtSlice:
555 return Subscript(base, build_ExtSlice(ctx, base, expr.slice))
560 def build_List(ctx, expr):
561 return ListLiteral(ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
562 [build_expr(ctx, e)
for e
in expr.elts])
565 def build_Tuple(ctx, expr):
566 return TupleLiteral(ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
567 [build_expr(ctx, e)
for e
in expr.elts])
570 def build_Dict(ctx, expr):
571 return DictLiteral(ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
572 [build_expr(ctx, e)
for e
in expr.keys], [build_expr(ctx, e)
for e
in expr.values])
575 def build_Num(ctx, expr):
577 r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(value))
578 return Const(r, value)
581 def build_Str(ctx, expr):
583 r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
584 return StringLiteral(r, value)
587 def build_Starred(ctx, expr):
588 r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
589 return Starred(r, build_expr(ctx, expr.value))
595 def find_after(ctx, pos, substr, offsets=(0, 0)):
596 new_pos = pos + ctx.source[pos:].index(substr)
597 return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1])
600 def find_before(ctx, pos, substr, offsets=(0, 0)):
601 new_pos = ctx.source[:pos].rindex(substr)
602 return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1])
def get_type_line(source)