2 """ The Python Hipify script. 4 # Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved. 5 # 2017-2018 Advanced Micro Devices, Inc. and 6 # Facebook Inc. All rights reserved. 8 # Permission is hereby granted, free of charge, to any person obtaining a copy 9 # of this software and associated documentation files (the "Software"), to deal 10 # in the Software without restriction, including without limitation the rights 11 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 # copies of the Software, and to permit persons to whom the Software is 13 # furnished to do so, subject to the following conditions: 15 # The above copyright notice and this permission notice shall be included in 16 # all copies or substantial portions of the Software. 18 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 from __future__
import absolute_import, division, print_function
38 from pyHIPIFY
import constants
43 """This dictionary provides the mapping from PyTorch kernel template types 44 to their actual types.""" 45 PYTORCH_TEMPLATE_MAP = {
"Dtype":
"scalar_t",
"T":
"scalar_t"}
46 CAFFE2_TEMPLATE_MAP = {}
52 def __init__(self, message):
53 super(InputError, self).__init__(message)
57 return "{}: {}".format(
"Input error", self.
message)
60 def openf(filename, mode):
61 if sys.version_info[0] == 3:
62 return open(filename, mode, errors=
'ignore')
64 return open(filename, mode)
80 """ How to disable functions 81 REMOVE - Remove the function entirely (includes the signature). 84 ```ret_type function(arg_type1 arg1, ..., ){ 94 STUB - Stub the function and return an empty object based off the type. 97 ```ret_type function(arg_type1 arg1, ..., ){ 104 ```ret_type function(arg_type1 arg1, ..., ){ 110 HCC_MACRO - Add !defined(__HIP_PLATFORM_HCC__) preprocessors around the function. 111 This macro is defined by HIP if the compiler used is hcc. 114 ```ret_type function(arg_type1 arg1, ..., ){ 121 ```#if !defined(__HIP_PLATFORM_HCC__) 122 ret_type function(arg_type1 arg1, ..., ){ 131 DEVICE_MACRO - Add !defined(__HIP_DEVICE_COMPILE__) preprocessors around the function. 132 This macro is defined by HIP if either hcc or nvcc are used in the device path. 135 ```ret_type function(arg_type1 arg1, ..., ){ 142 ```#if !defined(__HIP_DEVICE_COMPILE__) 143 ret_type function(arg_type1 arg1, ..., ){ 152 EXCEPTION - Stub the function and throw an exception at runtime. 155 ```ret_type function(arg_type1 arg1, ..., ){ 162 ```ret_type function(arg_type1 arg1, ..., ){ 163 throw std::runtime_error("The function function is not implemented.") 167 ASSERT - Stub the function and throw an assert(0). 170 ```ret_type function(arg_type1 arg1, ..., ){ 177 ```ret_type function(arg_type1 arg1, ..., ){ 182 EMPTYBODY - Stub the function and keep an empty body. 185 ```ret_type function(arg_type1 arg1, ..., ){ 192 ```ret_type function(arg_type1 arg1, ..., ){ 208 def matched_files_iter(root_path, includes=(
'*',), ignores=(), extensions=(), out_of_place_only=
False):
209 def _fnmatch(filepath, patterns):
210 return any(fnmatch.fnmatch(filepath, pattern)
for pattern
in patterns)
212 def match_extensions(filename):
213 """Helper method to see if filename ends with certain extension""" 214 return any(filename.endswith(e)
for e
in extensions)
216 exact_matches = set(includes)
223 for (abs_dirpath, dirs, filenames)
in os.walk(root_path, topdown=
True):
224 rel_dirpath = os.path.relpath(abs_dirpath, root_path)
225 if rel_dirpath ==
'.':
231 if "third_party" in dirs:
232 dirs.remove(
"third_party")
233 for filename
in filenames:
234 filepath = os.path.join(rel_dirpath, filename)
237 if _fnmatch(filepath, includes)
and (
not _fnmatch(filepath, ignores))
and (match_extensions(filepath)
or filepath
in exact_matches):
238 if not is_pytorch_file(filepath)
and not is_caffe2_gpu_file(filepath):
240 if out_of_place_only
and not is_out_of_place(filepath):
250 hip_clang_launch=
False):
252 Call preprocessor on selected files. 255 show_detailed - Show a detailed summary of the transpilation process. 259 stats = {
"unsupported_calls": [],
"kernel_launches": []}
261 for filepath
in all_files:
262 preprocessor(output_directory, filepath, stats, hip_clang_launch)
269 print(bcolors.OKGREEN +
"Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)
276 def compute_stats(stats):
277 unsupported_calls = {cuda_call
for (cuda_call, _filepath)
in stats[
"unsupported_calls"]}
280 print(
"Total number of unsupported CUDA function calls: {0:d}".format(len(unsupported_calls)))
283 print(
", ".join(unsupported_calls))
286 print(
"\nTotal number of replaced kernel launches: {0:d}".format(len(stats[
"kernel_launches"])))
290 '''adds dim3() to the second and third arguments in the kernel launch''' 293 kernel_string = kernel_string.replace(
"<<<",
"").replace(
">>>",
"")
294 arg_locs = [{}
for _
in range(2)]
295 arg_locs[count][
'start'] = 0
296 for ind, c
in enumerate(kernel_string):
303 elif (c ==
"," or ind == len(kernel_string) - 1)
and closure == 0:
304 arg_locs[count][
'end'] = ind + (c !=
",")
307 arg_locs[count][
'start'] = ind + 1
309 first_arg_raw = kernel_string[arg_locs[0][
'start']:arg_locs[0][
'end'] + 1]
310 second_arg_raw = kernel_string[arg_locs[1][
'start']:arg_locs[1][
'end']]
312 first_arg_clean = kernel_string[arg_locs[0][
'start']:arg_locs[0][
'end']].replace(
"\n",
"").strip(
" ")
313 second_arg_clean = kernel_string[arg_locs[1][
'start']:arg_locs[1][
'end']].replace(
"\n",
"").strip(
" ")
315 first_arg_dim3 =
"dim3({})".format(first_arg_clean)
316 second_arg_dim3 =
"dim3({})".format(second_arg_clean)
318 first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3)
319 second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3)
320 cuda_kernel = cuda_kernel.replace(first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3)
324 RE_KERNEL_LAUNCH = re.compile(
r'([ ]+)(detail?)::[ ]+\\\n[ ]+')
328 """ Replace the CUDA style Kernel launches with the HIP style kernel launches.""" 330 string = RE_KERNEL_LAUNCH.sub(
lambda inp:
"{0}{1}::".format(inp.group(1), inp.group(2)), string)
332 def grab_method_and_template(in_kernel):
335 "kernel_launch": {
"start": in_kernel[
"start"],
"end": in_kernel[
"end"]},
336 "kernel_name": {
"start": -1,
"end": -1},
337 "template": {
"start": -1,
"end": -1}
352 for i
in range(pos[
"kernel_launch"][
"start"] - 1, -1, -1):
356 if status == START
or status == AT_TEMPLATE:
360 pos[
"template"][
"end"] = i
365 if count[
"<>"] == 0
and (status == AT_TEMPLATE):
366 pos[
"template"][
"start"] = i
367 status = AFTER_TEMPLATE
370 if status != AT_TEMPLATE:
371 if string[i].isalnum()
or string[i]
in {
'(',
')',
'_',
':',
'#'}:
372 if status != AT_KERNEL_NAME:
373 status = AT_KERNEL_NAME
374 pos[
"kernel_name"][
"end"] = i
378 pos[
"kernel_name"][
"start"] = 0
381 return [(pos[
"kernel_name"]), (pos[
"template"]), (pos[
"kernel_launch"])]
385 if status == AT_KERNEL_NAME:
386 pos[
"kernel_name"][
"start"] = i
389 return [(pos[
"kernel_name"]), (pos[
"template"]), (pos[
"kernel_launch"])]
391 def find_kernel_bounds(string):
392 """Finds the starting and ending points for all kernel launches in the string.""" 394 kernel_positions = []
397 while string.find(
"<<<", kernel_end) != -1:
399 kernel_start = string.find(
"<<<", kernel_end)
402 kernel_end = string.find(
">>>", kernel_start) + 3
407 kernel_positions.append({
"start": kernel_start,
"end": kernel_end,
408 "group": string[kernel_start: kernel_end]})
410 return kernel_positions
413 get_kernel_positions = [k
for k
in find_kernel_bounds(string)]
414 output_string = string
417 for kernel
in get_kernel_positions:
419 params = grab_method_and_template(kernel)
422 parenthesis = string.find(
"(", kernel[
"end"])
425 cuda_kernel = string[params[0][
"start"]:parenthesis + 1]
426 kernel_string = string[kernel[
'start']:kernel[
'end']]
427 cuda_kernel_dim3 =
add_dim3(kernel_string, cuda_kernel)
429 num_klp = len(
extract_arguments(0, kernel[
"group"].replace(
"<<<",
"(").replace(
">>>",
")")))
431 hip_kernel =
"hipLaunchKernelGGL(" + cuda_kernel_dim3[0:-1].replace(
432 ">>>",
", 0" * (4 - num_klp) +
">>>").replace(
"<<<",
", ").replace(
">>>",
", ")
435 output_string = output_string.replace(cuda_kernel, hip_kernel)
438 stats[
"kernel_launches"].append(hip_kernel)
444 """Generalization for finding a balancing closure group 446 if group = ["(", ")"], then finds the first balanced parantheses. 447 if group = ["{", "}"], then finds the first balanced bracket. 449 Given an input string, a starting position in the input string, and the group type, 450 find_closure_group returns the positions of group[0] and group[1] as a tuple. 453 find_closure_group("(hi)", 0, ["(", ")"]) 459 inside_parenthesis =
False 462 p_start, p_end = -1, -1
464 while pos < len(input_string):
465 if input_string[pos] == group[0]:
466 if inside_parenthesis
is False:
467 inside_parenthesis =
True 472 elif input_string[pos] == group[1]
and inside_parenthesis:
477 return p_start, p_end
484 """Finds the first balanced parantheses.""" 489 """Finds the first balanced bracket.""" 493 RE_ASSERT = re.compile(
r"\bassert[ ]*\(")
497 """ Disables regular assert statements 498 e.g. "assert(....)" -> "/*assert(....)*/" 500 output_string = input_string
501 asserts = list(RE_ASSERT.finditer(input_string))
502 for assert_item
in asserts:
504 start = assert_item.start()
505 output_string = output_string.replace(input_string[start:p_end + 1],
"")
510 """ FIXME: Temporarily replace std:: invocations of math functions with non-std:: versions to prevent linker errors 511 NOTE: This can lead to correctness issues when running tests, since the correct version of the math function (exp/expf) might not get called. 512 Plan is to remove this function once HIP supports std:: math function calls inside device code 514 output_string = input_string
515 for func
in MATH_TRANSPILATIONS:
516 output_string = output_string.replace(
r'{}('.format(func),
'{}('.format(MATH_TRANSPILATIONS[func]))
521 RE_SYNCTHREADS = re.compile(
r"[:]?[:]?\b(__syncthreads)\b(\w*\()")
525 """If the file makes kernel builtin calls and does not include the cuda_runtime.h header, 526 then automatically add an #include to match the "magic" includes provided by NVCC. 528 Update logic to ignore cases where the cuda_runtime.h is included by another file. 532 output_string = input_string
535 headers = [
"hip/hip_runtime.h",
"hip/hip_runtime_api.h"]
536 if any(re.search(
r'#include ("{0}"|<{0}>)'.format(ext), output_string)
for ext
in headers):
540 hasDeviceLogic =
"hipLaunchKernelGGL" in output_string
541 hasDeviceLogic +=
"__global__" in output_string
542 hasDeviceLogic +=
"__shared__" in output_string
543 hasDeviceLogic += RE_SYNCTHREADS.search(output_string)
is not None 547 output_string =
'#include "hip/hip_runtime.h"\n' + input_string
552 RE_EXTERN_SHARED = re.compile(
r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;")
556 """Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead. 557 https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md#__shared__ 559 "extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)" 560 "extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)" 562 output_string = input_string
563 output_string = RE_EXTERN_SHARED.sub(
564 lambda inp:
"HIP_DYNAMIC_SHARED({0} {1}, {2})".format(
565 inp.group(1)
or "", inp.group(2), inp.group(3)), output_string)
571 """ Finds and disables a function in a particular file. 573 If type(function) == List 574 function - The signature of the function to disable. 575 e.g. ["bool", "overlappingIndices", "(const Tensor& t)"] 576 disables function -> "bool overlappingIndices(const Tensor& t)" 578 If type(function) == String 579 function - Disables the function by name only. 580 e.g. "overlappingIndices" 582 replace_style - The style to use when stubbing functions. 586 "function_start": -1,
597 if type(function) == list:
600 "return_type": function[0].strip(),
601 "function_name": function[1].strip(),
602 "function_args": function[2].strip()
606 function_string =
"{0}{1}{2}".format(
607 func_info[
"return_type"],
608 func_info[
"function_name"],
609 func_info[
"function_args"]
613 info[
"function_start"] = input_string.find(function_string)
616 the_match = re.search(
r"(((.*) (\*)?)({0})(\([^{{)]*\)))\s*{{".format(
617 function.replace(
"(",
r"\(").replace(
")",
r"\)")), input_string)
618 if the_match
is None:
622 "return_type": the_match.group(2).strip(),
623 "function_name": the_match.group(5).strip(),
624 "function_args": the_match.group(6).strip(),
628 info[
"function_start"] = the_match.start()
629 function_string = the_match.group(1)
632 if info[
"function_start"] == -1:
636 pos = info[
"function_start"] + len(function_string) - 1
637 while pos < len(input_string)
and STATE != BRACKET_COMPLETE:
638 if input_string[pos] ==
"{":
639 if STATE != INSIDE_FUNCTION:
640 STATE = INSIDE_FUNCTION
641 info[
"bracket_count"] = 1
643 info[
"bracket_count"] += 1
644 elif input_string[pos] ==
"}":
645 info[
"bracket_count"] -= 1
647 if info[
"bracket_count"] == 0
and STATE == INSIDE_FUNCTION:
648 STATE = BRACKET_COMPLETE
649 info[
"function_end"] = pos
654 if STATE != BRACKET_COMPLETE:
658 function_body = input_string[info[
"function_start"]:info[
"function_end"] + 1]
661 if replace_style == disablefuncmode.REMOVE:
662 output_string = input_string.replace(function_body,
"")
665 elif replace_style == disablefuncmode.STUB:
667 if func_info[
"return_type"] ==
"void" or func_info[
"return_type"] ==
"static void":
668 stub =
"{0}{{\n}}".format(function_string)
670 elif "*" in func_info[
"return_type"]:
671 stub =
"{0}{{\nreturn {1};\n}}".format(function_string,
"NULL")
673 stub =
"{0}{{\n{1} stub_var;\nreturn stub_var;\n}}".format(function_string, func_info[
"return_type"])
675 output_string = input_string.replace(function_body, stub)
678 elif replace_style == disablefuncmode.HCC_MACRO:
679 output_string = input_string.replace(
681 "#if !defined(__HIP_PLATFORM_HCC__)\n{0}\n#endif".format(function_body))
684 elif replace_style == disablefuncmode.DEVICE_MACRO:
685 output_string = input_string.replace(
687 "#if !defined(__HIP_DEVICE_COMPILE__)\n{0}\n#endif".format(function_body))
690 elif replace_style == disablefuncmode.EXCEPTION:
691 stub =
"{0}{{\n{1};\n}}".format(
693 'throw std::runtime_error("The function {0} is not implemented.")'.format(
694 function_string.replace(
"\n",
" ")))
695 output_string = input_string.replace(function_body, stub)
697 elif replace_style == disablefuncmode.ASSERT:
698 stub =
"{0}{{\n{1};\n}}".format(
701 output_string = input_string.replace(function_body, stub)
703 elif replace_style == disablefuncmode.EMPTYBODY:
704 stub =
"{0}{{\n;\n}}".format(function_string)
705 output_string = input_string.replace(function_body, stub)
711 Returns the new name of the hipified file 715 if not is_out_of_place(filepath):
718 dirpath, filename = os.path.split(filepath)
719 root, ext = os.path.splitext(filename)
754 orig_dirpath = dirpath
756 dirpath = dirpath.replace(
'cuda',
'hip')
757 dirpath = dirpath.replace(
'THC',
'THH')
759 root = root.replace(
'cuda',
'hip')
760 root = root.replace(
'CUDA',
'HIP')
762 if dirpath !=
"caffe2/core":
763 root = root.replace(
'THC',
'THH')
765 if dirpath == orig_dirpath:
766 dirpath = os.path.join(dirpath,
'hip')
768 return os.path.join(dirpath, root + ext)
771 def is_out_of_place(filepath):
772 if filepath.startswith(
"torch/"):
774 if filepath.startswith(
"tools/autograd/templates/"):
780 def is_pytorch_file(filepath):
781 if filepath.startswith(
"aten/"):
782 if filepath.startswith(
"aten/src/ATen/core/"):
785 if filepath.startswith(
"torch/"):
787 if filepath.startswith(
"tools/autograd/templates/"):
792 def is_caffe2_gpu_file(filepath):
793 if filepath.startswith(
"c10/cuda"):
795 filename = os.path.basename(filepath)
796 _, ext = os.path.splitext(filename)
797 return (
'gpu' in filename
or ext
in [
'.cu',
'.cuh'])
and (
'cudnn' not in filename)
802 """Regex::Trie in Python. Creates a Trie out of a list of words. The trie can be exported to a Regex pattern. 803 The corresponding Regex should match much faster than a simple Regex union.""" 811 ref[char] = char
in ref
and ref[char]
or {}
818 def quote(self, char):
819 return re.escape(char)
821 def _pattern(self, pData):
823 if "" in data
and len(data.keys()) == 1:
829 for char
in sorted(data.keys()):
830 if isinstance(data[char], dict):
833 alt.append(self.
quote(char) + recurse)
835 cc.append(self.
quote(char))
838 cconly =
not len(alt) > 0
844 alt.append(
'[' +
''.join(cc) +
']')
849 result =
"(?:" +
"|".join(alt) +
")" 855 result =
"(?:%s)?" % result
864 PYTORCH_TRIE =
Trie()
866 for mapping
in CUDA_TO_HIP_MAPPINGS:
867 for src, value
in mapping.items():
869 meta_data = value[1:]
870 if constants.API_CAFFE2
not in meta_data:
871 PYTORCH_TRIE.add(src)
872 PYTORCH_MAP[src] = dst
873 if constants.API_PYTORCH
not in meta_data:
875 CAFFE2_MAP[src] = dst
876 RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.pattern())
877 RE_PYTORCH_PREPROCESSOR = re.compile(
r'\b{0}\b'.format(PYTORCH_TRIE.pattern()))
879 RE_QUOTE_HEADER = re.compile(
r'#include "([^"]+)"')
880 RE_ANGLE_HEADER = re.compile(
r'#include <([^>]+)>')
881 RE_THC_GENERIC_FILE = re.compile(
r'#define THC_GENERIC_FILE "([^"]+)"')
882 RE_CU_SUFFIX = re.compile(
r'\.cu\b')
885 """ Executes the CUDA -> HIP conversion on the specified file. """ 886 fin_path = os.path.join(output_directory, filepath)
887 with open(fin_path,
'r') as fin: 888 output_source = fin.read() 891 if not os.path.exists(os.path.dirname(fout_path)):
892 os.makedirs(os.path.dirname(fout_path))
894 with open(fout_path,
'w')
as fout:
896 if is_pytorch_file(filepath):
898 return PYTORCH_MAP[m.group(0)]
899 output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
902 return CAFFE2_MAP[m.group(0)]
903 output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source)
909 if f.startswith(
"ATen/cuda")
or f.startswith(
"ATen/native/cuda")
or f.startswith(
"ATen/native/sparse/cuda")
or f.startswith(
"THC/")
or f.startswith(
"THCUNN/")
or (f.startswith(
"THC")
and not f.startswith(
"THCP")):
913 output_source = RE_QUOTE_HEADER.sub(mk_repl(
'#include "{0}"'), output_source)
914 output_source = RE_ANGLE_HEADER.sub(mk_repl(
'#include <{0}>'), output_source)
915 output_source = RE_THC_GENERIC_FILE.sub(mk_repl(
'#define THC_GENERIC_FILE "{0}"'), output_source)
918 if filepath.endswith(
'CMakeLists.txt'):
919 output_source = output_source.replace(
'CUDA',
'HIP')
920 output_source = output_source.replace(
'THC',
'THH')
921 output_source = RE_CU_SUFFIX.sub(
'.hip', output_source)
924 if not hip_clang_launch:
932 if filepath.endswith(
".cu")
or filepath.endswith(
".cuh"):
941 fout.write(output_source)
944 def file_specific_replacement(filepath, search_string, replace_string, strict=False):
945 with openf(filepath,
"r+")
as f:
948 contents = re.sub(
r'\b({0})\b'.format(re.escape(search_string)),
lambda x: replace_string, contents)
950 contents = contents.replace(search_string, replace_string)
956 def file_add_header(filepath, header):
957 with openf(filepath,
"r+")
as f:
959 if header[0] !=
"<" and header[-1] !=
">":
960 header =
'"{0}"'.format(header)
961 contents = (
'#include {0} \n'.format(header)) + contents
968 """Static global kernels in HIP results in a compilation error.""" 969 in_txt = in_txt.replace(
" __global__ static",
"__global__")
974 """Disables calls to an unsupported HIP function""" 976 output_string = input_string
979 calls = re.finditer(
r"\b{0}\b".format(re.escape(function)), input_string)
987 started_arguments =
False 989 while pos < len(input_string):
990 if input_string[pos] ==
"(":
991 if started_arguments
is False:
992 started_arguments =
True 996 elif input_string[pos] ==
")" and started_arguments:
999 if bracket_count == 0
and started_arguments:
1004 function_call = input_string[start:pos + 1]
1005 output_string = output_string.replace(function_call, replacement)
1007 return output_string
1010 RE_INCLUDE = re.compile(
r"#include .*\n")
1014 """Disable a module entirely except for header includes.""" 1015 with openf(input_file,
"r+")
as f:
1017 last = list(RE_INCLUDE.finditer(txt))[-1]
1020 disabled =
"{0}#if !defined(__HIP_PLATFORM_HCC__)\n{1}\n#endif".format(txt[0:end], txt[end:])
1028 """ Return the list of arguments in the upcoming function parameter closure. 1030 string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))' 1032 '[{'start': 1, 'end': 7}, 1033 {'start': 8, 'end': 16}, 1034 {'start': 17, 'end': 19}, 1035 {'start': 20, 'end': 53}]' 1043 current_position = start
1044 argument_start_pos = current_position + 1
1047 while current_position < len(string):
1048 if string[current_position] ==
"(":
1050 elif string[current_position] ==
")":
1052 elif string[current_position] ==
"<":
1054 elif string[current_position] ==
">" and string[current_position - 1] !=
"-" and closures[
"<"] > 0:
1058 if closures[
"("] == 0
and closures[
"<"] == 0:
1060 arguments.append({
"start": argument_start_pos,
"end": current_position})
1064 if closures[
"("] == 1
and closures[
"<"] == 0
and string[current_position] ==
",":
1065 arguments.append({
"start": argument_start_pos,
"end": current_position})
1066 argument_start_pos = current_position + 1
1068 current_position += 1
1074 """ArgumentParser doesn't support type=bool. Thus, this helper method will convert 1075 from possible string types to True / False.""" 1076 if v.lower()
in (
'yes',
'true',
't',
'y',
'1'):
1078 elif v.lower()
in (
'no',
'false',
'f',
'n',
'0'):
1081 raise argparse.ArgumentTypeError(
'Boolean value expected.')
1086 show_detailed=
False,
1087 extensions=(
".cu",
".cuh",
".c",
".cc",
".cpp",
".h",
".in",
".hpp"),
1088 output_directory=
"",
1091 out_of_place_only=
False,
1094 hip_clang_launch=
False,
1096 if project_directory ==
"":
1097 project_directory = os.getcwd()
1100 if not os.path.exists(project_directory):
1101 print(
"The project folder specified does not exist.")
1105 if not output_directory:
1106 project_directory.rstrip(
"/")
1107 output_directory = project_directory +
"_amd" 1110 if not os.path.exists(output_directory):
1111 shutil.copytree(project_directory, output_directory)
1114 if json_settings !=
"":
1115 with openf(json_settings,
"r") as f: 1116 json_data = json.load(f) 1119 for disable_info
in json_data[
"disabled_functions"]:
1120 filepath = os.path.join(output_directory, disable_info[
"path"])
1121 if "functions" in disable_info:
1122 functions = disable_info[
"functions"]
1124 functions = disable_info.get(
"functions", [])
1126 if "non_hip_functions" in disable_info:
1127 non_hip_functions = disable_info[
"non_hip_functions"]
1129 non_hip_functions = disable_info.get(
"non_hip_functions", [])
1131 if "non_device_functions" in disable_info:
1132 not_on_device_functions = disable_info[
"non_device_functions"]
1134 not_on_device_functions = disable_info.get(
"non_device_functions", [])
1136 with openf(filepath,
"r+")
as f:
1138 for func
in functions:
1142 for func
in non_hip_functions:
1146 for func
in not_on_device_functions:
1155 disable_modules = json_data[
"disabled_modules"]
1156 for module
in disable_modules:
1160 for disable
in json_data[
"disable_unsupported_hip_calls"]:
1161 filepath = os.path.join(output_directory, disable[
"path"])
1162 if "functions" in disable:
1163 functions = disable[
"functions"]
1165 functions = disable.get(
"functions", [])
1167 if "constants" in disable:
1168 constants = disable[
"constants"]
1170 constants = disable.get(
"constants", [])
1172 if "s_constants" in disable:
1173 s_constants = disable[
"s_constants"]
1175 s_constants = disable.get(
"s_constants", [])
1177 if not os.path.exists(filepath):
1178 print(
"\n" + bcolors.WARNING +
"JSON Warning: File {0} does not exist.".format(filepath) + bcolors.ENDC)
1181 with openf(filepath,
"r+")
as f:
1185 for func
in functions:
1189 for const
in constants:
1190 txt = re.sub(
r"\b{0}\b".format(re.escape(const)), constants[const], txt)
1193 for s_const
in s_constants:
1194 txt = txt.replace(s_const, s_constants[s_const])
1201 all_files = list(matched_files_iter(output_directory, includes=includes,
1202 ignores=ignores, extensions=extensions,
1203 out_of_place_only=out_of_place_only))
1209 show_detailed=show_detailed,
1210 show_progress=show_progress,
1211 hip_clang_launch=hip_clang_launch)
def find_bracket_group(input_string, start)
def replace_extern_shared(input_string)
def preprocess(output_directory, all_files, show_detailed=False, show_progress=True, hip_clang_launch=False)
def preprocessor(output_directory, filepath, stats, hip_clang_launch)
def processKernelLaunches(string, stats)
def disable_function(input_string, function, replace_style)
def find_parentheses_group(input_string, start)
def disable_module(input_file)
def _pattern(self, pData)
def hip_header_magic(input_string)
def fix_static_global_kernels(in_txt)
def get_hip_file_path(filepath)
def find_closure_group(input_string, start, group)
def extract_arguments(start, string)
def disable_asserts(input_string)
def replace_math_functions(input_string)
def disable_unsupported_function_call(function, input_string, replacement)
def add_dim3(kernel_string, cuda_kernel)