blob: 25bacdcc7885218192046fe3d60a7f65552ccec5 [file] [log] [blame]
import ctypes
import torch
from torch._streambase import _StreamBase
from ._utils import _dummy_type
if not hasattr(torch._C, "_XpuStreamBase"):
# Define dummy base classes
torch._C.__dict__["_XpuStreamBase"] = _dummy_type("_XpuStreamBase")
class Stream(torch._C._XpuStreamBase, _StreamBase):
r"""Wrapper around a XPU stream.
A XPU stream is a linear sequence of execution that belongs to a specific
device, independent from other streams.
Args:
device(torch.device or int, optional): a device on which to allocate
the stream. If :attr:`device` is ``None`` (default) or a negative
integer, this will use the current device.
priority(int, optional): priority of the stream, should be 0 or
negative, where negative numbers indicate higher priority. By default,
streams have priority 0.
"""
def __new__(cls, device=None, priority=0, **kwargs):
# setting device manager is expensive, so we avoid it unless necessary
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
return super().__new__(cls, priority=priority, **kwargs)
else:
with torch.xpu.device(device):
return super().__new__(cls, priority=priority, **kwargs)
def wait_event(self, event):
pass
def wait_stream(self, stream):
pass
def record_event(self, event=None):
pass
def query(self):
r"""Check if all the work submitted has been completed.
Returns:
A boolean indicating if all kernels in this stream are completed.
"""
return super().query()
def synchronize(self):
r"""Wait for all the kernels in this stream to complete."""
super().synchronize()
@property
def _as_parameter_(self):
return ctypes.c_void_p(self.sycl_queue)
def __eq__(self, o):
if isinstance(o, Stream):
return super().__eq__(o)
return False
def __hash__(self):
return hash((self.sycl_queue, self.device))
def __repr__(self):
return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"