Update cuda amp to also check xla device (#63413)
Summary:
Fixes https://github.com/pytorch/xla/issues/3086. Pytorch/XLA:GPU also use cuda amp. I verified the pt/xla `test_autocast` with this fix and all test passed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63413
Reviewed By: ngimel
Differential Revision: D30380785
Pulled By: bdhirsh
fbshipit-source-id: fd1a1de7d224c616fc3fa90b80a688a21f6b1ecc
diff --git a/torch/autocast_mode.py b/torch/autocast_mode.py
index edf36d2..ec9fdb0 100644
--- a/torch/autocast_mode.py
+++ b/torch/autocast_mode.py
@@ -135,7 +135,7 @@
self.fast_dtype = torch.get_autocast_cpu_dtype()
else:
raise RuntimeError('User specified autocast device_type must be \'cuda\' or \'cpu\'')
- if not torch.cuda.is_available() and self.device == 'cuda':
+ if torch.cuda.amp.common.amp_definitely_not_available() and self.device == 'cuda':
warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')
enabled = False
for key, value in kwargs.items():