| from contextlib import contextmanager | 
 |  | 
 | try: | 
 |     from torch._C import _nvtx | 
 | except ImportError: | 
 |  | 
 |     class _NVTXStub: | 
 |         @staticmethod | 
 |         def _fail(*args, **kwargs): | 
 |             raise RuntimeError( | 
 |                 "NVTX functions not installed. Are you sure you have a CUDA build?" | 
 |             ) | 
 |  | 
 |         rangePushA = _fail | 
 |         rangePop = _fail | 
 |         markA = _fail | 
 |  | 
 |     _nvtx = _NVTXStub()  # type: ignore[assignment] | 
 |  | 
 | __all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"] | 
 |  | 
 |  | 
 | def range_push(msg): | 
 |     """ | 
 |     Pushes a range onto a stack of nested range span.  Returns zero-based | 
 |     depth of the range that is started. | 
 |  | 
 |     Args: | 
 |         msg (str): ASCII message to associate with range | 
 |     """ | 
 |     return _nvtx.rangePushA(msg) | 
 |  | 
 |  | 
 | def range_pop(): | 
 |     """ | 
 |     Pops a range off of a stack of nested range spans.  Returns the | 
 |     zero-based depth of the range that is ended. | 
 |     """ | 
 |     return _nvtx.rangePop() | 
 |  | 
 |  | 
 | def range_start(msg) -> int: | 
 |     """ | 
 |     Mark the start of a range with string message. It returns an unique handle | 
 |     for this range to pass to the corresponding call to rangeEnd(). | 
 |  | 
 |     A key difference between this and range_push/range_pop is that the | 
 |     range_start/range_end version supports range across threads (start on one | 
 |     thread and end on another thread). | 
 |  | 
 |     Returns: A range handle (uint64_t) that can be passed to range_end(). | 
 |  | 
 |     Args: | 
 |         msg (str): ASCII message to associate with the range. | 
 |     """ | 
 |     return _nvtx.rangeStartA(msg) | 
 |  | 
 |  | 
 | def range_end(range_id) -> None: | 
 |     """ | 
 |     Mark the end of a range for a given range_id. | 
 |  | 
 |     Args: | 
 |         range_id (int): an unique handle for the start range. | 
 |     """ | 
 |     _nvtx.rangeEnd(range_id) | 
 |  | 
 |  | 
 | def mark(msg): | 
 |     """ | 
 |     Describe an instantaneous event that occurred at some point. | 
 |  | 
 |     Args: | 
 |         msg (str): ASCII message to associate with the event. | 
 |     """ | 
 |     return _nvtx.markA(msg) | 
 |  | 
 |  | 
 | @contextmanager | 
 | def range(msg, *args, **kwargs): | 
 |     """ | 
 |     Context manager / decorator that pushes an NVTX range at the beginning | 
 |     of its scope, and pops it at the end. If extra arguments are given, | 
 |     they are passed as arguments to msg.format(). | 
 |  | 
 |     Args: | 
 |         msg (str): message to associate with the range | 
 |     """ | 
 |     range_push(msg.format(*args, **kwargs)) | 
 |     try: | 
 |         yield | 
 |     finally: | 
 |         range_pop() |