| # Copyright 2019 Kakao Brain |
| # |
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| # |
| # This source code is licensed under the BSD license found in the |
| # LICENSE file in the root directory of this source tree. |
| """Utilities for eliminating boilerplate code to handle abstract streams with |
| CPU device. |
| """ |
| from contextlib import contextmanager |
| from typing import Generator, List, Union, cast |
| |
| import torch |
| |
| __all__: List[str] = ["CPUStreamType", "new_stream", "current_stream", "default_stream", |
| "use_device", "use_stream", "get_device", "wait_stream", "record_stream", |
| "is_cuda", "as_cuda"] |
| |
| |
| class CPUStreamType: |
| pass |
| |
| |
| # The placeholder on place of streams for the CPU device instead of CUDA. |
| CPUStream = CPUStreamType() |
| |
| # It represents both CUDA streams and the CPU stream. |
| AbstractStream = Union[torch.cuda.Stream, CPUStreamType] |
| |
| |
| def new_stream(device: torch.device) -> AbstractStream: |
| """Creates a new stream for either CPU or CUDA device.""" |
| if device.type != "cuda": |
| return CPUStream |
| return torch.cuda.Stream(device) |
| |
| |
| def current_stream(device: torch.device) -> AbstractStream: |
| """:func:`torch.cuda.current_stream` for either CPU or CUDA device.""" |
| if device.type != "cuda": |
| return CPUStream |
| return torch.cuda.current_stream(device) |
| |
| |
| def default_stream(device: torch.device) -> AbstractStream: |
| """:func:`torch.cuda.default_stream` for either CPU or CUDA device.""" |
| if device.type != "cuda": |
| return CPUStream |
| return torch.cuda.default_stream(device) |
| |
| |
| @contextmanager |
| def use_device(device: torch.device) -> Generator[None, None, None]: |
| """:func:`torch.cuda.device` for either CPU or CUDA device.""" |
| if device.type != "cuda": |
| yield |
| return |
| |
| with torch.cuda.device(device): |
| yield |
| |
| |
| @contextmanager |
| def use_stream(stream: AbstractStream) -> Generator[None, None, None]: |
| """:func:`torch.cuda.stream` for either CPU or CUDA stream.""" |
| if not is_cuda(stream): |
| yield |
| return |
| |
| with torch.cuda.stream(as_cuda(stream)): |
| yield |
| |
| |
| def get_device(stream: AbstractStream) -> torch.device: |
| """Gets the device from CPU or CUDA stream.""" |
| if is_cuda(stream): |
| return as_cuda(stream).device |
| return torch.device("cpu") |
| |
| |
| def wait_stream(source: AbstractStream, target: AbstractStream) -> None: |
| """:meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It |
| makes the source stream wait until the target stream completes work queued. |
| """ |
| if is_cuda(target): |
| if is_cuda(source): |
| # A CUDA stream waits another CUDA stream. |
| as_cuda(source).wait_stream(as_cuda(target)) |
| else: |
| # CPU waits a CUDA stream. |
| as_cuda(target).synchronize() |
| |
| # If the target is CPU, synchronization is not required. |
| |
| |
| def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None: |
| """:meth:`torch.Tensor.record_stream` for either CPU or CUDA stream.""" |
| if is_cuda(stream): |
| # NOTE(sublee): record_stream() on a shifted view tensor throws |
| # RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely |
| # protect the tensor against unexpected reallocation, here we use a |
| # temporal tensor associated with the same storage without shifting as |
| # a workaround. |
| # |
| # Issue: https://github.com/pytorch/pytorch/issues/27366 |
| # |
| tensor = tensor.new_empty([0]).set_(tensor._typed_storage()) |
| |
| # Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream |
| tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type] |
| |
| |
| def is_cuda(stream: AbstractStream) -> bool: |
| """Returns ``True`` if the given stream is a valid CUDA stream.""" |
| return stream is not CPUStream |
| |
| |
| def as_cuda(stream: AbstractStream) -> torch.cuda.Stream: |
| """Casts the given stream as :class:`torch.cuda.Stream`.""" |
| return cast(torch.cuda.Stream, stream) |