Fix missing mandatory device_type argument in autocast docstring (#97223)
Fixes #[92803](https://github.com/pytorch/pytorch/issues/92803)



Pull Request resolved: https://github.com/pytorch/pytorch/pull/97223
Approved by: https://github.com/albanD, https://github.com/malfet
diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py
index 255a935..9bb350c 100644
--- a/torch/amp/autocast_mode.py
+++ b/torch/amp/autocast_mode.py
@@ -41,7 +41,7 @@
optimizer.zero_grad()
# Enables autocasting for the forward pass (model + loss)
- with autocast():
+ with torch.autocast(device_type="cuda"):
output = model(input)
loss = loss_fn(output, target)
@@ -56,7 +56,7 @@
class AutocastModel(nn.Module):
...
- @autocast()
+ @torch.autocast(device_type="cuda")
def forward(self, input):
...
@@ -74,7 +74,7 @@
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
- with autocast():
+ with torch.autocast(device_type="cuda"):
# torch.mm is on autocast's list of ops that should run in float16.
# Inputs are float32, but the op runs in float16 and produces float16 output.
# No manual casts are required.
@@ -153,9 +153,9 @@
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
- with autocast():
+ with torch.autocast(device_type="cuda"):
e_float16 = torch.mm(a_float32, b_float32)
- with autocast(enabled=False):
+ with torch.autocast(device_type="cuda", enabled=False):
# Calls e_float16.float() to ensure float32 execution
# (necessary because e_float16 was created in an autocasted region)
f_float32 = torch.mm(c_float32, e_float16.float())
diff --git a/torch/csrc/jit/JIT-AUTOCAST.md b/torch/csrc/jit/JIT-AUTOCAST.md
index c8e7ffa..f32833f 100644
--- a/torch/csrc/jit/JIT-AUTOCAST.md
+++ b/torch/csrc/jit/JIT-AUTOCAST.md
@@ -79,6 +79,9 @@
will be emitted)
```python
+import torch
+from torch.cpu.amp import autocast
+
@autocast(enabled=True)
def helper(x):
...
@@ -91,6 +94,9 @@
Another example
```python
+import torch
+from torch.cpu.amp import autocast
+
@torch.jit.script
@autocast() # not supported
def foo(a, b, c, d):
@@ -100,6 +106,9 @@
#### Autocast argument must be a compile-time constant
```python
+import torch
+from torch.cpu.amp import autocast
+
@torch.jit.script
def fn(a, b, use_amp: bool):
# runtime values for autocast enable argument are not supported
@@ -111,6 +120,9 @@
#### Uncommon autocast usage patterns may not be supported
```python
+import torch
+from torch.cpu.amp import autocast
+
@torch.jit.script
def fn(a, b, c, d):
with autocast(enabled=True) as autocast_instance: # not supported
@@ -140,6 +152,9 @@
> This is one known limitation where we don't have a way to emit a diagnostic!
```python
+import torch
+from torch.cpu.amp import autocast
+
def helper(a, b):
with autocast(enabled=False):
return torch.mm(a, b) * 2.0
@@ -158,6 +173,9 @@
function from eager mode:
```python
+import torch
+from torch.cpu.amp import autocast
+
@torch.jit.script
def fn(a, b):
return torch.mm(a, b)
@@ -176,6 +194,9 @@
within a scripted function, autocasting will still occur.
```python
+import torch
+from torch.cuda.amp import autocast
+
@torch.jit.script
def fn(a, b):
with autocast(enabled=False):