Caffe2 - Python API
A deep learning, cross platform ML framework
frontend.py
1 import __future__
2 import torch
3 import sys
4 import ast
5 import inspect
6 import string
7 from textwrap import dedent
8 from functools import partial
9 from collections import namedtuple
10 from torch._six import PY2
11 from torch._C._jit_tree_views import *
12 
13 _reserved_prefix = '__jit'
14 _reserved_names = {'print'}
15 _identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits)
16 
17 
18 def is_reserved_name(name):
19  return name.startswith(_reserved_prefix) or name in _reserved_names
20 
21 
22 pretty_node_names = {
23  ast.FunctionDef: "function definitions",
24  ast.For: "for loops",
25  ast.Delete: "del statements",
26  ast.ClassDef: "class definitions",
27  ast.With: "with statements",
28  ast.Raise: "raise statements",
29  ast.Assert: "assertions",
30  ast.Import: "import statements",
31  ast.ImportFrom: "import statements",
32  ast.Global: "global variables",
33  ast.Break: "break statements",
34  ast.Continue: "continue statements",
35 }
36 
37 node_start_tokens = {
38  ast.FunctionDef: "def",
39  ast.For: "for",
40  ast.Delete: "del",
41  ast.ClassDef: "class",
42  ast.With: "with",
43  ast.Raise: "raise",
44  ast.Assert: "assert",
45  ast.Import: "import",
46  ast.ImportFrom: "from",
47  ast.Global: "global",
48  ast.Break: "break",
49  ast.Continue: "continue",
50 }
51 
52 if PY2:
53  pretty_node_names.update({
54  ast.Print: "print statements",
55  ast.TryExcept: "try blocks",
56  ast.TryFinally: "try blocks",
57  ast.Exec: "exec statements",
58  })
59 
60  node_start_tokens.update({
61  ast.Print: "print",
62  ast.TryExcept: "try",
63  ast.TryFinally: "try",
64  ast.Exec: "exec",
65  })
66 else:
67  pretty_node_names.update({
68  ast.AsyncFunctionDef: "async function definitions",
69  ast.AsyncFor: "async for loops",
70  ast.AsyncWith: "async with statements",
71  ast.Try: "try blocks",
72  ast.Nonlocal: "nonlocal variables",
73  })
74 
75  node_start_tokens.update({
76  ast.AsyncFunctionDef: "async def",
77  ast.AsyncFor: "async for",
78  ast.AsyncWith: "async with",
79  ast.Try: "try",
80  ast.Nonlocal: "nonlocal",
81  })
82 
83 if sys.version_info >= (3, 6):
84  pretty_node_names.update({
85  ast.AnnAssign: "annotated assignments",
86  })
87  # NB: no specific token for AnnAssign
88 
89 
90 class FrontendError(Exception):
91  def __init__(self, source_range, msg):
92  self.source_range = source_range
93  self.msg = msg
94 
95  def __str__(self):
96  result = self.msg
97  if self.source_range is not None:
98  result += '\n' + self.source_range.highlight()
99  return result
100 
101 
103  pass
104 
105 
106 class UnsupportedNodeError(NotSupportedError):
107  def __init__(self, ctx, offending_node):
108  # If we don't have a specific token, we default to length of 1
109  node_type = type(offending_node)
110  range_len = len(node_start_tokens.get(node_type, ' '))
111  source_range = ctx.make_range(offending_node.lineno,
112  offending_node.col_offset,
113  offending_node.col_offset + range_len)
114  feature_name = pretty_node_names.get(node_type, node_type.__name__)
115  msg = "{} aren't supported".format(feature_name)
116  super(NotSupportedError, self).__init__(source_range, msg)
117 
118 
120  pass
121 
122 
123 def build_stmts(ctx, stmts):
124  stmts = [build_stmt(ctx, s) for s in stmts]
125  return list(filter(None, stmts))
126 
127 
128 def _uses_true_division(fn):
129  if not PY2:
130  return True
131  if inspect.ismethod(fn):
132  return _uses_true_division(fn.__func__)
133  elif inspect.isfunction(fn):
134  return fn.__globals__.get('division') is __future__.division
135  else:
136  raise RuntimeError(
137  '_uses_true_division: expected function or method, got {}'.format(type(fn)))
138 
139 
140 def get_jit_class_def(cls, self_name=None):
141  # Get defs for each method independently
142  methods = inspect.getmembers(
143  cls, predicate=lambda m: inspect.ismethod(m) or inspect.isfunction(m))
144  method_defs = [get_jit_def(method[1],
145  self_name=cls.__name__) for method in methods]
146 
147  source = dedent(inspect.getsource(cls))
148  py_ast = ast.parse(source)
149  ctx = SourceContext(source, False)
150  return build_class_def(ctx, py_ast.body[0], method_defs)
151 
152 
153 def get_jit_def(fn, self_name=None):
154  source = dedent(inspect.getsource(fn))
155  py_ast = ast.parse(source)
156  if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
157  raise RuntimeError("expected a single top-level function")
158  type_line = torch.jit.annotations.get_type_line(source)
159  ctx = SourceContext(source, _uses_true_division(fn))
160  return build_def(ctx, py_ast.body[0], type_line, self_name)
161 
162 
163 # Thin wrapper around SourceRangeFactory to store extra metadata
164 # about the function-to-be-compiled.
165 class SourceContext(SourceRangeFactory):
166  def __init__(self, source, uses_true_division=True):
167  super(SourceContext, self).__init__(source)
168  self.uses_true_division = uses_true_division
169 
170 
171 class Builder(object):
172  def __call__(self, ctx, node):
173  method = getattr(self, 'build_' + node.__class__.__name__, None)
174  if method is None:
175  raise UnsupportedNodeError(ctx, node)
176  return method(ctx, node)
177 
178 
179 def build_class_def(ctx, py_def, methods):
180  r = ctx.make_range(py_def.lineno, py_def.col_offset,
181  py_def.col_offset + len("class"))
182  return ClassDef(Ident(r, py_def.name), methods)
183 
184 
185 def build_def(ctx, py_def, type_line, self_name=None):
186  body = py_def.body
187  r = ctx.make_range(py_def.lineno, py_def.col_offset,
188  py_def.col_offset + len("def"))
189  param_list = build_param_list(ctx, py_def.args, self_name)
190  return_type = None
191  if getattr(py_def, 'returns', None) is not None:
192  return_type = build_expr(ctx, py_def.returns)
193  decl = Decl(r, param_list, return_type)
194  is_method = self_name is not None
195  if type_line is not None:
196  type_comment_decl = torch._C.parse_type_comment(type_line)
197  decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
198  return Def(Ident(r, py_def.name),
199  decl,
200  build_stmts(ctx, body))
201 
202 
203 _vararg_kwarg_err = ("Compiled functions can't take variable number of arguments "
204  "or use keyword-only arguments with defaults")
205 
206 
207 def build_param_list(ctx, py_args, self_name):
208  if py_args.vararg is not None or py_args.kwarg is not None:
209  raise ValueError(_vararg_kwarg_err)
210  if not PY2 and py_args.kw_defaults:
211  raise ValueError(_vararg_kwarg_err)
212  result = [build_param(ctx, arg, self_name, False) for arg in py_args.args]
213  if not PY2:
214  result += [build_params(ctx, arg, self_name, True) for arg in py_args.kwonlyargs]
215  return result
216 
217 
218 def build_param(ctx, py_arg, self_name, kwarg_only):
219  # NB: In Python3 py_arg is a pair of (str arg, expr? annotation)
220  # In Python2 py_arg is a Name (Expr subclass)
221  name = py_arg.id if PY2 else py_arg.arg
222  r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name))
223  if getattr(py_arg, 'annotation', None) is not None:
224  annotation_expr = build_expr(ctx, py_arg.annotation)
225  elif self_name is not None and name == 'self':
226  annotation_expr = Var(Ident(r, self_name))
227  else:
228  annotation_expr = Var(Ident(r, 'Tensor'))
229  return Param(annotation_expr, Ident(r, name), kwarg_only)
230 
231 
232 def get_default_args(fn):
233  if PY2:
234  argspec = inspect.getargspec(fn)
235  if argspec.defaults is not None:
236  return dict(zip(argspec.args[-len(argspec.defaults):], argspec.defaults))
237  else:
238  return {}
239  else:
240  signature = inspect.signature(fn)
241  return {
242  k: v.default
243  for k, v in signature.parameters.items()
244  if v.default is not inspect.Parameter.empty
245  }
246 
247 
249  augassign_map = {
250  ast.Add: '+',
251  ast.Sub: '-',
252  ast.Mult: '*',
253  ast.Div: '/',
254  }
255 
256  @staticmethod
257  def build_Expr(ctx, stmt):
258  value = stmt.value
259  if value.__class__.__name__ == 'Str':
260  # If a statement is a string literal expression,
261  # then it is a docstring. Just ignore it.
262  return None
263  else:
264  return ExprStmt(build_expr(ctx, value))
265 
266  @staticmethod
267  def build_Assign(ctx, stmt):
268  rhs = build_expr(ctx, stmt.value)
269  if len(stmt.targets) > 1:
270  start_point = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + 1)
271  raise NotSupportedError(ctx.make_raw_range(start_point.start, rhs.range().end),
272  "Performing multiple assignments in a single line isn't supported")
273  lhs = build_expr(ctx, stmt.targets[0])
274  return Assign(lhs, rhs)
275 
276  @staticmethod
277  def build_Return(ctx, stmt):
278  r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("return"))
279  return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value))
280 
281  @staticmethod
282  def build_Raise(ctx, stmt):
283  r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("raise"))
284  if PY2:
285  if stmt.tback:
286  raise NotSupportedError(r, "tracebacks with exceptions is not supported")
287  # TODO use stmt.type once instantiating exceptions is supported
288  expr = build_expr(ctx, stmt.inst) if stmt.inst else None
289  else:
290  expr = build_expr(ctx, stmt.exc)
291  return Raise(r, expr)
292 
293  @staticmethod
294  def build_Assert(ctx, stmt):
295  r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("assert"))
296  test = build_expr(ctx, stmt.test)
297  msg = build_expr(ctx, stmt.msg) if stmt.msg is not None else None
298  return Assert(r, test, msg)
299 
300  @staticmethod
301  def build_AugAssign(ctx, stmt):
302  lhs = build_expr(ctx, stmt.target)
303  rhs = build_expr(ctx, stmt.value)
304  op = type(stmt.op)
305  if op in StmtBuilder.augassign_map:
306  op_token = StmtBuilder.augassign_map[op]
307  else:
308  raise NotSupportedError(
309  find_before(ctx, rhs.range().start, '=', offsets=(-1, 0)),
310  "unsupported kind of augumented assignment: " + op.__name__)
311  return AugAssign(lhs, op_token, rhs)
312 
313  @staticmethod
314  def build_While(ctx, stmt):
315  if stmt.orelse:
316  # TODO: try to recover the location of else:? Python doesn't give us useful
317  # annotations in this case
318  raise NotSupportedError(None, "else branches of while loops aren't supported")
319  r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("while"))
320  return While(r, build_expr(ctx, stmt.test),
321  build_stmts(ctx, stmt.body))
322 
323  @staticmethod
324  def build_For(ctx, stmt):
325  r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("for"))
326  return For(
327  r, [build_expr(ctx, stmt.target)],
328  [build_expr(ctx, stmt.iter)], build_stmts(ctx, stmt.body))
329 
330  @staticmethod
331  def build_If(ctx, stmt):
332  r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("if"))
333  return If(r, build_expr(ctx, stmt.test),
334  build_stmts(ctx, stmt.body),
335  build_stmts(ctx, stmt.orelse))
336 
337  @staticmethod
338  def build_Print(ctx, stmt):
339  r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("print"))
340  if stmt.dest:
341  raise NotSupportedError(r, "print statements with non-default destinations aren't supported")
342  args = [build_expr(ctx, val) for val in stmt.values]
343  return ExprStmt(Apply(Var(Ident(r, "print")), args, []))
344 
345  @staticmethod
346  def build_Pass(ctx, stmt):
347  r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("pass"))
348  return Pass(r)
349 
350 
352  binop_map = {
353  ast.Add: '+',
354  ast.Sub: '-',
355  ast.Mult: '*',
356  ast.Div: '/',
357  ast.Pow: '**',
358  ast.Mod: '%',
359  ast.FloorDiv: '//',
360  ast.BitAnd: '&',
361  ast.BitXor: '^',
362  ast.BitOr: '|',
363  }
364 
365  if not PY2:
366  binop_map[ast.MatMult] = '@'
367 
368  unop_map = {
369  ast.Not: 'not',
370  ast.USub: '-',
371  }
372 
373  boolop_map = {
374  ast.And: 'and',
375  ast.Or: 'or',
376  }
377 
378  cmpop_map = {
379  ast.Eq: '==',
380  ast.NotEq: '!=',
381  ast.LtE: '<=',
382  ast.Lt: '<',
383  ast.GtE: '>=',
384  ast.Gt: '>',
385  ast.Is: 'is',
386  ast.IsNot: 'is not',
387  }
388 
389  @staticmethod
390  def build_Attribute(ctx, expr):
391  # NB: the only attributes we support are for getting methods
392  value = build_expr(ctx, expr.value)
393  # <sigh> name is just a string, so it's not annotated in any way.
394  source = ctx.source
395  pos = find_after(ctx, value.range().end, '.').end # Start with the dot
396  while source[pos] in string.whitespace: # Skip whitespace
397  pos += 1
398  start_pos = pos
399  while source[pos] in _identifier_chars: # Find the identifier itself
400  pos += 1
401  name_range = ctx.make_raw_range(start_pos, pos)
402  return Select(value, Ident(name_range, expr.attr))
403 
404  @staticmethod
405  def build_Call(ctx, expr):
406  func = build_expr(ctx, expr.func)
407  args = [build_expr(ctx, py_arg) for py_arg in expr.args]
408  if hasattr(expr, 'starargs') and expr.starargs:
409  stararg_expr = build_expr(ctx, expr.starargs)
410  args += [Starred(stararg_expr.range(), stararg_expr)]
411  kwargs = []
412  for kw in expr.keywords:
413  kw_expr = build_expr(ctx, kw.value)
414  # XXX: we could do a better job at figuring out the range for the name here
415  kwargs.append(Attribute(Ident(kw_expr.range(), kw.arg), kw_expr))
416  return Apply(func, args, kwargs)
417 
418  @staticmethod
419  def build_Name(ctx, expr):
420  r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id))
421  if expr.id.startswith(_reserved_prefix):
422  raise NotSupportedError(r, "names of variables used in JIT-ed functions "
423  "can't start with " + _reserved_prefix)
424  if expr.id == "True":
425  return TrueLiteral(r)
426  elif expr.id == "False":
427  return FalseLiteral(r)
428  elif expr.id == "None":
429  return NoneLiteral(r)
430  return Var(Ident(r, expr.id))
431 
432  @staticmethod
433  def build_NameConstant(ctx, expr):
434  r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(str(expr.value)))
435  if expr.value is True:
436  return TrueLiteral(r)
437  elif expr.value is False:
438  return FalseLiteral(r)
439  elif expr.value is None:
440  return NoneLiteral(r)
441  else:
442  raise ValueError("Name constant value unsupported: " + str(expr.value))
443 
444  @staticmethod
445  def build_BinOp(ctx, expr):
446  lhs = build_expr(ctx, expr.left)
447  rhs = build_expr(ctx, expr.right)
448  op = type(expr.op)
449 
450  if op == ast.Div and not ctx.uses_true_division:
451  raise RuntimeError('Division of ints in JIT script uses Python 3 true '
452  'division semantics. Please put `from __future__ '
453  'import division` at the top of your file')
454 
455  op_token = ExprBuilder.binop_map.get(op)
456  if op_token is None:
457  err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
458  raise NotSupportedError(err_range, "unsupported binary operator: " + op.__name__)
459  return BinOp(op_token, lhs, rhs)
460 
461  @staticmethod
462  def build_UnaryOp(ctx, expr):
463  sub_expr = build_expr(ctx, expr.operand)
464  op = type(expr.op)
465  op_token = ExprBuilder.unop_map.get(op)
466  r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(op_token))
467  if op_token is None:
468  err_range = ctx.make_raw_range(r.start, sub_expr.range().end)
469  raise NotSupportedError(err_range, "unsupported unary operator: " + op.__name__)
470  return UnaryOp(r, op_token, sub_expr)
471 
472  @staticmethod
473  def build_BoolOp(ctx, expr):
474  if len(expr.values) < 2:
475  raise AssertionError("expected at least 2 values in BoolOp, but got " + str(len(expr.values)))
476  sub_exprs = [build_expr(ctx, sub_expr) for sub_expr in expr.values]
477  op = type(expr.op)
478  op_token = ExprBuilder.boolop_map.get(op)
479  if op_token is None:
480  err_range = ctx.make_raw_range(sub_exprs[0].range().end, sub_exprs[1].range().start)
481  raise NotSupportedError(err_range, "unsupported boolean operator: " + op.__name__)
482  lhs = sub_exprs[0]
483  for rhs in sub_exprs[1:]:
484  lhs = BinOp(op_token, lhs, rhs)
485  return lhs
486 
487  @staticmethod
488  def build_IfExp(ctx, expr):
489  return TernaryIf(build_expr(ctx, expr.test),
490  build_expr(ctx, expr.body),
491  build_expr(ctx, expr.orelse))
492 
493  @staticmethod
494  def build_Compare(ctx, expr):
495  operands = [build_expr(ctx, e) for e in [expr.left] + list(expr.comparators)]
496  result = None
497  for lhs, op_, rhs in zip(operands, expr.ops, operands[1:]):
498  op = type(op_)
499  op_token = ExprBuilder.cmpop_map.get(op)
500  if op_token is None:
501  err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
502  raise NotSupportedError(err_range, "unsupported comparison operator: " + op.__name__)
503  cmp_expr = BinOp(op_token, lhs, rhs)
504  if result is None:
505  result = cmp_expr
506  else:
507  result = BinOp('and', result, cmp_expr)
508  return result
509 
510  @staticmethod
511  def build_Subscript(ctx, expr):
512  def build_SliceExpr(ctx, base, slice_expr):
513  lower = build_expr(ctx, slice_expr.lower) if slice_expr.lower is not None else None
514  upper = build_expr(ctx, slice_expr.upper) if slice_expr.upper is not None else None
515  if slice_expr.step is not None:
516  step = build_expr(ctx, slice_expr.step)
517  raise NotSupportedError(step.range(), "slices with ranges are not supported yet")
518  return SliceExpr(base.range(), lower, upper)
519 
520  def build_Index(ctx, base, index_expr):
521  if isinstance(index_expr.value, ast.Tuple) or \
522  isinstance(index_expr.value, ast.List):
523  raise NotSupportedError(base.range(),
524  "slicing multiple dimensions with "
525  "sequences not supported yet")
526  return build_expr(ctx, index_expr.value)
527 
528  def build_ExtSlice(ctx, base, extslice):
529  sub_exprs = []
530  for expr in extslice.dims:
531  sub_type = type(expr)
532  if sub_type is ast.Index:
533  sub_exprs.append(build_Index(ctx, base, expr))
534  elif sub_type is ast.Slice:
535  sub_exprs.append(build_SliceExpr(ctx, base, expr))
536  else:
537  raise NotSupportedError(base.range(),
538  "slicing multiple dimensions with "
539  "{} not supported".format(sub_type))
540  return sub_exprs
541 
542  base = build_expr(ctx, expr.value)
543  sub_type = type(expr.slice)
544  if sub_type is ast.Index:
545  if isinstance(expr.slice.value, ast.Tuple) or isinstance(expr.slice.value, ast.List):
546  indices = []
547  for index_expr in expr.slice.value.elts:
548  indices.append(build_expr(ctx, index_expr))
549  return Subscript(base, indices)
550  else:
551  return Subscript(base, [build_expr(ctx, expr.slice.value)])
552  elif sub_type is ast.Slice:
553  return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)])
554  elif sub_type is ast.ExtSlice:
555  return Subscript(base, build_ExtSlice(ctx, base, expr.slice))
556  else: # Ellipsis (can only happen in Python 2)
557  raise NotSupportedError(base.range(), "ellipsis is not supported")
558 
559  @staticmethod
560  def build_List(ctx, expr):
561  return ListLiteral(ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
562  [build_expr(ctx, e) for e in expr.elts])
563 
564  @staticmethod
565  def build_Tuple(ctx, expr):
566  return TupleLiteral(ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
567  [build_expr(ctx, e) for e in expr.elts])
568 
569  @staticmethod
570  def build_Dict(ctx, expr):
571  return DictLiteral(ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
572  [build_expr(ctx, e) for e in expr.keys], [build_expr(ctx, e) for e in expr.values])
573 
574  @staticmethod
575  def build_Num(ctx, expr):
576  value = str(expr.n)
577  r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(value))
578  return Const(r, value)
579 
580  @staticmethod
581  def build_Str(ctx, expr):
582  value = str(expr.s)
583  r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
584  return StringLiteral(r, value)
585 
586  @staticmethod
587  def build_Starred(ctx, expr):
588  r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
589  return Starred(r, build_expr(ctx, expr.value))
590 
591 build_expr = ExprBuilder()
592 build_stmt = StmtBuilder()
593 
594 
595 def find_after(ctx, pos, substr, offsets=(0, 0)):
596  new_pos = pos + ctx.source[pos:].index(substr)
597  return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1])
598 
599 
600 def find_before(ctx, pos, substr, offsets=(0, 0)):
601  new_pos = ctx.source[:pos].rindex(substr)
602  return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1])
def get_type_line(source)
Definition: annotations.py:112