Caffe2 - Python API
A deep learning, cross platform ML framework
generate_wrappers.py
1 import os
2 import sys
3 from string import Template, ascii_lowercase
4 from ..cwrap import cwrap
5 from ..cwrap.plugins import NNExtension, NullableArguments, AutoGPU
6 from ..shared import import_module
7 
8 from ..shared._utils_internal import get_file_path
9 
10 THNN_H_PATH = get_file_path('torch', 'include', 'THNN', 'generic', 'THNN.h')
11 THCUNN_H_PATH = get_file_path('torch', 'include', 'THCUNN', 'generic', 'THCUNN.h')
12 
13 THNN_UTILS_PATH = get_file_path('torch', '_thnn', 'utils.py')
14 
15 thnn_utils = import_module('torch._thnn.utils', THNN_UTILS_PATH)
16 
17 FUNCTION_TEMPLATE = Template("""\
18 [[
19  name: $name
20  return: void
21  cname: $cname
22  arguments:
23 """)
24 
25 COMMON_TRANSFORMS = {
26  'THIndex_t': 'int64_t',
27  'THCIndex_t': 'int64_t',
28  'THInteger_t': 'int',
29 }
30 COMMON_CPU_TRANSFORMS = {
31  'THNNState*': 'void*',
32  'THIndexTensor*': 'THLongTensor*',
33  'THIntegerTensor*': 'THIntTensor*',
34 }
35 COMMON_GPU_TRANSFORMS = {
36  'THCState*': 'void*',
37  'THCIndexTensor*': 'THCudaLongTensor*',
38 }
39 
40 TYPE_TRANSFORMS = {
41  'Float': {
42  'THTensor*': 'THFloatTensor*',
43  'real': 'float',
44  'accreal': 'double',
45  },
46  'Double': {
47  'THTensor*': 'THDoubleTensor*',
48  'real': 'double',
49  'accreal': 'double',
50  },
51  'CudaHalf': {
52  'THCTensor*': 'THCudaHalfTensor*',
53  'real': 'half',
54  'accreal': 'float',
55  },
56  'Cuda': {
57  'THCTensor*': 'THCudaTensor*',
58  'real': 'float',
59  'accreal': 'float',
60  },
61  'CudaDouble': {
62  'THCTensor*': 'THCudaDoubleTensor*',
63  'real': 'double',
64  'accreal': 'double',
65  },
66 }
67 for t, transforms in TYPE_TRANSFORMS.items():
68  transforms.update(COMMON_TRANSFORMS)
69 
70 for t in ['Float', 'Double']:
71  TYPE_TRANSFORMS[t].update(COMMON_CPU_TRANSFORMS)
72 for t in ['CudaHalf', 'Cuda', 'CudaDouble']:
73  TYPE_TRANSFORMS[t].update(COMMON_GPU_TRANSFORMS)
74 
75 
76 def wrap_function(name, type, arguments):
77  cname = 'THNN_' + type + name
78  declaration = ''
79  declaration += 'TH_API void ' + cname + \
80  '(' + ', '.join(TYPE_TRANSFORMS[type].get(arg.type, arg.type)
81  for arg in arguments) + ');\n'
82  declaration += FUNCTION_TEMPLATE.substitute(name=type + name, cname=cname)
83  indent = ' ' * 4
84  dict_indent = ' ' * 6
85  prefix = indent + '- '
86  for arg in arguments:
87  if not arg.is_optional:
88  declaration += prefix + \
89  TYPE_TRANSFORMS[type].get(
90  arg.type, arg.type) + ' ' + arg.name + '\n'
91  else:
92  t = TYPE_TRANSFORMS[type].get(arg.type, arg.type)
93  declaration += prefix + 'type: ' + t + '\n' + \
94  dict_indent + 'name: ' + arg.name + '\n' + \
95  dict_indent + 'nullable: True' + '\n'
96  declaration += ']]\n\n\n'
97  return declaration
98 
99 
100 def generate_wrappers(nn_root=None, install_dir=None, template_path=None):
101  wrap_nn(os.path.join(nn_root, 'THNN', 'generic', 'THNN.h') if nn_root else None, install_dir, template_path)
102  wrap_cunn(os.path.join(nn_root, 'THCUNN', 'generic', 'THCUNN.h') if nn_root else None, install_dir, template_path)
103 
104 
105 def wrap_nn(thnn_h_path, install_dir, template_path):
106  wrapper = '#include <TH/TH.h>\n\n\n'
107  nn_functions = thnn_utils.parse_header(thnn_h_path or THNN_H_PATH)
108  for fn in nn_functions:
109  for t in ['Float', 'Double']:
110  wrapper += wrap_function(fn.name, t, fn.arguments)
111  install_dir = install_dir or 'torch/csrc/nn'
112  try:
113  os.makedirs(install_dir)
114  except OSError:
115  pass
116  with open(os.path.join(install_dir, 'THNN.cwrap'), 'w') as f:
117  f.write(wrapper)
118  cwrap(os.path.join(install_dir, 'THNN.cwrap'),
119  plugins=[NNExtension('torch._C._THNN'), NullableArguments()],
120  template_path=template_path)
121 
122 
123 def wrap_cunn(thcunn_h_path, install_dir, template_path):
124  wrapper = '#include <TH/TH.h>\n'
125  wrapper += '#include <THC/THC.h>\n\n\n'
126  cunn_functions = thnn_utils.parse_header(thcunn_h_path or THCUNN_H_PATH)
127  for fn in cunn_functions:
128  for t in ['CudaHalf', 'Cuda', 'CudaDouble']:
129  wrapper += wrap_function(fn.name, t, fn.arguments)
130  install_dir = install_dir or 'torch/csrc/nn'
131  with open(os.path.join(install_dir, 'THCUNN.cwrap'), 'w') as f:
132  f.write(wrapper)
133  cwrap(os.path.join(install_dir, 'THCUNN.cwrap'),
134  plugins=[NNExtension('torch._C._THCUNN'), NullableArguments(), AutoGPU(has_self=False)],
135  template_path=template_path)