Caffe2 - Python API
A deep learning, cross platform ML framework
nvtx.py
1 import os
2 import glob
3 import ctypes
4 import platform
5 
6 lib = None
7 
8 __all__ = ['range_push', 'range_pop', 'mark']
9 
10 
11 def windows_nvToolsExt_lib():
12  lib_path = windows_nvToolsExt_path()
13  if len(lib_path) > 0:
14  lib_name = os.path.basename(lib_path)
15  lib = os.path.splitext(lib_name)[0]
16  return ctypes.cdll.LoadLibrary(lib)
17  else:
18  return None
19 
20 
21 def windows_nvToolsExt_path():
22  WINDOWS_HOME = 'C:/Program Files/NVIDIA Corporation/NvToolsExt'
23  NVTOOLEXT_HOME = os.getenv('NVTOOLSEXT_PATH', WINDOWS_HOME)
24  if os.path.exists(NVTOOLEXT_HOME):
25  lib_paths = glob.glob(NVTOOLEXT_HOME + '/bin/x64/nvToolsExt*.dll')
26  if len(lib_paths) > 0:
27  lib_path = lib_paths[0]
28  return lib_path
29  return ''
30 
31 
32 def _libnvToolsExt():
33  global lib
34  if lib is None:
35  if platform.system() != 'Windows':
36  lib = ctypes.cdll.LoadLibrary(None)
37  else:
38  lib = windows_nvToolsExt_lib()
39  lib.nvtxMarkA.restype = None
40  return lib
41 
42 
43 def range_push(msg):
44  """
45  Pushes a range onto a stack of nested range span. Returns zero-based
46  depth of the range that is started.
47 
48  Arguments:
49  msg (string): ASCII message to associate with range
50  """
51  if _libnvToolsExt() is None:
52  raise RuntimeError('Unable to load nvToolsExt library')
53  return lib.nvtxRangePushA(ctypes.c_char_p(msg.encode("ascii")))
54 
55 
56 def range_pop():
57  """
58  Pops a range off of a stack of nested range spans. Returns the
59  zero-based depth of the range that is ended.
60  """
61  if _libnvToolsExt() is None:
62  raise RuntimeError('Unable to load nvToolsExt library')
63  return lib.nvtxRangePop()
64 
65 
66 def mark(msg):
67  """
68  Describe an instantaneous event that occurred at some point.
69 
70  Arguments:
71  msg (string): ASCII message to associate with the event.
72  """
73  if _libnvToolsExt() is None:
74  raise RuntimeError('Unable to load nvToolsExt library')
75  return lib.nvtxMarkA(ctypes.c_char_p(msg.encode("ascii")))