Skip to content

Commit

Permalink
Fix nvprof mode in autograd profiler
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored and ezyang committed Oct 20, 2017
1 parent 17a8171 commit 76abc06
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 78 deletions.
4 changes: 4 additions & 0 deletions docs/source/autograd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,7 @@ Profiler
.. autoclass:: torch.autograd.profiler.profile
:members:

.. autoclass:: torch.autograd.profiler.emit_nvtx
:members:

.. autofunction:: torch.autograd.profiler.load_nvprof
195 changes: 117 additions & 78 deletions torch/autograd/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,72 @@ def __str__(self):
def table(self, sort_by=None):
return build_table(self, sort_by)

def export_chrome_trace(self, path):
"""Exports an EventList as a Chrome tracing tools file.
The checkpoint can be later loaded and inspected under ``chrome://tracing`` URL.
Arguments:
path (str): Path where the trace will be written.
"""
import json
with open(path, 'w') as f:
chrome_events = []
for evt in self:
chrome_events.append(dict(
name=evt.name,
ph='X',
ts=evt.start / 1000,
dur=evt.cpu_time_total / 1000,
tid='Autograd functions',
pid='Autograd functions',
args={},
))
json.dump(chrome_events, f)

def key_averages(self):
"""Averages all function events over their keys.
Returns:
An EventList containing FunctionEventAvg objects.
"""
stats = defaultdict(FunctionEventAvg)
for evt in self:
stats[evt.key] += evt
return EventList(stats.values())

def total_average(self):
"""Averages all events.
Returns:
A FunctionEventAvg object.
"""
total_stat = FunctionEventAvg()
for evt in self:
total_stat += evt
total_stat.key = None
total_stat.key = 'Total'
return total_stat


class profile(object):
"""Context manager that manages autograd profiler state and holds a summary of results.
Arguments:
use_nvprof (bool, optional): If True, uses nvprof (might incur high overhead and
assumes that the whole process is running inside nvprof), otherwise uses a custom
CPU-only profiler (with negligible overhead). Default: False.
trace_path (str, optional): A path of the CUDA checkpoint. If specified, it will be left
unmodified after profiling finishes, so it can be opened and inspected in nvvp. Otherwise
it will be created in a temporary directory and removed after reading the results.
enabled (bool, optional): Setting this to False makes this context manager a no-op.
Default: True.
.. warning:
This context managers should not be called recursively, i.e. at most one
instance should be enabled at any given time.
Example:
>>> x = Variable(torch.randn(1, 1), requires_grad=True)
>>> with torch.autograd.profiler.profile() as prof:
... y = x ** 2
... y.backward()
>>> # NOTE: some columns were removed for brevity
... print(prof.key_averages())
... print(prof)
------------------------------------- --------------- ---------------
Name CPU time CUDA time
------------------------------------- --------------- ---------------
Expand All @@ -53,51 +98,27 @@ class profile(object):
N5torch8autograd5CloneE 4.088us 0.000us
"""

def __init__(self, use_nvprof=False, trace_path=None, enabled=True):
def __init__(self, enabled=True):
self.enabled = enabled
self.function_events = None
if not self.enabled:
return
self.entered = False
self.use_nvprof = use_nvprof
if use_nvprof:
if trace_path is None:
# The file will be deleted in the destructor
with tempfile.NamedTemporaryFile(delete=False) as f:
self.delete_trace = True
self.trace_path = f.name
else:
self.trace_path = trace_path
self.delete_trace = False
else:
self.trace_path = None
self.delete_trace = False

def __del__(self):
if not self.enabled:
return
if self.delete_trace:
os.unlink(trace_path)

def __enter__(self):
if not self.enabled:
return
if self.entered:
raise RuntimeError("autograd profiler traces are not reentrant")
self.entered = True
if self.use_nvprof:
torch.cuda.profiler.initialize()
torch.autograd._enable_profiler(self.use_nvprof)
torch.autograd._enable_profiler(False)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if not self.enabled:
return
records = torch.autograd._disable_profiler()
if self.use_nvprof:
self.function_events = EventList(parse_nvprof_trace(self.used_cuda_path))
else:
self.function_events = EventList(parse_cpu_trace(records))
self.function_events = EventList(parse_cpu_trace(records))
return False

def __repr__(self):
Expand All @@ -111,57 +132,75 @@ def __str__(self):
return str(self.function_events)

def export_chrome_trace(self, path):
"""Exports a list of FunctionEvents as Chrome trace.
The checkpoint can be later loaded and inspected under ``chrome://tracing`` URL.
Arguments:
path (str): Path where the trace will be written.
"""
if self.function_events is None:
raise RuntimeError("can't export a trace that didn't finish running")
import json
with open(path, 'w') as f:
chrome_events = []
for evt in self.function_events:
chrome_events.append(dict(
name=evt.name,
ph='X',
ts=evt.start / 1000,
dur=evt.cpu_time_total / 1000,
tid='Autograd functions',
pid='Autograd functions',
args={},
))
json.dump(chrome_events, f)
return self.function_events.export_chrome_trace(path)
export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__

def key_averages(self):
"""Averages all function events over their keys.
Returns:
A list of FunctionEventAvg objects.
"""
def key_averages(self, path):
if self.function_events is None:
raise RuntimeError("can't export a trace that didn't finish running")
stats = defaultdict(FunctionEventAvg)
for evt in self.function_events:
stats[evt.key] += evt
return EventList(stats.values())
raise RuntimeError("can't average a trace that didn't finish running")
return self.function_events.key_averages()
key_averages.__doc__ = EventList.key_averages.__doc__

def total_average(self):
"""Averages all events.
Returns:
A FunctionEventAvg object.
"""
if self.function_events is None:
raise RuntimeError("can't export a trace that didn't finish running")
total_stat = FunctionEventAvg()
for evt in self.function_events:
total_stat += evt
total_stat.key = None
total_stat.key = 'Total'
return total_stat
raise RuntimeError("can't average a trace that didn't finish running")
return self.function_events.total_average()
total_average.__doc__ = EventList.total_average.__doc__


class emit_nvtx(object):
"""Context manager that makes every autograd operation emit an NVTX range.
It is useful when running the program under nvprof. Unfortunately, there's no
way to force nvprof to flush the data it collected to disk, so for CUDA profiling
one has to use this context manager to annotate nvprof traces, and then use
:func:`torch.autograd.profiler.open_nvtx` to analyze the checkpoint.
.. warning:
This context managers should not be called recursively, i.e. at most one
instance should be enabled at any given time.
Arguments:
enabled (bool, optional): Setting this to False makes this context manager a no-op.
Default: True.
Example:
>>> with torch.cuda.profiler.profile():
... model(x) # Warmup CUDA memory allocator and profiler
... with torch.autograd.profiler.emit_nvtx():
... model(x)
"""
def __init__(self, enabled=True):
self.enabled = True
self.entered = False

def __enter__(self):
if not self.enabled:
return
if self.entered:
raise RuntimeError("NVTX annotation context manager is not reentrant")
self.entered = True
torch.cuda.synchronize()
torch.autograd._enable_profiler(True)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if not self.enabled:
return
torch.cuda.synchronize()
torch.autograd._disable_profiler()
return False


def load_nvprof(path):
"""Opens an nvprof trace file and parses autograd annotations.
Arguments:
path (str): path to nvprof trace
"""
return EventList(parse_nvprof_trace(path))


################################################################################
Expand Down
10 changes: 10 additions & 0 deletions torch/cuda/profiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ctypes
import tempfile
import contextlib
from . import cudart, check_error


Expand Down Expand Up @@ -43,3 +44,12 @@ def start():

def stop():
check_error(cudart().cudaProfilerStop())


@contextlib.contextmanager
def profile():
try:
start()
yield
finally:
stop()

0 comments on commit 76abc06

Please sign in to comment.