6 from collections 
import defaultdict
    13     def __init__(self, name):
    17         torch.autograd._push_range(self.
name)
    19     def __exit__(self, *args):
    20         torch.autograd._pop_range()
    25     """A list of Events (for pretty printing)"""    26     def __init__(self, *args, **kwargs):
    27         super(EventList, self).__init__(*args, **kwargs)
    33         """Prints an EventList as a nicely formatted table.    36             sort_by (str, optional): Attribute used to sort entries. By default    37                 they are printed in the same order as they were registered.    38                 Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``,    39                 ``cuda_time_total``, ``count``.    42             A string containing the table.    44         return build_table(self, sort_by)
    47         """Exports an EventList as a Chrome tracing tools file.    49         The checkpoint can be later loaded and inspected under ``chrome://tracing`` URL.    52             path (str): Path where the trace will be written.    55         with open(path, 
'w') 
as f:
    59                 chrome_events.append(dict(
    62                     ts=evt.cpu_interval.start,
    63                     dur=evt.cpu_interval.elapsed_us(),
    71                     chrome_events.append(dict(
    74                         ts=evt.cpu_interval.start,
    81                     chrome_events.append(dict(
    91                     chrome_events.append(dict(
    95                         dur=k.interval.elapsed_us(),
   102             json.dump(chrome_events, f)
   105         """Averages all function events over their keys.   108             An EventList containing FunctionEventAvg objects.   110         stats = defaultdict(FunctionEventAvg)
   112             stats[evt.key] += evt
   116         """Averages all events.   119             A FunctionEventAvg object.   124             total_stat.key = 
None   125         total_stat.key = 
'Total'   130     """Context manager that manages autograd profiler state and holds a summary of results.   133         enabled (bool, optional): Setting this to False makes this context manager a no-op.   136         use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API.   137             Adds approximately 4us of overhead to each tensor operation.   141         This context managers should not be called recursively, i.e. at most one   142         instance should be enabled at any given time.   145         >>> x = torch.randn((1, 1), requires_grad=True)   146         >>> with torch.autograd.profiler.profile() as prof:   149         >>> # NOTE: some columns were removed for brevity   151         -------------------------------------  ---------------  ---------------   152         Name                                          CPU time        CUDA time   153         -------------------------------------  ---------------  ---------------   154         PowConstant                                  142.036us          0.000us   155         N5torch8autograd9GraphRootE                   63.524us          0.000us   156         PowConstantBackward                          184.228us          0.000us   157         MulConstant                                   50.288us          0.000us   158         PowConstant                                   28.439us          0.000us   160         N5torch8autograd14AccumulateGradE             13.790us          0.000us   161         N5torch8autograd5CloneE                        4.088us          0.000us   164     def __init__(self, enabled=True, use_cuda=False):
   176             raise RuntimeError(
"autograd profiler traces are not reentrant")
   178         profiler_kind = torch.autograd.ProfilerState.CUDA 
if self.
use_cuda \
   179             else torch.autograd.ProfilerState.CPU
   180         torch.autograd._enable_profiler(profiler_kind)
   183     def __exit__(self, exc_type, exc_val, exc_tb):
   186         records = torch.autograd._disable_profiler()
   192             return '<unfinished torch.autograd.profile>'   197             return '<unfinished torch.autograd.profile>'   200     def _check_finish(self):
   202             raise RuntimeError(
"can't export a trace that didn't finish running")
   204     def table(self, sort_by=None):
   206         return self.function_events.table(sort_by)
   207     table.__doc__ = EventList.table.__doc__
   209     def export_chrome_trace(self, path):
   211         return self.function_events.export_chrome_trace(path)
   212     export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
   214     def key_averages(self):
   216         return self.function_events.key_averages()
   217     key_averages.__doc__ = EventList.key_averages.__doc__
   219     def total_average(self):
   221         return self.function_events.total_average()
   222     total_average.__doc__ = EventList.total_average.__doc__
   226     """Context manager that makes every autograd operation emit an NVTX range.   228     It is useful when running the program under nvprof::   230         nvprof --profile-from-start off -o trace_name.prof -- <regular command here>   232     Unfortunately, there's no way to force nvprof to flush the data it collected   233     to disk, so for CUDA profiling one has to use this context manager to annotate   234     nvprof traces and wait for the process to exit before inspecting them.   235     Then, either NVIDIA Visual Profiler (nvvp) can be used to visualize the timeline, or   236     :func:`torch.autograd.profiler.load_nvprof` can load the results for inspection   240         This context manager should not be called recursively, i.e. at most one   241         instance should be enabled at any given time.   244         enabled (bool, optional): Setting this to False makes this context manager a no-op.   248         >>> with torch.cuda.profiler.profile():   249         ...     model(x) # Warmup CUDA memory allocator and profiler   250         ...     with torch.autograd.profiler.emit_nvtx():   253     **Forward-backward correlation**   255     When viewing a profile created using :class:`emit_nvtx` in the Nvidia Visual Profiler,   256     correlating each backward-pass op with the corresponding forward-pass op can be difficult.   257     To ease this task, :class:`emit_nvtx` appends sequence number information to the ranges it   260     During the forward pass, each function range is decorated with ``seq=<N>``.  ``seq`` is a running   261     counter, incremented each time a new backward Function object is created and stashed for backward.   262     Thus, the `seq=<N>` annotation associated with each forward function range tells you that   263     if a backward Function object is created by this forward function,   264     the backward object will receive sequence number N.   265     During the backward pass, the top-level range wrapping each C++ backward Function's   266     ``apply()`` call is decorated with ``stashed seq=<M>``.  ``M`` is the sequence number that   267     the backward object was created with.  By comparing ``stashed seq`` numbers in backward with ``seq``   268     numbers in forward, you can track down which forward op created each backward Function.   270     Any functions executed during the backward pass are also decorated with ``seq=<N>``.  During   271     default backward (with ``create_graph=False``) this information is irrelevant, and in fact,   272     ``N`` may simply be 0 for all such functions.  Only the top-level ranges associated with   273     backward Function objects' ``apply()`` methods are useful, as a way to correlate these Function   274     objects with the earlier forward pass.   278     If, on the other hand, a backward pass with ``create_graph=True`` is underway (in other words,   279     if you are setting up for a double-backward), each function's execution during backward   280     is given a nonzero, useful ``seq=<N>``.  Those functions may themselves create Function objects   281     to be executed later during double-backward, just as the original functions in the forward pass did.   282     The relationship between backward and double-backward is conceptually the same as the relationship   283     between forward and backward: The functions still emit current-sequence-number-tagged ranges,   284     the Function objects they create still stash those sequence numbers, and during the eventual   285     double-backward, the Function objects' ``apply()`` ranges are still tagged with ``stashed seq``   286     numbers, which can be compared to `seq` numbers from the backward pass.   289         The sequence number is thread-local, and some forward functions don't create an associated   290         backward Function object (instead delegating that to sub-functions further down the call chain).   291         For these reasons, the correspondence of stashed sequence numbers in   292         backward Function ``apply()`` ranges with `seq` numbers in forward-pass ranges is   293         not guaranteed to be 1 to 1.  The sequence numbers alone may not be enough to fully   294         disambiguate which forward function created which   295         backward Function object.  You may need to make a judgment based on analytic knowledge of what   296         the expected correspondence should be.   298     def __init__(self, enabled=True):
   306             raise RuntimeError(
"NVTX annotation context manager is not reentrant")
   309         torch.autograd._enable_profiler(torch.autograd.ProfilerState.NVTX)
   312     def __exit__(self, exc_type, exc_val, exc_tb):
   316         torch.autograd._disable_profiler()
   320 def load_nvprof(path):
   321     """Opens an nvprof trace file and parses autograd annotations.   324         path (str): path to nvprof trace   326     return EventList(parse_nvprof_trace(path))
   332 def format_time(time_us):
   333     """Defines how to format time in FunctionEvent"""   334     return '{:.3f}us'.format(time_us)
   337 def attr_formatter(name):
   338     return property(
lambda self: format_time(getattr(self, name)))
   342     """Helpers for FunctionEvent and FunctionEventAvg.   344     The subclass should define `*_time_total` and `count` attributes.   346     cpu_time_str = attr_formatter(
'cpu_time')
   347     cuda_time_str = attr_formatter(
'cuda_time')
   348     cpu_time_total_str = attr_formatter(
'cpu_time_total')
   349     cuda_time_total_str = attr_formatter(
'cuda_time_total')
   353         return 0.0 
if self.
count == 0 
else 1.0 * self.cpu_time_total / self.
count   357         return 0.0 
if self.
count == 0 
else 1.0 * self.cuda_time_total / self.
count   361     def __init__(self, start, end):
   365     def elapsed_us(self):
   370     def __init__(self, name, device, interval):
   378     """Profiling information about a single function."""   379     def __init__(self, id, name, thread, cpu_start, cpu_end):
   387     def append_kernel(self, name, device, start, end):
   391     def cuda_time_total(self):
   392         return sum(kinfo.interval.elapsed_us() 
for kinfo 
in self.
kernels)
   395     def cpu_time_total(self):
   396         return self.cpu_interval.elapsed_us()
   403         return '<FunctionEvent id={} cpu_time={} cuda_time={} name={} thread={}>'.format(
   408     """Used to average stats over multiple FunctionEvent objects."""   413     def __iadd__(self, other):
   416         assert isinstance(other, FunctionEvent)
   417         assert other.key == self.
key   424         return '<FunctionEventAvg cpu_time={} cuda_time={} key={}>'.format(
   432     def __missing__(self, key):
   433         self[key] = torch._C._demangle(key)
   440 def parse_cpu_trace(thread_records):
   453     def adjusted_time(cuda_record):
   454         assert cuda_record.device() != -1
   455         cuda_time_0 = cuda_records[cuda_record.device()]
   456         return cuda_time_0.cuda_elapsed_us(cuda_record) + start_record.cpu_elapsed_us(cuda_time_0)
   459     for record 
in itertools.chain(*thread_records):
   460         if record.name() == 
'__start_profile':
   461             start_record = record
   462         elif record.name() == 
'__cuda_start_event':
   463             assert record.device() != -1
   464             cuda_records[record.device()] = record
   465     assert start_record 
is not None   467     for record 
in itertools.chain(*thread_records):
   468         if record.kind() == 
'mark':
   470         elif record.kind() == 
'push':
   471             record_stack.append((next_id, record))
   473         elif record.kind() == 
'pop':
   474             function_id, start = record_stack.pop()
   477                 name=string_table[start.name()],
   478                 thread=start.thread_id(),
   479                 cpu_start=start_record.cpu_elapsed_us(start),
   480                 cpu_end=start_record.cpu_elapsed_us(record))
   482                 cuda_start = adjusted_time(start)
   483                 cuda_end = adjusted_time(record)
   484                 fe.append_kernel(start.name(),
   490     functions.sort(key=
lambda evt: evt.cpu_interval.start)
   498     """Raises an error if a key is seen more than once."""   504             raise RuntimeError(
'duplicate key: ' + str(key))
   508 def parse_nvprof_trace(path):
   510     conn = sqlite3.connect(path)
   511     conn.row_factory = sqlite3.Row
   515     for r 
in conn.execute(
"SELECT _id_ as id, value FROM StringTable"):
   516         strings[r[
"id"]] = torch._C._demangle(r[
"value"])
   521         start.id AS marker_id, start.name, start.timestamp AS start_time, end.timestamp AS end_time   523         CUPTI_ACTIVITY_KIND_MARKER AS start INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end   526         start.name != 0 AND end.name = 0   531     for row 
in conn.execute(marker_query):
   532         unique.see(row[
'marker_id'])
   534                             name=strings[row[
'name']],
   535                             cpu_start=row[
'start_time'],
   536                             cpu_end=row[
'end_time'],
   538         functions.append(evt)
   539         functions_map[evt.id] = evt
   544         start.id AS marker_id, start.name, start.timestamp, end.timestamp,   545         runtime._id_ AS runtime_id, runtime.cbid, runtime.start AS runtime_start, runtime.end AS runtime_end,   546         kernel.start AS kernel_start, kernel.end AS kernel_end, kernel.name AS kernel_name   548         CUPTI_ACTIVITY_KIND_MARKER AS start   549         INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end   551         INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME as runtime   552             ON (start.timestamp < runtime.start AND runtime.end < end.timestamp)   553         INNER JOIN CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL AS kernel   554             ON kernel.correlationId = runtime.correlationId   557     for row 
in conn.execute(kernel_query):
   558         unique.see(row[
'marker_id'], row[
'runtime_id'])
   559         assert row[
'cbid'] == 13  
   560         evt = functions_map[row[
'marker_id']]
   561         evt.append_kernel(row[
'kernel_name'],
   566     functions.sort(key=
lambda evt: evt.cpu_interval.start)
   573 def build_table(events, sort_by=None, header=None):
   574     """Prints a summary of events (which can be a list of FunctionEvent or FunctionEventAvg)."""   575     if sort_by 
is not None:
   576         events = sorted(events, key=
lambda evt: getattr(evt, sort_by))
   578     name_lengths = [len(evt.key) 
for evt 
in events]
   579     if len(name_lengths) == 0:
   581     max_name_length = max(name_lengths)
   584     col_format = 
'  {: >' + str(col_width) + 
'}'   585     row_format = 
'{: <' + str(max_name_length) + 
'}' + col_format * 5
   586     header_sep = 
'-' * max_name_length + (
'  ' + 
'-' * col_width) * 5
   596     if header 
is not None:
   597         line_length = max_name_length + (col_width + 2) * 5
   598         append(
'=' * line_length)
   601     append(row_format.format(
'Name', 
'CPU time', 
'CUDA time', 
'Calls', 
'CPU total', 
'CUDA total'))
   604         append(row_format.format(evt.key, evt.cpu_time_str, evt.cuda_time_str,
   605                                  evt.count, evt.cpu_time_total_str, evt.cuda_time_total_str))
   607     return ''.join(result)
 
def table(self, sort_by=None)
 
def export_chrome_trace(self, path)