7 from textwrap 
import dedent
     8 from functools 
import partial
     9 from collections 
import namedtuple
    13 _reserved_prefix = 
'__jit'    14 _reserved_names = {
'print'}
    15 _identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits)
    18 def is_reserved_name(name):
    19     return name.startswith(_reserved_prefix) 
or name 
in _reserved_names
    23     ast.FunctionDef: 
"function definitions",
    25     ast.Delete: 
"del statements",
    26     ast.ClassDef: 
"class definitions",
    27     ast.With: 
"with statements",
    28     ast.Raise: 
"raise statements",
    29     ast.Assert: 
"assertions",
    30     ast.Import: 
"import statements",
    31     ast.ImportFrom: 
"import statements",
    32     ast.Global: 
"global variables",
    33     ast.Break: 
"break statements",
    34     ast.Continue: 
"continue statements",
    38     ast.FunctionDef: 
"def",
    41     ast.ClassDef: 
"class",
    46     ast.ImportFrom: 
"from",
    49     ast.Continue: 
"continue",
    53     pretty_node_names.update({
    54         ast.Print: 
"print statements",
    55         ast.TryExcept: 
"try blocks",
    56         ast.TryFinally: 
"try blocks",
    57         ast.Exec: 
"exec statements",
    60     node_start_tokens.update({
    63         ast.TryFinally: 
"try",
    67     pretty_node_names.update({
    68         ast.AsyncFunctionDef: 
"async function definitions",
    69         ast.AsyncFor: 
"async for loops",
    70         ast.AsyncWith: 
"async with statements",
    71         ast.Try: 
"try blocks",
    72         ast.Nonlocal: 
"nonlocal variables",
    75     node_start_tokens.update({
    76         ast.AsyncFunctionDef: 
"async def",
    77         ast.AsyncFor: 
"async for",
    78         ast.AsyncWith: 
"async with",
    80         ast.Nonlocal: 
"nonlocal",
    83 if sys.version_info >= (3, 6):
    84     pretty_node_names.update({
    85         ast.AnnAssign: 
"annotated assignments",
    91     def __init__(self, source_range, msg):
    98             result += 
'\n' + self.source_range.highlight()
   106 class UnsupportedNodeError(NotSupportedError):
   107     def __init__(self, ctx, offending_node):
   109         node_type = type(offending_node)
   110         range_len = len(node_start_tokens.get(node_type, 
' '))
   111         source_range = ctx.make_range(offending_node.lineno,
   112                                       offending_node.col_offset,
   113                                       offending_node.col_offset + range_len)
   114         feature_name = pretty_node_names.get(node_type, node_type.__name__)
   115         msg = 
"{} aren't supported".format(feature_name)
   116         super(NotSupportedError, self).__init__(source_range, msg)
   123 def build_stmts(ctx, stmts):
   124     stmts = [build_stmt(ctx, s) 
for s 
in stmts]
   125     return list(filter(
None, stmts))
   128 def _uses_true_division(fn):
   131     if inspect.ismethod(fn):
   132         return _uses_true_division(fn.__func__)
   133     elif inspect.isfunction(fn):
   134         return fn.__globals__.get(
'division') 
is __future__.division
   137             '_uses_true_division: expected function or method, got {}'.format(type(fn)))
   140 def get_jit_class_def(cls, self_name=None):
   142     methods = inspect.getmembers(
   143         cls, predicate=
lambda m: inspect.ismethod(m) 
or inspect.isfunction(m))
   144     method_defs = [get_jit_def(method[1],
   145                    self_name=cls.__name__) 
for method 
in methods]
   147     source = dedent(inspect.getsource(cls))
   148     py_ast = ast.parse(source)
   149     ctx = SourceContext(source, 
False)
   150     return build_class_def(ctx, py_ast.body[0], method_defs)
   153 def get_jit_def(fn, self_name=None):
   154     source = dedent(inspect.getsource(fn))
   155     py_ast = ast.parse(source)
   156     if len(py_ast.body) != 1 
or not isinstance(py_ast.body[0], ast.FunctionDef):
   157         raise RuntimeError(
"expected a single top-level function")
   159     ctx = SourceContext(source, _uses_true_division(fn))
   160     return build_def(ctx, py_ast.body[0], type_line, self_name)
   166     def __init__(self, source, uses_true_division=True):
   167         super(SourceContext, self).__init__(source)
   172     def __call__(self, ctx, node):
   173         method = getattr(self, 
'build_' + node.__class__.__name__, 
None)
   176         return method(ctx, node)
   179 def build_class_def(ctx, py_def, methods):
   180     r = ctx.make_range(py_def.lineno, py_def.col_offset,
   181                        py_def.col_offset + len(
"class"))
   182     return ClassDef(Ident(r, py_def.name), methods)
   185 def build_def(ctx, py_def, type_line, self_name=None):
   187     r = ctx.make_range(py_def.lineno, py_def.col_offset,
   188                        py_def.col_offset + len(
"def"))
   189     param_list = build_param_list(ctx, py_def.args, self_name)
   191     if getattr(py_def, 
'returns', 
None) 
is not None:
   192         return_type = build_expr(ctx, py_def.returns)
   193     decl = Decl(r, param_list, return_type)
   194     is_method = self_name 
is not None   195     if type_line 
is not None:
   196         type_comment_decl = torch._C.parse_type_comment(type_line)
   197         decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
   198     return Def(Ident(r, py_def.name),
   200                build_stmts(ctx, body))
   203 _vararg_kwarg_err = (
"Compiled functions can't take variable number of arguments "   204                      "or use keyword-only arguments with defaults")
   207 def build_param_list(ctx, py_args, self_name):
   208     if py_args.vararg 
is not None or py_args.kwarg 
is not None:
   209         raise ValueError(_vararg_kwarg_err)
   210     if not PY2 
and py_args.kw_defaults:
   211         raise ValueError(_vararg_kwarg_err)
   212     result = [build_param(ctx, arg, self_name, 
False) 
for arg 
in py_args.args]
   214         result += [build_params(ctx, arg, self_name, 
True) 
for arg 
in py_args.kwonlyargs]
   218 def build_param(ctx, py_arg, self_name, kwarg_only):
   221     name = py_arg.id 
if PY2 
else py_arg.arg
   222     r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name))
   223     if getattr(py_arg, 
'annotation', 
None) 
is not None:
   224         annotation_expr = build_expr(ctx, py_arg.annotation)
   225     elif self_name 
is not None and name == 
'self':
   226         annotation_expr = Var(Ident(r, self_name))
   228         annotation_expr = Var(Ident(r, 
'Tensor'))
   229     return Param(annotation_expr, Ident(r, name), kwarg_only)
   232 def get_default_args(fn):
   234         argspec = inspect.getargspec(fn)
   235         if argspec.defaults 
is not None:
   236             return dict(zip(argspec.args[-len(argspec.defaults):], argspec.defaults))
   240         signature = inspect.signature(fn)
   243             for k, v 
in signature.parameters.items()
   244             if v.default 
is not inspect.Parameter.empty
   257     def build_Expr(ctx, stmt):
   259         if value.__class__.__name__ == 
'Str':
   264             return ExprStmt(build_expr(ctx, value))
   267     def build_Assign(ctx, stmt):
   268         rhs = build_expr(ctx, stmt.value)
   269         if len(stmt.targets) > 1:
   270             start_point = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + 1)
   272                                     "Performing multiple assignments in a single line isn't supported")
   273         lhs = build_expr(ctx, stmt.targets[0])
   274         return Assign(lhs, rhs)
   277     def build_Return(ctx, stmt):
   278         r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"return"))
   279         return Return(r, 
None if stmt.value 
is None else build_expr(ctx, stmt.value))
   282     def build_Raise(ctx, stmt):
   283         r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"raise"))
   288             expr = build_expr(ctx, stmt.inst) 
if stmt.inst 
else None   290             expr = build_expr(ctx, stmt.exc)
   291         return Raise(r, expr)
   294     def build_Assert(ctx, stmt):
   295         r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"assert"))
   296         test = build_expr(ctx, stmt.test)
   297         msg = build_expr(ctx, stmt.msg) 
if stmt.msg 
is not None else None   298         return Assert(r, test, msg)
   301     def build_AugAssign(ctx, stmt):
   302         lhs = build_expr(ctx, stmt.target)
   303         rhs = build_expr(ctx, stmt.value)
   305         if op 
in StmtBuilder.augassign_map:
   306             op_token = StmtBuilder.augassign_map[op]
   309                 find_before(ctx, rhs.range().start, 
'=', offsets=(-1, 0)),
   310                 "unsupported kind of augumented assignment: " + op.__name__)
   311         return AugAssign(lhs, op_token, rhs)
   314     def build_While(ctx, stmt):
   319         r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"while"))
   320         return While(r, build_expr(ctx, stmt.test),
   321                      build_stmts(ctx, stmt.body))
   324     def build_For(ctx, stmt):
   325         r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"for"))
   327             r, [build_expr(ctx, stmt.target)],
   328             [build_expr(ctx, stmt.iter)], build_stmts(ctx, stmt.body))
   331     def build_If(ctx, stmt):
   332         r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"if"))
   333         return If(r, build_expr(ctx, stmt.test),
   334                   build_stmts(ctx, stmt.body),
   335                   build_stmts(ctx, stmt.orelse))
   338     def build_Print(ctx, stmt):
   339         r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"print"))
   341             raise NotSupportedError(r, 
"print statements with non-default destinations aren't supported")
   342         args = [build_expr(ctx, val) 
for val 
in stmt.values]
   343         return ExprStmt(Apply(Var(Ident(r, 
"print")), args, []))
   346     def build_Pass(ctx, stmt):
   347         r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len(
"pass"))
   366         binop_map[ast.MatMult] = 
'@'   390     def build_Attribute(ctx, expr):
   392         value = build_expr(ctx, expr.value)
   395         pos = find_after(ctx, value.range().end, 
'.').end  
   396         while source[pos] 
in string.whitespace:  
   399         while source[pos] 
in _identifier_chars:  
   401         name_range = ctx.make_raw_range(start_pos, pos)
   402         return Select(value, Ident(name_range, expr.attr))
   405     def build_Call(ctx, expr):
   406         func = build_expr(ctx, expr.func)
   407         args = [build_expr(ctx, py_arg) 
for py_arg 
in expr.args]
   408         if hasattr(expr, 
'starargs') 
and expr.starargs:
   409             stararg_expr = build_expr(ctx, expr.starargs)
   410             args += [Starred(stararg_expr.range(), stararg_expr)]
   412         for kw 
in expr.keywords:
   413             kw_expr = build_expr(ctx, kw.value)
   415             kwargs.append(
Attribute(Ident(kw_expr.range(), kw.arg), kw_expr))
   416         return Apply(func, args, kwargs)
   419     def build_Name(ctx, expr):
   420         r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id))
   421         if expr.id.startswith(_reserved_prefix):
   423                                        "can't start with " + _reserved_prefix)
   424         if expr.id == 
"True":
   425             return TrueLiteral(r)
   426         elif expr.id == 
"False":
   427             return FalseLiteral(r)
   428         elif expr.id == 
"None":
   429             return NoneLiteral(r)
   430         return Var(Ident(r, expr.id))
   433     def build_NameConstant(ctx, expr):
   434         r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(str(expr.value)))
   435         if expr.value 
is True:
   436             return TrueLiteral(r)
   437         elif expr.value 
is False:
   438             return FalseLiteral(r)
   439         elif expr.value 
is None:
   440             return NoneLiteral(r)
   442             raise ValueError(
"Name constant value unsupported: " + str(expr.value))
   445     def build_BinOp(ctx, expr):
   446         lhs = build_expr(ctx, expr.left)
   447         rhs = build_expr(ctx, expr.right)
   450         if op == ast.Div 
and not ctx.uses_true_division:
   451             raise RuntimeError(
'Division of ints in JIT script uses Python 3 true '   452                                'division semantics. Please put `from __future__ '   453                                'import division` at the top of your file')
   455         op_token = ExprBuilder.binop_map.get(op)
   457             err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
   458             raise NotSupportedError(err_range, 
"unsupported binary operator: " + op.__name__)
   459         return BinOp(op_token, lhs, rhs)
   462     def build_UnaryOp(ctx, expr):
   463         sub_expr = build_expr(ctx, expr.operand)
   465         op_token = ExprBuilder.unop_map.get(op)
   466         r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(op_token))
   468             err_range = ctx.make_raw_range(r.start, sub_expr.range().end)
   470         return UnaryOp(r, op_token, sub_expr)
   473     def build_BoolOp(ctx, expr):
   474         if len(expr.values) < 2:
   475             raise AssertionError(
"expected at least 2 values in BoolOp, but got " + str(len(expr.values)))
   476         sub_exprs = [build_expr(ctx, sub_expr) 
for sub_expr 
in expr.values]
   478         op_token = ExprBuilder.boolop_map.get(op)
   480             err_range = ctx.make_raw_range(sub_exprs[0].range().end, sub_exprs[1].range().start)
   481             raise NotSupportedError(err_range, 
"unsupported boolean operator: " + op.__name__)
   483         for rhs 
in sub_exprs[1:]:
   484             lhs = BinOp(op_token, lhs, rhs)
   488     def build_IfExp(ctx, expr):
   489         return TernaryIf(build_expr(ctx, expr.test),
   490                          build_expr(ctx, expr.body),
   491                          build_expr(ctx, expr.orelse))
   494     def build_Compare(ctx, expr):
   495         operands = [build_expr(ctx, e) 
for e 
in [expr.left] + list(expr.comparators)]
   497         for lhs, op_, rhs 
in zip(operands, expr.ops, operands[1:]):
   499             op_token = ExprBuilder.cmpop_map.get(op)
   501                 err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
   502                 raise NotSupportedError(err_range, 
"unsupported comparison operator: " + op.__name__)
   503             cmp_expr = BinOp(op_token, lhs, rhs)
   507                 result = BinOp(
'and', result, cmp_expr)
   511     def build_Subscript(ctx, expr):
   512         def build_SliceExpr(ctx, base, slice_expr):
   513             lower = build_expr(ctx, slice_expr.lower) 
if slice_expr.lower 
is not None else None   514             upper = build_expr(ctx, slice_expr.upper) 
if slice_expr.upper 
is not None else None   515             if slice_expr.step 
is not None:
   516                 step = build_expr(ctx, slice_expr.step)
   517                 raise NotSupportedError(step.range(), 
"slices with ranges are not supported yet")
   518             return SliceExpr(base.range(), lower, upper)
   520         def build_Index(ctx, base, index_expr):
   521             if isinstance(index_expr.value, ast.Tuple) 
or \
   522                     isinstance(index_expr.value, ast.List):
   524                                         "slicing multiple dimensions with "   525                                         "sequences not supported yet")
   526             return build_expr(ctx, index_expr.value)
   528         def build_ExtSlice(ctx, base, extslice):
   530             for expr 
in extslice.dims:
   531                 sub_type = type(expr)
   532                 if sub_type 
is ast.Index:
   533                     sub_exprs.append(build_Index(ctx, base, expr))
   534                 elif sub_type 
is ast.Slice:
   535                     sub_exprs.append(build_SliceExpr(ctx, base, expr))
   538                                             "slicing multiple dimensions with "   539                                             "{} not supported".format(sub_type))
   542         base = build_expr(ctx, expr.value)
   543         sub_type = type(expr.slice)
   544         if sub_type 
is ast.Index:
   545             if isinstance(expr.slice.value, ast.Tuple) 
or isinstance(expr.slice.value, ast.List):
   547                 for index_expr 
in expr.slice.value.elts:
   548                     indices.append(build_expr(ctx, index_expr))
   549                 return Subscript(base, indices)
   551                 return Subscript(base, [build_expr(ctx, expr.slice.value)])
   552         elif sub_type 
is ast.Slice:
   553             return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)])
   554         elif sub_type 
is ast.ExtSlice:
   555             return Subscript(base, build_ExtSlice(ctx, base, expr.slice))
   560     def build_List(ctx, expr):
   561         return ListLiteral(ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
   562                            [build_expr(ctx, e) 
for e 
in expr.elts])
   565     def build_Tuple(ctx, expr):
   566         return TupleLiteral(ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
   567                             [build_expr(ctx, e) 
for e 
in expr.elts])
   570     def build_Dict(ctx, expr):
   571         return DictLiteral(ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
   572                            [build_expr(ctx, e) 
for e 
in expr.keys], [build_expr(ctx, e) 
for e 
in expr.values])
   575     def build_Num(ctx, expr):
   577         r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(value))
   578         return Const(r, value)
   581     def build_Str(ctx, expr):
   583         r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
   584         return StringLiteral(r, value)
   587     def build_Starred(ctx, expr):
   588         r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
   589         return Starred(r, build_expr(ctx, expr.value))
   595 def find_after(ctx, pos, substr, offsets=(0, 0)):
   596     new_pos = pos + ctx.source[pos:].index(substr)
   597     return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1])
   600 def find_before(ctx, pos, substr, offsets=(0, 0)):
   601     new_pos = ctx.source[:pos].rindex(substr)
   602     return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1])
 
def get_type_line(source)