3 from string
import Template, ascii_lowercase
4 from ..cwrap
import cwrap
5 from ..cwrap.plugins
import NNExtension, NullableArguments, AutoGPU
6 from ..shared
import import_module
8 from ..shared._utils_internal
import get_file_path
10 THNN_H_PATH = get_file_path(
'torch',
'include',
'THNN',
'generic',
'THNN.h')
11 THCUNN_H_PATH = get_file_path(
'torch',
'include',
'THCUNN',
'generic',
'THCUNN.h')
13 THNN_UTILS_PATH = get_file_path(
'torch',
'_thnn',
'utils.py')
15 thnn_utils = import_module(
'torch._thnn.utils', THNN_UTILS_PATH)
17 FUNCTION_TEMPLATE = Template(
"""\ 26 'THIndex_t':
'int64_t',
27 'THCIndex_t':
'int64_t',
30 COMMON_CPU_TRANSFORMS = {
31 'THNNState*':
'void*',
32 'THIndexTensor*':
'THLongTensor*',
33 'THIntegerTensor*':
'THIntTensor*',
35 COMMON_GPU_TRANSFORMS = {
37 'THCIndexTensor*':
'THCudaLongTensor*',
42 'THTensor*':
'THFloatTensor*',
47 'THTensor*':
'THDoubleTensor*',
52 'THCTensor*':
'THCudaHalfTensor*',
57 'THCTensor*':
'THCudaTensor*',
62 'THCTensor*':
'THCudaDoubleTensor*',
67 for t, transforms
in TYPE_TRANSFORMS.items():
68 transforms.update(COMMON_TRANSFORMS)
70 for t
in [
'Float',
'Double']:
71 TYPE_TRANSFORMS[t].update(COMMON_CPU_TRANSFORMS)
72 for t
in [
'CudaHalf',
'Cuda',
'CudaDouble']:
73 TYPE_TRANSFORMS[t].update(COMMON_GPU_TRANSFORMS)
76 def wrap_function(name, type, arguments):
77 cname =
'THNN_' + type + name
79 declaration +=
'TH_API void ' + cname + \
80 '(' +
', '.join(TYPE_TRANSFORMS[type].get(arg.type, arg.type)
81 for arg
in arguments) +
');\n' 82 declaration += FUNCTION_TEMPLATE.substitute(name=type + name, cname=cname)
85 prefix = indent +
'- ' 87 if not arg.is_optional:
88 declaration += prefix + \
89 TYPE_TRANSFORMS[type].get(
90 arg.type, arg.type) +
' ' + arg.name +
'\n' 92 t = TYPE_TRANSFORMS[type].get(arg.type, arg.type)
93 declaration += prefix +
'type: ' + t +
'\n' + \
94 dict_indent +
'name: ' + arg.name +
'\n' + \
95 dict_indent +
'nullable: True' +
'\n' 96 declaration +=
']]\n\n\n' 100 def generate_wrappers(nn_root=None, install_dir=None, template_path=None):
101 wrap_nn(os.path.join(nn_root,
'THNN',
'generic',
'THNN.h')
if nn_root
else None, install_dir, template_path)
102 wrap_cunn(os.path.join(nn_root,
'THCUNN',
'generic',
'THCUNN.h')
if nn_root
else None, install_dir, template_path)
105 def wrap_nn(thnn_h_path, install_dir, template_path):
106 wrapper =
'#include <TH/TH.h>\n\n\n' 107 nn_functions = thnn_utils.parse_header(thnn_h_path
or THNN_H_PATH)
108 for fn
in nn_functions:
109 for t
in [
'Float',
'Double']:
110 wrapper += wrap_function(fn.name, t, fn.arguments)
111 install_dir = install_dir
or 'torch/csrc/nn' 113 os.makedirs(install_dir)
116 with open(os.path.join(install_dir,
'THNN.cwrap'),
'w')
as f:
118 cwrap(os.path.join(install_dir,
'THNN.cwrap'),
119 plugins=[NNExtension(
'torch._C._THNN'), NullableArguments()],
120 template_path=template_path)
123 def wrap_cunn(thcunn_h_path, install_dir, template_path):
124 wrapper =
'#include <TH/TH.h>\n' 125 wrapper +=
'#include <THC/THC.h>\n\n\n' 126 cunn_functions = thnn_utils.parse_header(thcunn_h_path
or THCUNN_H_PATH)
127 for fn
in cunn_functions:
128 for t
in [
'CudaHalf',
'Cuda',
'CudaDouble']:
129 wrapper += wrap_function(fn.name, t, fn.arguments)
130 install_dir = install_dir
or 'torch/csrc/nn' 131 with open(os.path.join(install_dir,
'THCUNN.cwrap'),
'w')
as f:
133 cwrap(os.path.join(install_dir,
'THCUNN.cwrap'),
134 plugins=[NNExtension(
'torch._C._THCUNN'), NullableArguments(), AutoGPU(has_self=
False)],
135 template_path=template_path)