Caffe2 - Python API
A deep learning, cross platform ML framework
hipify_python.py
1 #!/usr/bin/env python
2 """ The Python Hipify script.
3 ##
4 # Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved.
5 # 2017-2018 Advanced Micro Devices, Inc. and
6 # Facebook Inc. All rights reserved.
7 #
8 # Permission is hereby granted, free of charge, to any person obtaining a copy
9 # of this software and associated documentation files (the "Software"), to deal
10 # in the Software without restriction, including without limitation the rights
11 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 # copies of the Software, and to permit persons to whom the Software is
13 # furnished to do so, subject to the following conditions:
14 #
15 # The above copyright notice and this permission notice shall be included in
16 # all copies or substantial portions of the Software.
17 #
18 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 # THE SOFTWARE.
25 """
26 
27 from __future__ import absolute_import, division, print_function
28 import argparse
29 import fnmatch
30 import re
31 import shutil
32 import sys
33 import os
34 import json
35 import subprocess
36 
37 from enum import Enum
38 from pyHIPIFY import constants
39 from pyHIPIFY.cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
40 from pyHIPIFY.cuda_to_hip_mappings import MATH_TRANSPILATIONS
41 
42 # Hardcode the PyTorch template map
43 """This dictionary provides the mapping from PyTorch kernel template types
44 to their actual types."""
45 PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"}
46 CAFFE2_TEMPLATE_MAP = {}
47 
48 
49 class InputError(Exception):
50  # Exception raised for errors in the input.
51 
52  def __init__(self, message):
53  super(InputError, self).__init__(message)
54  self.message = message
55 
56  def __str__(self):
57  return "{}: {}".format("Input error", self.message)
58 
59 
60 def openf(filename, mode):
61  if sys.version_info[0] == 3:
62  return open(filename, mode, errors='ignore')
63  else:
64  return open(filename, mode)
65 
66 
67 # Color coding for printing
68 class bcolors:
69  HEADER = '\033[95m'
70  OKBLUE = '\033[94m'
71  OKGREEN = '\033[92m'
72  WARNING = '\033[93m'
73  FAIL = '\033[91m'
74  ENDC = '\033[0m'
75  BOLD = '\033[1m'
76  UNDERLINE = '\033[4m'
77 
78 
79 class disablefuncmode(Enum):
80  """ How to disable functions
81  REMOVE - Remove the function entirely (includes the signature).
82  e.g.
83  FROM:
84  ```ret_type function(arg_type1 arg1, ..., ){
85  ...
86  ...
87  ...
88  }```
89 
90  TO:
91  ```
92  ```
93 
94  STUB - Stub the function and return an empty object based off the type.
95  e.g.
96  FROM:
97  ```ret_type function(arg_type1 arg1, ..., ){
98  ...
99  ...
100  ...
101  }```
102 
103  TO:
104  ```ret_type function(arg_type1 arg1, ..., ){
105  ret_type obj;
106  return obj;
107  }```
108 
109 
110  HCC_MACRO - Add !defined(__HIP_PLATFORM_HCC__) preprocessors around the function.
111  This macro is defined by HIP if the compiler used is hcc.
112  e.g.
113  FROM:
114  ```ret_type function(arg_type1 arg1, ..., ){
115  ...
116  ...
117  ...
118  }```
119 
120  TO:
121  ```#if !defined(__HIP_PLATFORM_HCC__)
122  ret_type function(arg_type1 arg1, ..., ){
123  ...
124  ...
125  ...
126  }
127  #endif
128  ```
129 
130 
131  DEVICE_MACRO - Add !defined(__HIP_DEVICE_COMPILE__) preprocessors around the function.
132  This macro is defined by HIP if either hcc or nvcc are used in the device path.
133  e.g.
134  FROM:
135  ```ret_type function(arg_type1 arg1, ..., ){
136  ...
137  ...
138  ...
139  }```
140 
141  TO:
142  ```#if !defined(__HIP_DEVICE_COMPILE__)
143  ret_type function(arg_type1 arg1, ..., ){
144  ...
145  ...
146  ...
147  }
148  #endif
149  ```
150 
151 
152  EXCEPTION - Stub the function and throw an exception at runtime.
153  e.g.
154  FROM:
155  ```ret_type function(arg_type1 arg1, ..., ){
156  ...
157  ...
158  ...
159  }```
160 
161  TO:
162  ```ret_type function(arg_type1 arg1, ..., ){
163  throw std::runtime_error("The function function is not implemented.")
164  }```
165 
166 
167  ASSERT - Stub the function and throw an assert(0).
168  e.g.
169  FROM:
170  ```ret_type function(arg_type1 arg1, ..., ){
171  ...
172  ...
173  ...
174  }```
175 
176  TO:
177  ```ret_type function(arg_type1 arg1, ..., ){
178  assert(0);
179  }```
180 
181 
182  EMPTYBODY - Stub the function and keep an empty body.
183  e.g.
184  FROM:
185  ```ret_type function(arg_type1 arg1, ..., ){
186  ...
187  ...
188  ...
189  }```
190 
191  TO:
192  ```ret_type function(arg_type1 arg1, ..., ){
193  ;
194  }```
195 
196 
197 
198  """
199  REMOVE = 0
200  STUB = 1
201  HCC_MACRO = 2
202  DEVICE_MACRO = 3
203  EXCEPTION = 4
204  ASSERT = 5
205  EMPTYBODY = 6
206 
207 
208 def matched_files_iter(root_path, includes=('*',), ignores=(), extensions=(), out_of_place_only=False):
209  def _fnmatch(filepath, patterns):
210  return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)
211 
212  def match_extensions(filename):
213  """Helper method to see if filename ends with certain extension"""
214  return any(filename.endswith(e) for e in extensions)
215 
216  exact_matches = set(includes)
217 
218  # This is a very rough heuristic; really, we want to avoid scanning
219  # any file which is not checked into source control, but this script
220  # needs to work even if you're in a Git or Hg checkout, so easier to
221  # just blacklist the biggest time sinks that won't matter in the
222  # end.
223  for (abs_dirpath, dirs, filenames) in os.walk(root_path, topdown=True):
224  rel_dirpath = os.path.relpath(abs_dirpath, root_path)
225  if rel_dirpath == '.':
226  # Blah blah blah O(n) blah blah
227  if ".git" in dirs:
228  dirs.remove(".git")
229  if "build" in dirs:
230  dirs.remove("build")
231  if "third_party" in dirs:
232  dirs.remove("third_party")
233  for filename in filenames:
234  filepath = os.path.join(rel_dirpath, filename)
235  # We respect extensions, UNLESS you wrote the entire
236  # filename verbatim, in which case we always accept it
237  if _fnmatch(filepath, includes) and (not _fnmatch(filepath, ignores)) and (match_extensions(filepath) or filepath in exact_matches):
238  if not is_pytorch_file(filepath) and not is_caffe2_gpu_file(filepath):
239  continue
240  if out_of_place_only and not is_out_of_place(filepath):
241  continue
242  yield filepath
243 
244 
245 def preprocess(
246  output_directory,
247  all_files,
248  show_detailed=False,
249  show_progress=True,
250  hip_clang_launch=False):
251  """
252  Call preprocessor on selected files.
253 
254  Arguments)
255  show_detailed - Show a detailed summary of the transpilation process.
256  """
257 
258  # Preprocessing statistics.
259  stats = {"unsupported_calls": [], "kernel_launches": []}
260 
261  for filepath in all_files:
262  preprocessor(output_directory, filepath, stats, hip_clang_launch)
263  # Show what happened
264  if show_progress:
265  print(
266  filepath, "->",
267  get_hip_file_path(filepath))
268 
269  print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)
270 
271  # Show detailed summary
272  if show_detailed:
273  compute_stats(stats)
274 
275 
276 def compute_stats(stats):
277  unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]}
278 
279  # Print the number of unsupported calls
280  print("Total number of unsupported CUDA function calls: {0:d}".format(len(unsupported_calls)))
281 
282  # Print the list of unsupported calls
283  print(", ".join(unsupported_calls))
284 
285  # Print the number of kernel launches
286  print("\nTotal number of replaced kernel launches: {0:d}".format(len(stats["kernel_launches"])))
287 
288 
289 def add_dim3(kernel_string, cuda_kernel):
290  '''adds dim3() to the second and third arguments in the kernel launch'''
291  count = 0
292  closure = 0
293  kernel_string = kernel_string.replace("<<<", "").replace(">>>", "")
294  arg_locs = [{} for _ in range(2)]
295  arg_locs[count]['start'] = 0
296  for ind, c in enumerate(kernel_string):
297  if count > 1:
298  break
299  if c == "(":
300  closure += 1
301  elif c == ")":
302  closure -= 1
303  elif (c == "," or ind == len(kernel_string) - 1) and closure == 0:
304  arg_locs[count]['end'] = ind + (c != ",")
305  count += 1
306  if count < 2:
307  arg_locs[count]['start'] = ind + 1
308 
309  first_arg_raw = kernel_string[arg_locs[0]['start']:arg_locs[0]['end'] + 1]
310  second_arg_raw = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']]
311 
312  first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ")
313  second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ")
314 
315  first_arg_dim3 = "dim3({})".format(first_arg_clean)
316  second_arg_dim3 = "dim3({})".format(second_arg_clean)
317 
318  first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3)
319  second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3)
320  cuda_kernel = cuda_kernel.replace(first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3)
321  return cuda_kernel
322 
323 
324 RE_KERNEL_LAUNCH = re.compile(r'([ ]+)(detail?)::[ ]+\\\n[ ]+')
325 
326 
327 def processKernelLaunches(string, stats):
328  """ Replace the CUDA style Kernel launches with the HIP style kernel launches."""
329  # Concat the namespace with the kernel names. (Find cleaner way of doing this later).
330  string = RE_KERNEL_LAUNCH.sub(lambda inp: "{0}{1}::".format(inp.group(1), inp.group(2)), string)
331 
332  def grab_method_and_template(in_kernel):
333  # The positions for relevant kernel components.
334  pos = {
335  "kernel_launch": {"start": in_kernel["start"], "end": in_kernel["end"]},
336  "kernel_name": {"start": -1, "end": -1},
337  "template": {"start": -1, "end": -1}
338  }
339 
340  # Count for balancing template
341  count = {"<>": 0}
342 
343  # Status for whether we are parsing a certain item.
344  START = 0
345  AT_TEMPLATE = 1
346  AFTER_TEMPLATE = 2
347  AT_KERNEL_NAME = 3
348 
349  status = START
350 
351  # Parse the string character by character
352  for i in range(pos["kernel_launch"]["start"] - 1, -1, -1):
353  char = string[i]
354 
355  # Handle Templating Arguments
356  if status == START or status == AT_TEMPLATE:
357  if char == ">":
358  if status == START:
359  status = AT_TEMPLATE
360  pos["template"]["end"] = i
361  count["<>"] += 1
362 
363  if char == "<":
364  count["<>"] -= 1
365  if count["<>"] == 0 and (status == AT_TEMPLATE):
366  pos["template"]["start"] = i
367  status = AFTER_TEMPLATE
368 
369  # Handle Kernel Name
370  if status != AT_TEMPLATE:
371  if string[i].isalnum() or string[i] in {'(', ')', '_', ':', '#'}:
372  if status != AT_KERNEL_NAME:
373  status = AT_KERNEL_NAME
374  pos["kernel_name"]["end"] = i
375 
376  # Case: Kernel name starts the string.
377  if i == 0:
378  pos["kernel_name"]["start"] = 0
379 
380  # Finished
381  return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])]
382 
383  else:
384  # Potential ending point if we're already traversing a kernel's name.
385  if status == AT_KERNEL_NAME:
386  pos["kernel_name"]["start"] = i
387 
388  # Finished
389  return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])]
390 
391  def find_kernel_bounds(string):
392  """Finds the starting and ending points for all kernel launches in the string."""
393  kernel_end = 0
394  kernel_positions = []
395 
396  # Continue until we cannot find any more kernels anymore.
397  while string.find("<<<", kernel_end) != -1:
398  # Get kernel starting position (starting from the previous ending point)
399  kernel_start = string.find("<<<", kernel_end)
400 
401  # Get kernel ending position (adjust end point past the >>>)
402  kernel_end = string.find(">>>", kernel_start) + 3
403  if kernel_end <= 0:
404  raise InputError("no kernel end found")
405 
406  # Add to list of traversed kernels
407  kernel_positions.append({"start": kernel_start, "end": kernel_end,
408  "group": string[kernel_start: kernel_end]})
409 
410  return kernel_positions
411 
412  # Grab positional ranges of all kernel launchces
413  get_kernel_positions = [k for k in find_kernel_bounds(string)]
414  output_string = string
415 
416  # Replace each CUDA kernel with a HIP kernel.
417  for kernel in get_kernel_positions:
418  # Get kernel components
419  params = grab_method_and_template(kernel)
420 
421  # Find parenthesis after kernel launch
422  parenthesis = string.find("(", kernel["end"])
423 
424  # Extract cuda kernel
425  cuda_kernel = string[params[0]["start"]:parenthesis + 1]
426  kernel_string = string[kernel['start']:kernel['end']]
427  cuda_kernel_dim3 = add_dim3(kernel_string, cuda_kernel)
428  # Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size)
429  num_klp = len(extract_arguments(0, kernel["group"].replace("<<<", "(").replace(">>>", ")")))
430 
431  hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel_dim3[0:-1].replace(
432  ">>>", ", 0" * (4 - num_klp) + ">>>").replace("<<<", ", ").replace(">>>", ", ")
433 
434  # Replace cuda kernel with hip kernel
435  output_string = output_string.replace(cuda_kernel, hip_kernel)
436 
437  # Update the statistics
438  stats["kernel_launches"].append(hip_kernel)
439 
440  return output_string
441 
442 
443 def find_closure_group(input_string, start, group):
444  """Generalization for finding a balancing closure group
445 
446  if group = ["(", ")"], then finds the first balanced parantheses.
447  if group = ["{", "}"], then finds the first balanced bracket.
448 
449  Given an input string, a starting position in the input string, and the group type,
450  find_closure_group returns the positions of group[0] and group[1] as a tuple.
451 
452  Example:
453  find_closure_group("(hi)", 0, ["(", ")"])
454 
455  Returns:
456  0, 3
457  """
458 
459  inside_parenthesis = False
460  parens = 0
461  pos = start
462  p_start, p_end = -1, -1
463 
464  while pos < len(input_string):
465  if input_string[pos] == group[0]:
466  if inside_parenthesis is False:
467  inside_parenthesis = True
468  parens = 1
469  p_start = pos
470  else:
471  parens += 1
472  elif input_string[pos] == group[1] and inside_parenthesis:
473  parens -= 1
474 
475  if parens == 0:
476  p_end = pos
477  return p_start, p_end
478 
479  pos += 1
480  return None, None
481 
482 
483 def find_bracket_group(input_string, start):
484  """Finds the first balanced parantheses."""
485  return find_closure_group(input_string, start, group=["{", "}"])
486 
487 
488 def find_parentheses_group(input_string, start):
489  """Finds the first balanced bracket."""
490  return find_closure_group(input_string, start, group=["(", ")"])
491 
492 
493 RE_ASSERT = re.compile(r"\bassert[ ]*\(")
494 
495 
496 def disable_asserts(input_string):
497  """ Disables regular assert statements
498  e.g. "assert(....)" -> "/*assert(....)*/"
499  """
500  output_string = input_string
501  asserts = list(RE_ASSERT.finditer(input_string))
502  for assert_item in asserts:
503  p_start, p_end = find_parentheses_group(input_string, assert_item.end() - 1)
504  start = assert_item.start()
505  output_string = output_string.replace(input_string[start:p_end + 1], "")
506  return output_string
507 
508 
509 def replace_math_functions(input_string):
510  """ FIXME: Temporarily replace std:: invocations of math functions with non-std:: versions to prevent linker errors
511  NOTE: This can lead to correctness issues when running tests, since the correct version of the math function (exp/expf) might not get called.
512  Plan is to remove this function once HIP supports std:: math function calls inside device code
513  """
514  output_string = input_string
515  for func in MATH_TRANSPILATIONS:
516  output_string = output_string.replace(r'{}('.format(func), '{}('.format(MATH_TRANSPILATIONS[func]))
517 
518  return output_string
519 
520 
521 RE_SYNCTHREADS = re.compile(r"[:]?[:]?\b(__syncthreads)\b(\w*\()")
522 
523 
524 def hip_header_magic(input_string):
525  """If the file makes kernel builtin calls and does not include the cuda_runtime.h header,
526  then automatically add an #include to match the "magic" includes provided by NVCC.
527  TODO:
528  Update logic to ignore cases where the cuda_runtime.h is included by another file.
529  """
530 
531  # Copy the input.
532  output_string = input_string
533 
534  # Check if one of the following headers is already included.
535  headers = ["hip/hip_runtime.h", "hip/hip_runtime_api.h"]
536  if any(re.search(r'#include ("{0}"|<{0}>)'.format(ext), output_string) for ext in headers):
537  return output_string
538 
539  # Rough logic to detect if we're inside device code
540  hasDeviceLogic = "hipLaunchKernelGGL" in output_string
541  hasDeviceLogic += "__global__" in output_string
542  hasDeviceLogic += "__shared__" in output_string
543  hasDeviceLogic += RE_SYNCTHREADS.search(output_string) is not None
544 
545  # If device logic found, provide the necessary header.
546  if hasDeviceLogic:
547  output_string = '#include "hip/hip_runtime.h"\n' + input_string
548 
549  return output_string
550 
551 
552 RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;")
553 
554 
555 def replace_extern_shared(input_string):
556  """Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
557  https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md#__shared__
558  Example:
559  "extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
560  "extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
561  """
562  output_string = input_string
563  output_string = RE_EXTERN_SHARED.sub(
564  lambda inp: "HIP_DYNAMIC_SHARED({0} {1}, {2})".format(
565  inp.group(1) or "", inp.group(2), inp.group(3)), output_string)
566 
567  return output_string
568 
569 
570 def disable_function(input_string, function, replace_style):
571  """ Finds and disables a function in a particular file.
572 
573  If type(function) == List
574  function - The signature of the function to disable.
575  e.g. ["bool", "overlappingIndices", "(const Tensor& t)"]
576  disables function -> "bool overlappingIndices(const Tensor& t)"
577 
578  If type(function) == String
579  function - Disables the function by name only.
580  e.g. "overlappingIndices"
581 
582  replace_style - The style to use when stubbing functions.
583  """
584 # void (*)(hcrngStateMtgp32 *, int, float *, double, double)
585  info = {
586  "function_start": -1,
587  "function_end": -1,
588  "bracket_count": 0
589  }
590 
591  STARTED = 0
592  INSIDE_FUNCTION = 1
593  BRACKET_COMPLETE = 2
594 
595  STATE = STARTED
596 
597  if type(function) == list:
598  # Extract components from function signature.
599  func_info = {
600  "return_type": function[0].strip(),
601  "function_name": function[1].strip(),
602  "function_args": function[2].strip()
603  }
604 
605  # Create function string to search for
606  function_string = "{0}{1}{2}".format(
607  func_info["return_type"],
608  func_info["function_name"],
609  func_info["function_args"]
610  )
611 
612  # Find the starting position for the function
613  info["function_start"] = input_string.find(function_string)
614  else:
615  # Automatically detect signature.
616  the_match = re.search(r"(((.*) (\*)?)({0})(\([^{{)]*\)))\s*{{".format(
617  function.replace("(", r"\(").replace(")", r"\)")), input_string)
618  if the_match is None:
619  return input_string
620 
621  func_info = {
622  "return_type": the_match.group(2).strip(),
623  "function_name": the_match.group(5).strip(),
624  "function_args": the_match.group(6).strip(),
625  }
626 
627  # Find the starting position for the function
628  info["function_start"] = the_match.start()
629  function_string = the_match.group(1)
630 
631  # The function can't be found anymore.
632  if info["function_start"] == -1:
633  return input_string
634 
635  # Find function block start.
636  pos = info["function_start"] + len(function_string) - 1
637  while pos < len(input_string) and STATE != BRACKET_COMPLETE:
638  if input_string[pos] == "{":
639  if STATE != INSIDE_FUNCTION:
640  STATE = INSIDE_FUNCTION
641  info["bracket_count"] = 1
642  else:
643  info["bracket_count"] += 1
644  elif input_string[pos] == "}":
645  info["bracket_count"] -= 1
646 
647  if info["bracket_count"] == 0 and STATE == INSIDE_FUNCTION:
648  STATE = BRACKET_COMPLETE
649  info["function_end"] = pos
650 
651  pos += 1
652 
653  # Never found the function end. Corrupted file!
654  if STATE != BRACKET_COMPLETE:
655  return input_string
656 
657  # Preprocess the source by removing the function.
658  function_body = input_string[info["function_start"]:info["function_end"] + 1]
659 
660  # Remove the entire function body
661  if replace_style == disablefuncmode.REMOVE:
662  output_string = input_string.replace(function_body, "")
663 
664  # Stub the function based off its return type.
665  elif replace_style == disablefuncmode.STUB:
666  # void return type
667  if func_info["return_type"] == "void" or func_info["return_type"] == "static void":
668  stub = "{0}{{\n}}".format(function_string)
669  # pointer return type
670  elif "*" in func_info["return_type"]:
671  stub = "{0}{{\nreturn {1};\n}}".format(function_string, "NULL") # nullptr
672  else:
673  stub = "{0}{{\n{1} stub_var;\nreturn stub_var;\n}}".format(function_string, func_info["return_type"])
674 
675  output_string = input_string.replace(function_body, stub)
676 
677  # Add HIP Preprocessors.
678  elif replace_style == disablefuncmode.HCC_MACRO:
679  output_string = input_string.replace(
680  function_body,
681  "#if !defined(__HIP_PLATFORM_HCC__)\n{0}\n#endif".format(function_body))
682 
683  # Add HIP Preprocessors.
684  elif replace_style == disablefuncmode.DEVICE_MACRO:
685  output_string = input_string.replace(
686  function_body,
687  "#if !defined(__HIP_DEVICE_COMPILE__)\n{0}\n#endif".format(function_body))
688 
689  # Throw an exception at runtime.
690  elif replace_style == disablefuncmode.EXCEPTION:
691  stub = "{0}{{\n{1};\n}}".format(
692  function_string,
693  'throw std::runtime_error("The function {0} is not implemented.")'.format(
694  function_string.replace("\n", " ")))
695  output_string = input_string.replace(function_body, stub)
696 
697  elif replace_style == disablefuncmode.ASSERT:
698  stub = "{0}{{\n{1};\n}}".format(
699  function_string,
700  'assert(0)')
701  output_string = input_string.replace(function_body, stub)
702 
703  elif replace_style == disablefuncmode.EMPTYBODY:
704  stub = "{0}{{\n;\n}}".format(function_string)
705  output_string = input_string.replace(function_body, stub)
706  return output_string
707 
708 
709 def get_hip_file_path(filepath):
710  """
711  Returns the new name of the hipified file
712  """
713  # At the moment, some files are HIPified in place. The predicate
714  # is_out_of_place tells us if this is the case or not.
715  if not is_out_of_place(filepath):
716  return filepath
717 
718  dirpath, filename = os.path.split(filepath)
719  root, ext = os.path.splitext(filename)
720 
721  # Here's the plan:
722  #
723  # In general, we need to disambiguate the HIPified filename so that
724  # it gets a different name from the original Caffe2 filename, so
725  # that we don't overwrite the original file. (Additionally,
726  # hcc historically had a bug where if you had two files with
727  # the same basename, they would clobber each other.)
728  #
729  # There's a lot of different naming conventions across PyTorch
730  # and Caffe2, but the general recipe is to convert occurrences
731  # of cuda/gpu to hip, and add hip if there are no occurrences
732  # of cuda/gpu anywhere.
733  #
734  # Concretely, we do the following:
735  #
736  # - If there is a directory component named "cuda", replace
737  # it with "hip", AND
738  #
739  # - If the file name contains "CUDA", replace it with "HIP", AND
740  #
741  # If NONE of the above occurred, then insert "hip" in the file path
742  # as the direct parent folder of the file
743  #
744  # Furthermore, ALWAYS replace '.cu' with '.hip', because those files
745  # contain CUDA kernels that needs to be hipified and processed with
746  # hcc compiler
747  #
748  # This isn't set in stone; we might adjust this to support other
749  # naming conventions.
750 
751  if ext == '.cu':
752  ext = '.hip'
753 
754  orig_dirpath = dirpath
755 
756  dirpath = dirpath.replace('cuda', 'hip')
757  dirpath = dirpath.replace('THC', 'THH')
758 
759  root = root.replace('cuda', 'hip')
760  root = root.replace('CUDA', 'HIP')
761  # Special case to handle caffe2/core/THCCachingAllocator
762  if dirpath != "caffe2/core":
763  root = root.replace('THC', 'THH')
764 
765  if dirpath == orig_dirpath:
766  dirpath = os.path.join(dirpath, 'hip')
767 
768  return os.path.join(dirpath, root + ext)
769 
770 
771 def is_out_of_place(filepath):
772  if filepath.startswith("torch/"):
773  return False
774  if filepath.startswith("tools/autograd/templates/"):
775  return False
776  return True
777 
778 
779 # Keep this synchronized with includes/ignores in build_amd.py
780 def is_pytorch_file(filepath):
781  if filepath.startswith("aten/"):
782  if filepath.startswith("aten/src/ATen/core/"):
783  return False
784  return True
785  if filepath.startswith("torch/"):
786  return True
787  if filepath.startswith("tools/autograd/templates/"):
788  return True
789  return False
790 
791 
792 def is_caffe2_gpu_file(filepath):
793  if filepath.startswith("c10/cuda"):
794  return True
795  filename = os.path.basename(filepath)
796  _, ext = os.path.splitext(filename)
797  return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
798 
799 
800 # Cribbed from https://stackoverflow.com/questions/42742810/speed-up-millions-of-regex-replacements-in-python-3/42789508#42789508
801 class Trie():
802  """Regex::Trie in Python. Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
803  The corresponding Regex should match much faster than a simple Regex union."""
804 
805  def __init__(self):
806  self.data = {}
807 
808  def add(self, word):
809  ref = self.data
810  for char in word:
811  ref[char] = char in ref and ref[char] or {}
812  ref = ref[char]
813  ref[''] = 1
814 
815  def dump(self):
816  return self.data
817 
818  def quote(self, char):
819  return re.escape(char)
820 
821  def _pattern(self, pData):
822  data = pData
823  if "" in data and len(data.keys()) == 1:
824  return None
825 
826  alt = []
827  cc = []
828  q = 0
829  for char in sorted(data.keys()):
830  if isinstance(data[char], dict):
831  try:
832  recurse = self._pattern(data[char])
833  alt.append(self.quote(char) + recurse)
834  except Exception:
835  cc.append(self.quote(char))
836  else:
837  q = 1
838  cconly = not len(alt) > 0
839 
840  if len(cc) > 0:
841  if len(cc) == 1:
842  alt.append(cc[0])
843  else:
844  alt.append('[' + ''.join(cc) + ']')
845 
846  if len(alt) == 1:
847  result = alt[0]
848  else:
849  result = "(?:" + "|".join(alt) + ")"
850 
851  if q:
852  if cconly:
853  result += "?"
854  else:
855  result = "(?:%s)?" % result
856  return result
857 
858  def pattern(self):
859  return self._pattern(self.dump())
860 
861 
862 CAFFE2_TRIE = Trie()
863 CAFFE2_MAP = {}
864 PYTORCH_TRIE = Trie()
865 PYTORCH_MAP = {}
866 for mapping in CUDA_TO_HIP_MAPPINGS:
867  for src, value in mapping.items():
868  dst = value[0]
869  meta_data = value[1:]
870  if constants.API_CAFFE2 not in meta_data:
871  PYTORCH_TRIE.add(src)
872  PYTORCH_MAP[src] = dst
873  if constants.API_PYTORCH not in meta_data:
874  CAFFE2_TRIE.add(src)
875  CAFFE2_MAP[src] = dst
876 RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.pattern())
877 RE_PYTORCH_PREPROCESSOR = re.compile(r'\b{0}\b'.format(PYTORCH_TRIE.pattern()))
878 
879 RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"')
880 RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>')
881 RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"')
882 RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh
883 
884 def preprocessor(output_directory, filepath, stats, hip_clang_launch):
885  """ Executes the CUDA -> HIP conversion on the specified file. """
886  fin_path = os.path.join(output_directory, filepath)
887  with open(fin_path, 'r') as fin:
888  output_source = fin.read()
889 
890  fout_path = os.path.join(output_directory, get_hip_file_path(filepath))
891  if not os.path.exists(os.path.dirname(fout_path)):
892  os.makedirs(os.path.dirname(fout_path))
893 
894  with open(fout_path, 'w') as fout:
895  # unsupported_calls statistics reporting is broken atm
896  if is_pytorch_file(filepath):
897  def pt_repl(m):
898  return PYTORCH_MAP[m.group(0)]
899  output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
900  else:
901  def c2_repl(m):
902  return CAFFE2_MAP[m.group(0)]
903  output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source)
904 
905  # Header rewrites
906  def mk_repl(templ):
907  def repl(m):
908  f = m.group(1)
909  if f.startswith("ATen/cuda") or f.startswith("ATen/native/cuda") or f.startswith("ATen/native/sparse/cuda") or f.startswith("THC/") or f.startswith("THCUNN/") or (f.startswith("THC") and not f.startswith("THCP")):
910  return templ.format(get_hip_file_path(m.group(1)))
911  return m.group(0)
912  return repl
913  output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"'), output_source)
914  output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>'), output_source)
915  output_source = RE_THC_GENERIC_FILE.sub(mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source)
916 
917  # CMakeLists.txt rewrites
918  if filepath.endswith('CMakeLists.txt'):
919  output_source = output_source.replace('CUDA', 'HIP')
920  output_source = output_source.replace('THC', 'THH')
921  output_source = RE_CU_SUFFIX.sub('.hip', output_source)
922 
923  # Perform Kernel Launch Replacements
924  if not hip_clang_launch:
925  output_source = processKernelLaunches(output_source, stats)
926 
927  # Disable asserts
928  # if not filepath.endswith("THCGeneral.h.in"):
929  # output_source = disable_asserts(output_source)
930 
931  # Replace std:: with non-std:: versions
932  if filepath.endswith(".cu") or filepath.endswith(".cuh"):
933  output_source = replace_math_functions(output_source)
934 
935  # Include header if device code is contained.
936  output_source = hip_header_magic(output_source)
937 
938  # Replace the extern __shared__
939  output_source = replace_extern_shared(output_source)
940 
941  fout.write(output_source)
942 
943 
944 def file_specific_replacement(filepath, search_string, replace_string, strict=False):
945  with openf(filepath, "r+") as f:
946  contents = f.read()
947  if strict:
948  contents = re.sub(r'\b({0})\b'.format(re.escape(search_string)), lambda x: replace_string, contents)
949  else:
950  contents = contents.replace(search_string, replace_string)
951  f.seek(0)
952  f.write(contents)
953  f.truncate()
954 
955 
956 def file_add_header(filepath, header):
957  with openf(filepath, "r+") as f:
958  contents = f.read()
959  if header[0] != "<" and header[-1] != ">":
960  header = '"{0}"'.format(header)
961  contents = ('#include {0} \n'.format(header)) + contents
962  f.seek(0)
963  f.write(contents)
964  f.truncate()
965 
966 
968  """Static global kernels in HIP results in a compilation error."""
969  in_txt = in_txt.replace(" __global__ static", "__global__")
970  return in_txt
971 
972 
973 def disable_unsupported_function_call(function, input_string, replacement):
974  """Disables calls to an unsupported HIP function"""
975  # Prepare output string
976  output_string = input_string
977 
978  # Find all calls to the function
979  calls = re.finditer(r"\b{0}\b".format(re.escape(function)), input_string)
980 
981  # Do replacements
982  for call in calls:
983  start = call.start()
984  end = call.end()
985 
986  pos = end
987  started_arguments = False
988  bracket_count = 0
989  while pos < len(input_string):
990  if input_string[pos] == "(":
991  if started_arguments is False:
992  started_arguments = True
993  bracket_count = 1
994  else:
995  bracket_count += 1
996  elif input_string[pos] == ")" and started_arguments:
997  bracket_count -= 1
998 
999  if bracket_count == 0 and started_arguments:
1000  # Finished!
1001  break
1002  pos += 1
1003 
1004  function_call = input_string[start:pos + 1]
1005  output_string = output_string.replace(function_call, replacement)
1006 
1007  return output_string
1008 
1009 
1010 RE_INCLUDE = re.compile(r"#include .*\n")
1011 
1012 
1013 def disable_module(input_file):
1014  """Disable a module entirely except for header includes."""
1015  with openf(input_file, "r+") as f:
1016  txt = f.read()
1017  last = list(RE_INCLUDE.finditer(txt))[-1]
1018  end = last.end()
1019 
1020  disabled = "{0}#if !defined(__HIP_PLATFORM_HCC__)\n{1}\n#endif".format(txt[0:end], txt[end:])
1021 
1022  f.seek(0)
1023  f.write(disabled)
1024  f.truncate()
1025 
1026 
1027 def extract_arguments(start, string):
1028  """ Return the list of arguments in the upcoming function parameter closure.
1029  Example:
1030  string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
1031  arguments (output):
1032  '[{'start': 1, 'end': 7},
1033  {'start': 8, 'end': 16},
1034  {'start': 17, 'end': 19},
1035  {'start': 20, 'end': 53}]'
1036  """
1037 
1038  arguments = []
1039  closures = {
1040  "<": 0,
1041  "(": 0
1042  }
1043  current_position = start
1044  argument_start_pos = current_position + 1
1045 
1046  # Search for final parenthesis
1047  while current_position < len(string):
1048  if string[current_position] == "(":
1049  closures["("] += 1
1050  elif string[current_position] == ")":
1051  closures["("] -= 1
1052  elif string[current_position] == "<":
1053  closures["<"] += 1
1054  elif string[current_position] == ">" and string[current_position - 1] != "-" and closures["<"] > 0:
1055  closures["<"] -= 1
1056 
1057  # Finished all arguments
1058  if closures["("] == 0 and closures["<"] == 0:
1059  # Add final argument
1060  arguments.append({"start": argument_start_pos, "end": current_position})
1061  break
1062 
1063  # Finished current argument
1064  if closures["("] == 1 and closures["<"] == 0 and string[current_position] == ",":
1065  arguments.append({"start": argument_start_pos, "end": current_position})
1066  argument_start_pos = current_position + 1
1067 
1068  current_position += 1
1069 
1070  return arguments
1071 
1072 
1073 def str2bool(v):
1074  """ArgumentParser doesn't support type=bool. Thus, this helper method will convert
1075  from possible string types to True / False."""
1076  if v.lower() in ('yes', 'true', 't', 'y', '1'):
1077  return True
1078  elif v.lower() in ('no', 'false', 'f', 'n', '0'):
1079  return False
1080  else:
1081  raise argparse.ArgumentTypeError('Boolean value expected.')
1082 
1083 
1084 def hipify(
1085  project_directory,
1086  show_detailed=False,
1087  extensions=(".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
1088  output_directory="",
1089  includes=(),
1090  json_settings="",
1091  out_of_place_only=False,
1092  ignores=(),
1093  show_progress=True,
1094  hip_clang_launch=False,
1095 ):
1096  if project_directory == "":
1097  project_directory = os.getcwd()
1098 
1099  # Verify the project directory exists.
1100  if not os.path.exists(project_directory):
1101  print("The project folder specified does not exist.")
1102  sys.exit(1)
1103 
1104  # If no output directory, provide a default one.
1105  if not output_directory:
1106  project_directory.rstrip("/")
1107  output_directory = project_directory + "_amd"
1108 
1109  # Copy from project directory to output directory if not done already.
1110  if not os.path.exists(output_directory):
1111  shutil.copytree(project_directory, output_directory)
1112 
1113  # Open JSON file with disable information.
1114  if json_settings != "":
1115  with openf(json_settings, "r") as f:
1116  json_data = json.load(f)
1117 
1118  # Disable functions in certain files according to JSON description
1119  for disable_info in json_data["disabled_functions"]:
1120  filepath = os.path.join(output_directory, disable_info["path"])
1121  if "functions" in disable_info:
1122  functions = disable_info["functions"]
1123  else:
1124  functions = disable_info.get("functions", [])
1125 
1126  if "non_hip_functions" in disable_info:
1127  non_hip_functions = disable_info["non_hip_functions"]
1128  else:
1129  non_hip_functions = disable_info.get("non_hip_functions", [])
1130 
1131  if "non_device_functions" in disable_info:
1132  not_on_device_functions = disable_info["non_device_functions"]
1133  else:
1134  not_on_device_functions = disable_info.get("non_device_functions", [])
1135 
1136  with openf(filepath, "r+") as f:
1137  txt = f.read()
1138  for func in functions:
1139  # TODO - Find fix assertions in HIP for device code.
1140  txt = disable_function(txt, func, disablefuncmode.ASSERT)
1141 
1142  for func in non_hip_functions:
1143  # Disable this function on HIP stack
1144  txt = disable_function(txt, func, disablefuncmode.HCC_MACRO)
1145 
1146  for func in not_on_device_functions:
1147  # Disable this function when compiling on Device
1148  txt = disable_function(txt, func, disablefuncmode.DEVICE_MACRO)
1149 
1150  f.seek(0)
1151  f.write(txt)
1152  f.truncate()
1153 
1154  # Disable modules
1155  disable_modules = json_data["disabled_modules"]
1156  for module in disable_modules:
1157  disable_module(os.path.join(output_directory, module))
1158 
1159  # Disable unsupported HIP functions
1160  for disable in json_data["disable_unsupported_hip_calls"]:
1161  filepath = os.path.join(output_directory, disable["path"])
1162  if "functions" in disable:
1163  functions = disable["functions"]
1164  else:
1165  functions = disable.get("functions", [])
1166 
1167  if "constants" in disable:
1168  constants = disable["constants"]
1169  else:
1170  constants = disable.get("constants", [])
1171 
1172  if "s_constants" in disable:
1173  s_constants = disable["s_constants"]
1174  else:
1175  s_constants = disable.get("s_constants", [])
1176 
1177  if not os.path.exists(filepath):
1178  print("\n" + bcolors.WARNING + "JSON Warning: File {0} does not exist.".format(filepath) + bcolors.ENDC)
1179  continue
1180 
1181  with openf(filepath, "r+") as f:
1182  txt = f.read()
1183 
1184  # Disable HIP Functions
1185  for func in functions:
1186  txt = disable_unsupported_function_call(func, txt, functions[func])
1187 
1188  # Disable Constants w\ Boundary.
1189  for const in constants:
1190  txt = re.sub(r"\b{0}\b".format(re.escape(const)), constants[const], txt)
1191 
1192  # Disable Constants
1193  for s_const in s_constants:
1194  txt = txt.replace(s_const, s_constants[s_const])
1195 
1196  # Save Changes
1197  f.seek(0)
1198  f.write(txt)
1199  f.truncate()
1200 
1201  all_files = list(matched_files_iter(output_directory, includes=includes,
1202  ignores=ignores, extensions=extensions,
1203  out_of_place_only=out_of_place_only))
1204 
1205  # Start Preprocessor
1206  preprocess(
1207  output_directory,
1208  all_files,
1209  show_detailed=show_detailed,
1210  show_progress=show_progress,
1211  hip_clang_launch=hip_clang_launch)
def find_bracket_group(input_string, start)
def replace_extern_shared(input_string)
def preprocess(output_directory, all_files, show_detailed=False, show_progress=True, hip_clang_launch=False)
def preprocessor(output_directory, filepath, stats, hip_clang_launch)
def processKernelLaunches(string, stats)
def disable_function(input_string, function, replace_style)
def find_parentheses_group(input_string, start)
def disable_module(input_file)
def hip_header_magic(input_string)
def fix_static_global_kernels(in_txt)
def get_hip_file_path(filepath)
def find_closure_group(input_string, start, group)
def extract_arguments(start, string)
def disable_asserts(input_string)
def replace_math_functions(input_string)
def disable_unsupported_function_call(function, input_string, replacement)
def add_dim3(kernel_string, cuda_kernel)