Corrects Error Codes from cudaHostRegister (#132089)
Causing some terrible error messages e.g. :
```
# printing directly: cudaError.???
# casting to int first: 712
Traceback (most recent call last):
File "/data/users/lpasqualin/fbsource/fbcode/scripts/lpasqualin/playground.py", line 15, in <module>
main()
File "/data/users/lpasqualin/fbsource/fbcode/scripts/lpasqualin/playground.py", line 11, in main
_create_cpu_state_dict(sd, share_memory=True, pin_memory=True)
File "/home/lpasqualin/pytorch/torch/distributed/_state_dict_utils.py", line 436, in _create_cpu_state_dict
ret = _iterate_state_dict(
^^^^^^^^^^^^^^^^^^^^
File "/home/lpasqualin/pytorch/torch/distributed/_state_dict_utils.py", line 143, in _iterate_state_dict
ret = {
^
File "/home/lpasqualin/pytorch/torch/distributed/_state_dict_utils.py", line 144, in <dictcomp>
key: _iterate_state_dict(
^^^^^^^^^^^^^^^^^^^^
File "/home/lpasqualin/pytorch/torch/distributed/_state_dict_utils.py", line 125, in _iterate_state_dict
ret = tensor_func(iter_object, pg, device, companion_obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lpasqualin/pytorch/torch/distributed/_state_dict_utils.py", line 428, in tensor_func
succ == 0
AssertionError: Pinning shared memory failed with error-code: cudaError.???
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132089
Approved by: https://github.com/Skylion007
diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py
index a0c89b0..2961368 100644
--- a/torch/distributed/_state_dict_utils.py
+++ b/torch/distributed/_state_dict_utils.py
@@ -411,16 +411,18 @@
if pin_memory:
def unpin_memory(t):
- succ = torch.cuda.cudart().cudaHostUnregister(t.data_ptr())
+ succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr()))
assert (
succ == 0
), f"Unpinning shared memory failed with error-code: {succ}"
weakref.finalize(t, unpin_memory, t)
- succ = torch.cuda.cudart().cudaHostRegister(
- t.data_ptr(),
- t.numel() * t.element_size(),
- 1, # lines up with 'cudaHostRegisterPortable'
+ succ = int(
+ torch.cuda.cudart().cudaHostRegister(
+ t.data_ptr(),
+ t.numel() * t.element_size(),
+ 1, # lines up with 'cudaHostRegisterPortable'
+ )
)
assert (
succ == 0