2 from copy
import deepcopy
3 from function_wrapper
import TYPE_FORMAL_GENERIC
4 import common_with_cwrap
22 all_types = type_map[
'floating_point'] + type_map[
'integral']
23 type_map[
'all'] = all_types
25 all_backends = [
'CPU',
'CUDA',
'SparseCPU',
'SparseCUDA']
26 default_backends = [
'CPU',
'CUDA']
29 def process_types_and_backends(option):
33 if 'backend_type_pairs' not in option:
34 backends = option.get(
'backends', default_backends)
35 if isinstance(option.get(
'type_method_definition_dispatch'), dict):
36 backends = option.get(
'type_method_definition_dispatch').keys()
37 backends = set(backends)
39 types = option.get(
'types', all_types)
41 pairs = [[p, t]
for p
in backends
for t
in types]
43 pairs = option[
'backend_type_pairs']
48 assert(p
in all_backends)
50 return [(p, tt)
for tt
in type_map[t]]
51 assert(t
in all_types)
53 pairs = set(p
for pair
in pairs
for p
in expand(pair))
56 for arg
in option.get(
'arguments', []):
57 if arg[
'type'] ==
'THSTensor*':
58 pairs.discard((
'CUDA',
'Half'))
61 if not option.get(
'cpu_half',
False):
62 pairs.discard((
'CPU',
'Half'))
64 if not option.get(
'cpu_bool',
False):
65 pairs.discard((
'CPU',
'Bool'))
68 pairs.discard((
'CUDA',
'Bool'))
71 option[
'backend_type_pairs'] = sorted([p
for p
in pairs])
74 def exclude(declaration):
75 return 'only_register' in declaration
or declaration.get(
'name') ==
'ndimension' 78 def add_variants(option):
79 option.setdefault(
'variants', [
'method'])
86 def handle_outputs_taken_as_arguments(options):
90 return (arg[
'type']
in {
'THIntegerTensor*',
'THTensor*'}
and 91 arg.get(
'default',
'')
in {
None,
'NULL',
'nullptr'})
93 def should_generate_out_variant(option):
94 if 'function' in option[
'variants']
and option[
'mode'] !=
'native':
96 return re.search(
'(^__i|[^_]_$)', option[
'api_name'])
is None 99 for option
in options:
100 for arg
in option[
'arguments']:
103 arg[
'is_nullable'] =
True 105 if any(
'output' in arg
for arg
in option[
'arguments']):
106 allocate_option = deepcopy(option)
108 for arg
in allocate_option[
'arguments']:
110 arg[
'allocate'] =
True 115 if should_generate_out_variant(option):
116 if 'method' in option[
'variants']:
117 option[
'variants'].remove(
'method')
118 option[
'api_name'] +=
'_out' 119 new_options.append(option)
121 new_options.append(allocate_option)
123 new_options.append(option)
127 def sanitize_return(option):
128 ret = option[
'return']
129 m = re.match(
r'argument (\d+(,\d+)*)', ret)
131 arguments = [int(x)
for x
in m.group(1).
split(
',')]
132 option[
'return'] = {
'kind':
'arguments',
'arguments': arguments}
134 option[
'return'] = {
'kind':
'arguments',
'arguments': []}
135 for i, x
in enumerate(option[
'arguments']):
136 if x[
'name'] ==
'self':
137 option[
'return'][
'arguments'].append(i)
140 option[
'return'] = {
'kind':
'type',
'type': option[
'return']}
143 def set_mode(option):
144 option[
'mode'] = option.get(
'mode',
'TH')
155 def discover_zero_dim_tensor_operations(declaration):
157 return arg.get(
'ignore_check')
159 def signature(option, i=None, value=None):
160 elements = [TYPE_FORMAL_GENERIC.get(arg[
'type'], arg[
'type'])
161 if i
is None or j != i
else value
162 for j, arg
in enumerate(option[
'arguments'])
164 return '#'.join(elements)
165 signature_to_option = {signature(option): option
166 for option
in declaration[
'options']}
168 for option
in declaration[
'options']:
169 for i, arg
in enumerate(option[
'arguments']):
170 if arg[
'type'] ==
'real':
171 signature_of_tensor_version = signature(option, i,
'Tensor &')
172 if signature_of_tensor_version
in signature_to_option:
174 signature_to_option[signature_of_tensor_version]
175 names = [arg[
'name']
for arg
in tensor_version[
'arguments']
177 tensor_version[
'zero_dim_dispatch_when_scalar'] = names[i]
186 def discover_sparse_tensor_operations(declaration):
188 return arg.get(
'ignore_check')
190 def signature(option, i=None, value=None):
191 elements = [TYPE_FORMAL_GENERIC.get(arg[
'type'], arg[
'type'])
192 if i
is None or j != i
else value
193 for j, arg
in enumerate(option[
'arguments'])
195 return '#'.join(elements)
198 dense_sparse_options = [option
199 for option
in declaration[
'options']
200 if option.get(
'aten_dense_sparse',
False)]
201 if len(dense_sparse_options) > 0:
202 signature_to_option = {signature(option): option
203 for option
in declaration[
'options']}
205 for option
in declaration[
'options']:
206 for i, arg
in enumerate(option[
'arguments']):
207 if (arg[
'type'] ==
'THSTensor*' and 208 option.get(
'aten_dense_sparse',
False)):
209 signature_of_tensor_version = signature(
210 option, i,
'Tensor &')
211 if signature_of_tensor_version
in signature_to_option:
213 signature_to_option[signature_of_tensor_version]
214 raw_args = len(tensor_version[
'arguments'])
215 names = [arg[
'name']
for arg
in tensor_version[
'arguments']
217 filtered_args = len(names)
218 tensor_version[
'when_sparse_dispatch'] = names[i -
219 (raw_args - filtered_args)]
222 def is_extended_method(option):
223 if 'method' in option[
'variants']:
229 def run(declarations):
230 declarations = [d
for d
in declarations
if not exclude(d)]
231 non_extended_methods = set()
232 for declaration
in declarations:
234 declaration[
'options'] = [deepcopy(o)
for o
in declaration[
'options']]
236 declaration[
'options'],
238 type_to_signature=TYPE_FORMAL_GENERIC,
243 discover_zero_dim_tensor_operations(declaration)
244 discover_sparse_tensor_operations(declaration)
246 for option
in declaration[
'options']:
248 if option[
'mode'] !=
'native':
249 sanitize_return(option)
250 process_types_and_backends(option)
252 if not is_extended_method(option):
253 non_extended_methods.add(option[
'api_name'])
254 declaration[
'options'] = handle_outputs_taken_as_arguments(
255 declaration[
'options'])
264 for declaration
in declarations:
265 for option
in declaration[
'options']:
266 option[
'extended_method'] = option[
'api_name']
not in non_extended_methods
Module caffe2.python.layers.split.
def set_declaration_defaults(declaration)
def sort_by_number_of_options(declaration, reverse=True)
def filter_unique_options(options, allow_kwarg, type_to_signature, remove_self)