commit | f84755bcac016729a1a2d3d850f50aee8beb58e8 | [log] [tgz] |
---|---|---|
author | Nikita Shulga <2453524+malfet@users.noreply.github.com> | Mon Oct 16 23:26:54 2023 +0000 |
committer | PyTorch MergeBot <pytorchmergebot@users.noreply.github.com> | Mon Oct 16 23:26:58 2023 +0000 |
tree | e836b0763767ffc80740582f0c48faa9820e38c7 | |
parent | 9683a26c55a8178e852d17fb9753e9907a7c6174 [diff] |
Fix _CudaStreamBase type annotations (#111387) Make it inherit from `Stream` as indeed it is, see https://github.com/pytorch/pytorch/blob/97a513ed077323550b808e690a0b5a0452f87334/torch/csrc/cuda/Stream.cpp#L208 and ``` python3 -c "import torch;print(torch._C._CudaStreamBase.__base__)" <class 'torch.Stream'> ``` Fixes https://github.com/pytorch/pytorch/issues/111268 TODO (in separate PR): Revive `test_typing` and add regression test Pull Request resolved: https://github.com/pytorch/pytorch/pull/111387 Approved by: https://github.com/jeanschmidt, https://github.com/Skylion007
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 6bf3b39..6e8dfcd 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in
@@ -1760,7 +1760,7 @@ def _gather_out(tensors: List[Tensor], out_tensor: Tensor, dim: _int) -> Tensor: ... # Defined in torch/csrc/cuda/Stream.cpp -class _CudaStreamBase: +class _CudaStreamBase(Stream): stream_id: _int device_index: _int device_type: _int