Caffe2 - Python API
A deep learning, cross platform ML framework
common_with_cwrap.py
1 # this code should be common among cwrap and ATen preprocessing
2 # for now, I have put it in one place but right now is copied out of cwrap
3 
4 from copy import deepcopy
5 from itertools import product
6 
7 
8 def parse_arguments(args):
9  new_args = []
10  for arg in args:
11  # Simple arg declaration of form "<type> <name>"
12  if isinstance(arg, str):
13  t, _, name = arg.partition(' ')
14  new_args.append({'type': t, 'name': name})
15  elif isinstance(arg, dict):
16  if 'arg' in arg:
17  arg['type'], _, arg['name'] = arg['arg'].partition(' ')
18  del arg['arg']
19  new_args.append(arg)
20  else:
21  assert False
22  return new_args
23 
24 
25 def set_declaration_defaults(declaration):
26  if 'schema_string' not in declaration:
27  declaration['schema_string'] = ''
28  if 'matches_jit_signature' not in declaration:
29  declaration['matches_jit_signature'] = False
30  declaration.setdefault('arguments', [])
31  declaration.setdefault('return', 'void')
32  if 'cname' not in declaration:
33  declaration['cname'] = declaration['name']
34  if 'backends' not in declaration:
35  declaration['backends'] = ['CPU', 'CUDA']
36  if 'api_name' not in declaration:
37  declaration['api_name'] = declaration['name']
38  # Simulate multiple dispatch, even if it's not necessary
39  if 'options' not in declaration:
40  declaration['options'] = [{'arguments': declaration['arguments']}]
41  del declaration['arguments']
42  # Parse arguments (some of them can be strings)
43  for option in declaration['options']:
44  option['arguments'] = parse_arguments(option['arguments'])
45  # Propagate defaults from declaration to options
46  for option in declaration['options']:
47  for k, v in declaration.items():
48  # TODO(zach): why does cwrap not propagate 'name'? I need it
49  # propagaged for ATen
50  if k != 'options':
51  option.setdefault(k, v)
52 
53 # TODO(zach): added option to remove keyword handling for C++ which cannot
54 # support it.
55 
56 
57 def filter_unique_options(options, allow_kwarg, type_to_signature, remove_self):
58  def exclude_arg(arg):
59  return arg.get('ignore_check') or arg['type'] == 'CONSTANT'
60 
61  def exclude_arg_with_self_check(arg):
62  return exclude_arg(arg) or (remove_self and arg['name'] == 'self')
63 
64  def signature(option, kwarg_only_count):
65  if kwarg_only_count == 0:
66  kwarg_only_count = None
67  else:
68  kwarg_only_count = -kwarg_only_count
69  arg_signature = '#'.join(
70  type_to_signature.get(arg['type'], arg['type'])
71  for arg in option['arguments'][:kwarg_only_count]
72  if not exclude_arg_with_self_check(arg))
73  if kwarg_only_count is None:
74  return arg_signature
75  kwarg_only_signature = '#'.join(
76  arg['name'] + '#' + arg['type']
77  for arg in option['arguments'][kwarg_only_count:]
78  if not exclude_arg(arg))
79  return arg_signature + "#-#" + kwarg_only_signature
80  seen_signatures = set()
81  unique = []
82  for option in options:
83  # if only check num_kwarg_only == 0 if allow_kwarg == False
84  limit = len(option['arguments']) if allow_kwarg else 0
85  for num_kwarg_only in range(0, limit + 1):
86  sig = signature(option, num_kwarg_only)
87  if sig not in seen_signatures:
88  if num_kwarg_only > 0:
89  for arg in option['arguments'][-num_kwarg_only:]:
90  arg['kwarg_only'] = True
91  unique.append(option)
92  seen_signatures.add(sig)
93  break
94  return unique
95 
96 
97 def enumerate_options_due_to_default(declaration,
98  allow_kwarg=True, type_to_signature=[], remove_self=True):
99 
100  # Checks to see if an argument with a default keyword is a Tensor that
101  # by default can be NULL. In this case, instead of generating another
102  # option that excludes this argument, we will instead generate a single
103  # function call that allows for the Tensor to be NULL
104  def is_nullable_tensor_arg(arg):
105  return arg['type'] == 'THTensor*' and arg['default'] == 'nullptr'
106 
107  # TODO(zach): in cwrap this is shared among all declarations
108  # but seems to assume that all declarations will have the same
109  new_options = []
110  for option in declaration['options']:
111  optional_args = []
112  for i, arg in enumerate(option['arguments']):
113  if 'default' in arg:
114  optional_args.append(i)
115  for permutation in product((True, False), repeat=len(optional_args)):
116  option_copy = deepcopy(option)
117  option_copy['has_full_argument_list'] = sum(permutation) == len(optional_args)
118  for i, bit in zip(optional_args, permutation):
119  arg = option_copy['arguments'][i]
120  # PyYAML interprets NULL as None...
121  arg['default'] = 'NULL' if arg['default'] is None else arg['default']
122  if not bit:
123  arg['declared_type'] = arg['type']
124  arg['type'] = 'CONSTANT'
125  arg['ignore_check'] = True
126  new_options.append(option_copy)
127  declaration['options'] = filter_unique_options(new_options,
128  allow_kwarg, type_to_signature, remove_self)
129 
130 
131 def sort_by_number_of_options(declaration, reverse=True):
132  def num_checked_args(option):
133  return sum(map(lambda a: not a.get('ignore_check', False), option['arguments']))
134  declaration['options'].sort(key=num_checked_args, reverse=reverse)
135 
136 
137 class Function(object):
138 
139  def __init__(self, name):
140  self.name = name
141  self.arguments = []
142 
143  def add_argument(self, arg):
144  assert isinstance(arg, Argument)
145  self.arguments.append(arg)
146 
147  def __repr__(self):
148  return self.name + '(' + ', '.join(map(lambda a: a.__repr__(), self.arguments)) + ')'
149 
150 
151 class Argument(object):
152 
153  def __init__(self, _type, name, is_optional):
154  self.type = _type
155  self.name = name
156  self.is_optional = is_optional
157 
158  def __repr__(self):
159  return self.type + ' ' + self.name
160 
161 
162 def parse_header(path):
163  with open(path, 'r') as f:
164  lines = f.read().split('\n')
165 
166  # Remove empty lines and prebackend directives
167  lines = filter(lambda l: l and not l.startswith('#'), lines)
168  # Remove line comments
169  lines = map(lambda l: l.partition('//'), lines)
170  # Select line and comment part
171  lines = map(lambda l: (l[0].strip(), l[2].strip()), lines)
172  # Remove trailing special signs
173  lines = map(lambda l: (l[0].rstrip(');').rstrip(','), l[1]), lines)
174  # Split arguments
175  lines = map(lambda l: (l[0].split(','), l[1]), lines)
176  # Flatten lines
177  new_lines = []
178  for l, c in lines:
179  for split in l:
180  new_lines.append((split, c))
181  lines = new_lines
182  del new_lines
183  # Remove unnecessary whitespace
184  lines = map(lambda l: (l[0].strip(), l[1]), lines)
185  # Remove empty lines
186  lines = filter(lambda l: l[0], lines)
187  generic_functions = []
188  for l, c in lines:
189  if l.startswith('TH_API void THNN_'):
190  fn_name = l.lstrip('TH_API void THNN_')
191  if fn_name[0] == '(' and fn_name[-2] == ')':
192  fn_name = fn_name[1:-2]
193  else:
194  fn_name = fn_name[:-1]
195  generic_functions.append(Function(fn_name))
196  elif l.startswith('THC_API void THNN_'):
197  fn_name = l.lstrip('THC_API void THNN_')
198  if fn_name[0] == '(' and fn_name[-2] == ')':
199  fn_name = fn_name[1:-2]
200  else:
201  fn_name = fn_name[:-1]
202  generic_functions.append(Function(fn_name))
203  elif l:
204  t, name = l.split()
205  if '*' in name:
206  t = t + '*'
207  name = name[1:]
208  generic_functions[-1].add_argument(
209  Argument(t, name, '[OPTIONAL]' in c))
210  return generic_functions
Module caffe2.python.layers.split.