Moving _run_autocast_outofplace to basic class named TestAutocast to reduce redundance (#134460)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134460
Approved by: https://github.com/EikanWang, https://github.com/ezyang
diff --git a/test/test_autocast.py b/test/test_autocast.py
index 4c3f031..8e702f9 100644
--- a/test/test_autocast.py
+++ b/test/test_autocast.py
@@ -1,10 +1,12 @@
# Owner(s): ["module: unknown"]
-import collections
import unittest
import torch
-from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
+from torch.testing._internal.autocast_test_lists import (
+ AutocastCPUTestLists,
+ TestAutocast,
+)
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
@@ -14,7 +16,7 @@
from torch.utils._python_dispatch import TorchDispatchMode
-class TestAutocastCPU(TestCase):
+class TestAutocastCPU(TestAutocast):
def setUp(self):
super().setUp()
self.autocast_lists = AutocastCPUTestLists(torch.device("cpu"))
@@ -23,100 +25,6 @@
del self.autocast_lists
super().tearDown()
- def _run_autocast_outofplace(
- self,
- op,
- args,
- run_as_type,
- out_type=None,
- module=torch,
- add_kwargs=None,
- amp_dtype=torch.bfloat16,
- ):
- # helper to cast args
- def cast(val, to_type):
- if isinstance(val, torch.Tensor):
- return val.to(to_type) if val.is_floating_point() else val
- elif isinstance(val, collections.abc.Iterable):
- return type(val)(cast(v, to_type) for v in val)
- else:
- return val
-
- if add_kwargs is None:
- add_kwargs = {}
-
- self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
- with torch.amp.autocast(device_type="cpu", dtype=amp_dtype):
- self.assertTrue(torch.is_autocast_enabled(device_type="cpu"))
- out_type = out_type if out_type is not None else run_as_type
- output = output_method = None
-
- # Try module.* variant, if requested:
- if module is not None and hasattr(module, op):
- output = getattr(module, op)(*args, **add_kwargs)
- if isinstance(output, torch.Tensor):
- self.assertTrue(
- out_type == output.dtype,
- f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
- )
- # Try Tensor.* variant:
- if hasattr(torch.Tensor, op):
- output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
- if isinstance(output_method, torch.Tensor):
- self.assertTrue(
- out_type == output_method.dtype,
- f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
- )
-
- self.assertTrue(
- (output is not None) or (output_method is not None),
- f"{op} not found as an attribute on either Tensor or the requested module {module}",
- )
-
- # Accounts for ops that return Tensors, iterables, and other non-Tensors.
- # For example, lstm_cell returns a tuple and equal returns bool.
- def compare(first, second):
- if isinstance(first, torch.Tensor):
- return torch.equal(first, second)
- elif isinstance(first, collections.abc.Iterable):
- return all(compare(f, s) for f, s in zip(first, second))
- else:
- return first == second
-
- # If both torch.* and Tensor.* variants were found, check outputs are identical
- if (output is not None) and (output_method is not None):
- self.assertTrue(type(output) == type(output_method))
- comparison = compare(output, output_method)
- self.assertTrue(
- comparison, f"torch.{op} result did not match Tensor.{op} result"
- )
-
- # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
- # as the C++-side autocasting, and should be bitwise accurate.
- output_to_compare = output if output is not None else output_method
- with torch.amp.autocast(device_type="cpu", enabled=False):
- self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
-
- if module is not None and hasattr(module, op):
- control = getattr(module, op)(
- *cast(args, run_as_type), **add_kwargs
- )
- else:
- control = getattr(args[0].to(run_as_type), op)(
- *cast(args[1:], run_as_type), **add_kwargs
- )
- self.assertTrue(type(output_to_compare) == type(control))
- comparison = compare(output_to_compare, control)
- self.assertTrue(comparison, f"torch.{op} result did not match control")
- self.assertTrue(torch.is_autocast_enabled(device_type="cpu"))
- self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
-
- def args_maybe_kwargs(self, op_with_args):
- if len(op_with_args) == 2:
- return op_with_args[0], op_with_args[1], {}
- else:
- return op_with_args[0], op_with_args[1], op_with_args[2]
-
@skipIfTorchDynamo()
def test_autocast_torch_expect_builtin_promote(self):
for (
@@ -125,9 +33,16 @@
args2,
out_type,
) in self.autocast_lists.torch_expect_builtin_promote:
- self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
self._run_autocast_outofplace(
- op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16
+ op, args1, torch.float32, device="cpu", out_type=out_type
+ )
+ self._run_autocast_outofplace(
+ op,
+ args2,
+ torch.float32,
+ device="cpu",
+ out_type=out_type,
+ amp_dtype=torch.float16,
)
@skipIfTorchDynamo()
@@ -139,12 +54,13 @@
out_type,
) in self.autocast_lists.methods_expect_builtin_promote:
self._run_autocast_outofplace(
- op, args1, torch.float32, module=None, out_type=out_type
+ op, args1, torch.float32, device="cpu", module=None, out_type=out_type
)
self._run_autocast_outofplace(
op,
args2,
torch.float32,
+ device="cpu",
module=None,
out_type=out_type,
amp_dtype=torch.float16,
@@ -155,12 +71,13 @@
for op_with_args in self.autocast_lists.torch_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(
- op, args, torch.bfloat16, add_kwargs=maybe_kwargs
+ op, args, torch.bfloat16, device="cpu", add_kwargs=maybe_kwargs
)
self._run_autocast_outofplace(
op,
args,
torch.float16,
+ device="cpu",
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)
@@ -170,12 +87,18 @@
for op_with_args in self.autocast_lists.nn_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(
- op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs
+ op,
+ args,
+ torch.bfloat16,
+ device="cpu",
+ module=torch._C._nn,
+ add_kwargs=maybe_kwargs,
)
self._run_autocast_outofplace(
op,
args,
torch.float16,
+ device="cpu",
module=torch._C._nn,
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
@@ -186,12 +109,13 @@
for op_with_args in self.autocast_lists.torch_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(
- op, args, torch.float32, add_kwargs=maybe_kwargs
+ op, args, torch.float32, device="cpu", add_kwargs=maybe_kwargs
)
self._run_autocast_outofplace(
op,
args,
torch.float32,
+ device="cpu",
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)
@@ -201,12 +125,18 @@
for op_with_args in self.autocast_lists.nn_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(
- op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs
+ op,
+ args,
+ torch.float32,
+ device="cpu",
+ module=torch._C._nn,
+ add_kwargs=maybe_kwargs,
)
self._run_autocast_outofplace(
op,
args,
torch.float32,
+ device="cpu",
module=torch._C._nn,
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
@@ -215,9 +145,9 @@
@skipIfTorchDynamo()
def test_autocast_torch_need_autocast_promote(self):
for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
- self._run_autocast_outofplace(op, args1, torch.float32)
+ self._run_autocast_outofplace(op, args1, torch.float32, device="cpu")
self._run_autocast_outofplace(
- op, args2, torch.float32, amp_dtype=torch.float16
+ op, args2, torch.float32, device="cpu", amp_dtype=torch.float16
)
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
diff --git a/test/test_cuda.py b/test/test_cuda.py
index c6a61e1..2195e71 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -1,6 +1,5 @@
# Owner(s): ["module: cuda"]
-import collections
import contextlib
import ctypes
import gc
@@ -29,7 +28,7 @@
segment_plot,
trace_plot,
)
-from torch.testing._internal.autocast_test_lists import AutocastTestLists
+from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
from torch.testing._internal.common_cuda import (
_create_scaling_case,
_get_torch_cuda_version,
@@ -61,7 +60,6 @@
IS_WINDOWS,
load_tests,
NO_MULTIPROCESSING_SPAWN,
- NoTest,
parametrize,
run_tests,
serialTest,
@@ -85,10 +83,6 @@
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
-if not TEST_CUDA:
- print("CUDA not available, skipping tests", file=sys.stderr)
- TestCase = NoTest # noqa: F811
-
try:
import torchvision.models # noqa: F401
from torchvision.models import resnet18 # noqa: F401
@@ -113,6 +107,7 @@
_cycles_per_ms = None
+@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCuda(TestCase):
_do_cuda_memory_leak_check = True
@@ -121,10 +116,8 @@
def setUp(self):
super().setUp()
- self.autocast_lists = AutocastTestLists(torch.device("cuda:0"))
def tearDown(self):
- del self.autocast_lists
super().tearDown()
@property
@@ -1558,545 +1551,6 @@
for t in range(num_threads):
self.assertEqual(results[t].sum().item(), size * size)
- def _run_autocast_outofplace(
- self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None
- ):
- # helper to cast args
- def cast(val, to_type):
- if isinstance(val, torch.Tensor):
- return val.to(to_type) if val.is_floating_point() else val
- elif isinstance(val, collections.abc.Iterable):
- return type(val)(cast(v, to_type) for v in val)
- else:
- return val
-
- if add_kwargs is None:
- add_kwargs = {}
- fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16
- self.assertFalse(torch.is_autocast_enabled())
- with torch.autocast("cuda", dtype=fast_dtype):
- self.assertTrue(torch.is_autocast_enabled())
-
- out_type = out_type if out_type is not None else run_as_type
- output = output_method = None
-
- # Try module.* variant, if requested:
- if module is not None and hasattr(module, op):
- output = getattr(module, op)(*args, **add_kwargs)
- if isinstance(output, torch.Tensor):
- self.assertTrue(
- out_type == output.dtype,
- f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
- )
-
- # Try Tensor.* variant:
- if hasattr(torch.Tensor, op):
- output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
- if isinstance(output_method, torch.Tensor):
- self.assertTrue(
- out_type == output_method.dtype,
- f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
- )
-
- self.assertTrue(
- (output is not None) or (output_method is not None),
- f"{op} not found as an attribute on either Tensor or the requested module {module}",
- )
-
- # Accounts for ops that return Tensors, iterables, and other non-Tensors.
- # For example, lstm_cell returns a tuple and equal returns bool.
- def compare(first, second):
- if isinstance(first, torch.Tensor):
- return torch.equal(first, second)
- elif isinstance(first, collections.abc.Iterable):
- return all(compare(f, s) for f, s in zip(first, second))
- else:
- return first == second
-
- # If both torch.* and Tensor.* variants were found, check outputs are identical
- if (output is not None) and (output_method is not None):
- self.assertTrue(type(output) == type(output_method))
- comparison = compare(output, output_method)
- self.assertTrue(
- comparison, f"torch.{op} result did not match Tensor.{op} result"
- )
-
- # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
- # as the C++-side autocasting, and should be bitwise accurate.
- output_to_compare = output if output is not None else output_method
- with torch.autocast("cuda", enabled=False):
- self.assertFalse(torch.is_autocast_enabled())
-
- if module is not None and hasattr(module, op):
- control = getattr(module, op)(
- *cast(args, run_as_type), **add_kwargs
- )
- else:
- control = getattr(args[0].to(run_as_type), op)(
- *cast(args[1:], run_as_type), **add_kwargs
- )
- self.assertTrue(type(output_to_compare) == type(control))
- comparison = compare(output_to_compare, control)
- self.assertTrue(comparison, f"torch.{op} result did not match control")
- self.assertTrue(torch.is_autocast_enabled())
- self.assertFalse(torch.is_autocast_enabled())
-
- def args_maybe_kwargs(self, op_with_args):
- if len(op_with_args) == 2:
- return op_with_args[0], op_with_args[1], {}
- else:
- return op_with_args[0], op_with_args[1], op_with_args[2]
-
- @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
- def test_autocast_torch_fp16(self):
- with torch.backends.cudnn.flags(enabled=True, deterministic=True):
- for op_with_args in self.autocast_lists.torch_fp16:
- skip_test = False
- op, args = op_with_args[0], op_with_args[1]
- if len(op_with_args) == 3:
- skip_test = op_with_args[2] # TEST_WITH_ROCM
- if not skip_test:
- self._run_autocast_outofplace(op, args, torch.float16)
-
- @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
- def test_autocast_torch_bf16(self):
- with torch.backends.cudnn.flags(enabled=True, deterministic=True):
- for op_with_args in self.autocast_lists.torch_fp16:
- skip_test = False
- op, args = op_with_args[0], op_with_args[1]
- if len(op_with_args) == 3:
- skip_test = op_with_args[2] # TEST_WITH_ROCM
- should_error_from_cudnn = "cudnn" in op and (
- "TORCH_CUDNN_V8_API_DISABLED" in os.environ
- and int(os.environ["TORCH_CUDNN_V8_API_DISABLED"])
- or torch.cuda.get_device_capability() < (8, 0)
- )
- should_error_from_not_implemented = should_error_from_cudnn
- if not skip_test:
- if should_error_from_not_implemented:
- with self.assertRaises(
- RuntimeError,
- msg=str(op) + " should not be supported for bfloat16!",
- ):
- self._run_autocast_outofplace(op, args, torch.bfloat16)
- else:
- if torch.cuda.is_bf16_supported():
- self._run_autocast_outofplace(op, args, torch.bfloat16)
- else:
- with self.assertRaisesRegex(
- RuntimeError, "Device does not support bfloat16"
- ):
- self._run_autocast_outofplace(op, args, torch.bfloat16)
-
- @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
- def test_autocast_torch_fp32(self):
- for op_with_args in self.autocast_lists.torch_fp32:
- op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
- self._run_autocast_outofplace(
- op, args, torch.float32, add_kwargs=maybe_kwargs
- )
-
- @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
- def test_autocast_torch_need_autocast_promote(self):
- for op, args in self.autocast_lists.torch_need_autocast_promote:
- self._run_autocast_outofplace(op, args, torch.float32)
-
- @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
- def test_autocast_torch_expect_builtin_promote(self):
- for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
- self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
-
- @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
- def test_autocast_nn_fp16(self):
- with torch.backends.cudnn.flags(enabled=True, deterministic=True):
- for op, args in self.autocast_lists.nn_fp16:
- self._run_autocast_outofplace(
- op, args, torch.float16, module=torch._C._nn
- )
-
- @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
- def test_autocast_nn_bf16(self):
- with torch.backends.cudnn.flags(enabled=True, deterministic=True):
- for op, args in self.autocast_lists.nn_fp16:
- if torch.cuda.is_bf16_supported():
- self._run_autocast_outofplace(
- op, args, torch.bfloat16, module=torch._C._nn
- )
- else:
- with self.assertRaisesRegex(
- RuntimeError, "Device does not support bfloat16"
- ):
- self._run_autocast_outofplace(
- op, args, torch.bfloat16, module=torch._C._nn
- )
-
- @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
- def test_autocast_nn_fp32(self):
- for op, args in self.autocast_lists.nn_fp32:
- self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn)
-
- @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
- def test_autocast_linalg_fp16(self):
- with torch.backends.cudnn.flags(enabled=True, deterministic=True):
- for op, args in self.autocast_lists.linalg_fp16:
- self._run_autocast_outofplace(
- op, args, torch.float16, module=torch._C._linalg
- )
-
- @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
- def test_autocast_methods_fp16(self):
- with torch.backends.cudnn.flags(enabled=True, deterministic=True):
- for op, args in self.autocast_lists.methods_fp16:
- self._run_autocast_outofplace(op, args, torch.float16, module=None)
-
- @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
- def test_autocast_methods_fp32(self):
- for op, args in self.autocast_lists.methods_fp32:
- self._run_autocast_outofplace(op, args, torch.float32, module=None)
-
- @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
- def test_autocast_methods_expect_builtin_promote(self):
- for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
- self._run_autocast_outofplace(
- op, args, torch.float32, module=None, out_type=out_type
- )
-
- def test_autocast_banned(self):
- with torch.autocast("cuda"):
- for op, args, module in self.autocast_lists.banned:
- with self.assertRaises(RuntimeError):
- getattr(module, op)(*args)
-
- def test_autocast_ignored_types(self):
- with torch.autocast("cuda"):
- for ignore_type in (torch.double, torch.int32):
- a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
- b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
- c_16 = torch.ones((8, 8), dtype=torch.float16, device="cuda:0")
-
- # Tests if CastPolicy::fp16 ops ignore double and int
- # Currently, no ops belonging to this policy support integer inputs.
- if ignore_type is torch.double:
- with self.assertRaises(RuntimeError):
- torch.mm(a_ignore, c_16)
- with torch.autocast("cuda", enabled=False):
- type_no_autocast = torch.mm(a_ignore, b_ignore).dtype
- self.assertTrue(
- torch.mm(a_ignore, b_ignore).dtype is type_no_autocast
- )
-
- # Tests if CastPolicy::fp32 ops ignore double and int
- with torch.autocast("cuda", enabled=False):
- type_no_autocast = torch.pow(a_ignore, 2.0).dtype
- self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast)
-
- # Tests if CastPolicy::fp32_set_opt_dtype ops ignore double and int
- with torch.autocast("cuda", enabled=False):
- type_no_autocast = torch.sum(a_ignore).dtype
- self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast)
-
- # Tests if CastPolicy::fp32_append_dtype ops ignore double and int
- # Currently, no ops belonging to this policy support integer inputs.
- if ignore_type is torch.double:
- with torch.autocast("cuda", enabled=False):
- type_no_autocast = torch.norm(a_ignore).dtype
- self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast)
-
- def test_autocast_custom_enabled(self):
- class MyMM(torch.autograd.Function):
- @staticmethod
- @torch.amp.custom_fwd(device_type="cuda")
- def forward(ctx, a, b):
- self.assertTrue(a.dtype is torch.float32)
- self.assertTrue(b.dtype is torch.float32)
- self.assertTrue(torch.is_autocast_enabled())
- ctx.save_for_backward(a, b)
- return a.mm(b)
-
- @staticmethod
- @torch.amp.custom_bwd(device_type="cuda")
- def backward(ctx, grad):
- self.assertTrue(torch.is_autocast_enabled())
- a, b = ctx.saved_tensors
- a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad)
- self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype)
- return a_grad, b_grad
-
- mymm = MyMM.apply
-
- x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
- y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
-
- dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,)
- for dtype in dtypes:
- with torch.cuda.amp.autocast(dtype=dtype):
- output = mymm(x, y)
- self.assertTrue(output.dtype is dtype)
- loss = output.sum()
- loss.backward()
-
- def test_autocast_custom_cast_inputs(self):
- class MyMM(torch.autograd.Function):
- @staticmethod
- @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
- def forward(ctx, a, container, expect_type):
- b = container[1][0]
- self.assertTrue(a.dtype is expect_type)
- self.assertTrue(b.dtype is expect_type)
- self.assertFalse(torch.is_autocast_enabled())
- ctx.save_for_backward(a, b)
- return a.mm(b)
-
- @staticmethod
- @torch.amp.custom_bwd(device_type="cuda")
- def backward(ctx, grad):
- self.assertFalse(torch.is_autocast_enabled())
- a, b = ctx.saved_tensors
- return grad.mm(b.t()), None, None
-
- mymm = MyMM.apply
-
- x = torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=True)
- # Puts one input tensor in a nested container. y's contained Tensor won't receive a gradient,
- # because torch.autograd.Function can't hand gradients back to non-Tensor forward arguments.
- # Sets requires_grad=False explicitly so we don't lie about expecting a gradient.
- y = (
- 0,
- {
- 0: torch.randn(
- (8, 8), device="cuda", dtype=torch.float16, requires_grad=False
- )
- },
- )
-
- with torch.autocast("cuda"):
- output = mymm(x, y, torch.float32)
- self.assertTrue(output.dtype is torch.float32)
- loss = output.sum()
- loss.backward()
-
- # Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region.
- output = mymm(x, y, torch.float16)
- self.assertTrue(output.dtype is torch.float16)
- loss = output.sum()
- loss.backward()
-
- def test_autocast_custom_deprecated_warning(self):
- with warnings.catch_warnings(record=True) as w:
-
- class MyMM(torch.autograd.Function):
- @staticmethod
- @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
- def forward(ctx, x, y):
- ctx.save_for_backward(x, y)
- self.assertFalse(torch.is_autocast_enabled())
- return x + y
-
- @staticmethod
- @torch.cuda.amp.custom_bwd
- def backward(ctx, grad):
- _, _ = ctx.saved_tensors
- self.assertFalse(torch.is_autocast_enabled())
- return grad, grad
-
- self.assertRegex(
- str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated."
- )
- self.assertRegex(
- str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated."
- )
-
- mymm = MyMM.apply
- x = torch.randn(3, 3, requires_grad=True)
- y = torch.randn(3, 3, requires_grad=True)
- with torch.amp.autocast("cuda"):
- output = mymm(x, y)
- loss = output.sum()
- loss.backward()
-
- def test_autocast_cat_jit(self):
- # Reported at https://github.com/pytorch/pytorch/issues/38958
-
- class Model(torch.nn.Module):
- def forward(self):
- a = torch.randn(1)
- b = torch.randn(1)
- c = torch.cat((a, b), 0)
- d = torch.stack([c, c], 0)
- return d
-
- # The JIT here doesn't really matter, we just need to call
- # cat via the boxed API
- model = Model()
- model_jit_script = torch.jit.script(model)
-
- with torch.autocast("cuda", enabled=True):
- model()
- model_jit_script()
-
- # cudnn RNNs require special backend handling (weights are cast to FP16 and reflattened)
- # so they get a dedicated test.
- # Despite the large number of RNN cases it tries, the test takes < 15 seconds on a Titan V (similar to V100).
- @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
- def test_autocast_rnn(self):
- with torch.backends.cudnn.flags(enabled=True, deterministic=True):
- # seq, batch, features, hidden size
- clses = ("RNN", "GRU", "LSTM")
- T, B, F, H = 3, 4, 5, 6
- dtypes = (torch.float16, torch.float32)
- input_layouts = ("seq_first", "batch_first", "packed")
-
- for (
- cls,
- num_layers,
- bias,
- input_layout,
- bidirectional,
- try_nonpreflattened_weights,
- input_dtype,
- hidden_dtype,
- weight_dtype,
- ) in product(
- clses,
- (1, 2),
- (True, False),
- input_layouts,
- (True, False),
- (True, False),
- dtypes,
- dtypes,
- dtypes,
- ):
- if input_layout == "seq_first":
- batch_first = False
- x = torch.randn((T, B, F), device="cuda", dtype=input_dtype)
- elif input_layout == "batch_first":
- batch_first = True
- x = torch.randn((B, T, F), device="cuda", dtype=input_dtype)
- elif input_layout == "packed":
- batch_first = False
- x = torch.nn.utils.rnn.pack_padded_sequence(
- torch.randn((T, B, F), device="cuda", dtype=input_dtype),
- lengths=(3, 2, 1, 3),
- enforce_sorted=False,
- )
-
- rnn = (
- getattr(torch.nn, cls)(
- F,
- H,
- num_layers=num_layers,
- bidirectional=bidirectional,
- bias=bias,
- batch_first=batch_first,
- )
- .cuda()
- .to(dtype=weight_dtype)
- )
-
- if try_nonpreflattened_weights:
- for p in rnn.parameters():
- with torch.no_grad():
- p.set_(p.clone())
-
- h = torch.randn(
- (num_layers * (2 if bidirectional else 1), B, H),
- device="cuda",
- dtype=hidden_dtype,
- )
- if cls == "LSTM":
- c = torch.randn(
- (num_layers * (2 if bidirectional else 1), B, H),
- device="cuda",
- dtype=hidden_dtype,
- )
- h = (h, c)
-
- with torch.autocast("cuda"):
- out, h_out = rnn(x, h)
- out = out.data if input_layout == "packed" else out
- self.assertEqual(out.dtype, torch.float16)
- # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed. This check can't guarantee
- # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has
- # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed.
- self.assertEqual(
- out.grad_fn.name(),
- "MiopenRnnBackward0" if torch.version.hip else "CudnnRnnBackward0",
- )
- out.sum().backward()
- grads = [p.grad.clone() for p in rnn.parameters()]
-
- rnn.zero_grad()
-
- if cls == "LSTM":
- out_control, h_out_control = rnn.to(dtype=torch.float16)(
- x.half(), (h[0].half(), h[1].half())
- )
- else:
- out_control, h_out_control = rnn.to(dtype=torch.float16)(
- x.half(), h.half()
- )
- out_control = (
- out_control.data if input_layout == "packed" else out_control
- )
- out_control.sum().backward()
- grads_control = [p.grad.clone() for p in rnn.parameters()]
-
- # Compares with default tolerances, even for FP16 execution. Barring nondeterminism,
- # autocast and control results should be bitwise identical.
- self.assertEqual(out, out_control)
-
- if cls == "LSTM":
- self.assertTrue(
- h_out[0].dtype is torch.float16
- and h_out[1].dtype is torch.float16
- )
- self.assertEqual(h_out[0], h_out_control[0])
- self.assertEqual(h_out[1], h_out_control[1])
- else:
- self.assertEqual(h_out.dtype, torch.float16)
- self.assertEqual(h_out, h_out_control)
- for grad, grad_control in zip(grads, grads_control):
- self.assertEqual(grad.half(), grad_control)
-
- def test_autocast_cache_leak(self):
- # Reported at https://github.com/pytorch/pytorch/issues/48049
- # Test is used to check, if autocast recaches the same parameters
- # when executed in a `torch.no_grad()` block.
-
- linear = torch.nn.Linear(10, 10).to("cuda")
- data = torch.randn(1, 10, device="cuda")
-
- with torch.autocast("cuda"):
- with torch.no_grad():
- out = linear(data)
- first_iter_mem = torch.cuda.memory_allocated()
- for _ in range(3):
- out = linear(data)
- self.assertTrue(first_iter_mem == torch.cuda.memory_allocated())
-
- def test_autocast_checkpointing(self):
- model = torch.nn.Sequential(
- torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)
- ).cuda()
- input = torch.rand(
- (8, 8), device="cuda", dtype=torch.float16, requires_grad=True
- )
- for reentrant in (True, False):
- with torch.autocast("cuda"):
- output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant)
- self.assertTrue(output.requires_grad)
- self.assertTrue(output.dtype is torch.float16)
- output.sum().backward()
-
- def test_cuda_autocast_deprecated_warning(self):
- with self.assertWarnsRegex(
- FutureWarning,
- r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.",
- ):
- with torch.cuda.amp.autocast():
- _ = torch.ones(10)
-
@slowTest
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
@serialTest()
@@ -3242,7 +2696,9 @@
relu_control = torch.nn.functional.relu
# This is a good stress test. It graphs four callables: two Modules and two python functions.
- with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
+ with torch.amp.autocast(
+ device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
+ ):
(
model_graphed[0],
model_graphed[1],
@@ -3282,7 +2738,9 @@
torch.cuda.manual_seed(5)
for data, target in zip(real_inputs, real_targets):
opt.zero_grad(set_to_none=True)
- with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
+ with torch.amp.autocast(
+ device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
+ ):
y_pred = m({"x": data, "unused_input": unused_input})["output"]
y_pred = relu(y_pred)
loss = loss_fn(y_pred, target)
@@ -3351,7 +2809,9 @@
y = torch.randn(N, D_in, device="cuda")
# This is a good stress test. It graphs four callables: two Modules and two python functions.
- with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
+ with torch.amp.autocast(
+ device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
+ ):
model_graphed[0] = torch.cuda.make_graphed_callables(
model_graphed[0],
({"x": x, "unused_input": unused_input},),
@@ -3366,8 +2826,10 @@
# so dropout math should be bitwise identical for both.
torch.manual_seed(5)
torch.cuda.manual_seed(5)
- for data, target in zip(real_inputs, real_targets):
- with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
+ for data, _ in zip(real_inputs, real_targets):
+ with torch.amp.autocast(
+ device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled
+ ):
out = m({"x": data, "unused_input": unused_input})["output"]
# We graphed the models in training mode. Eval should still run ungraphed.
@@ -3826,6 +3288,7 @@
file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
+@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCudaMallocAsync(TestCase):
@unittest.skipIf(
@@ -4492,7 +3955,7 @@
return t
-@unittest.skipIf(TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI")
+@unittest.skipIf(not TEST_CUDA or TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI")
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestBlockStateAbsorption(TestCase):
@property
@@ -4855,6 +4318,7 @@
self.assertEqual(rc, "False", "Triton was imported when importing torch!")
+@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
class TestMemPool(TestCase):
def test_mempool_id(self):
pool1 = torch.cuda.graph_pool_handle()
@@ -4980,6 +4444,7 @@
self.assertEqual(len(set(active_pool_ids)), 4)
+@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCudaOptims(TestCase):
# These tests will be instantiate with instantiate_device_type_tests
@@ -5322,6 +4787,7 @@
self.assertEqual(scaler._growth_tracker, growth_tracker)
+@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
class TestGDS(TestCase):
def _get_tmp_dir_fs_type(self):
my_path = os.path.realpath("/tmp")
@@ -5356,6 +4822,526 @@
torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage())
+@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
+class TestCudaAutocast(TestAutocast):
+ def setUp(self):
+ super().setUp()
+ self.autocast_lists = AutocastTestLists(torch.device("cuda:0"))
+
+ def tearDown(self):
+ del self.autocast_lists
+ super().tearDown()
+
+ @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+ def test_autocast_torch_fp16(self):
+ with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+ for op_with_args in self.autocast_lists.torch_fp16:
+ skip_test = False
+ op, args = op_with_args[0], op_with_args[1]
+ if len(op_with_args) == 3:
+ skip_test = op_with_args[2] # TEST_WITH_ROCM
+ if not skip_test:
+ self._run_autocast_outofplace(
+ op, args, torch.float16, device="cuda", amp_dtype=torch.float16
+ )
+
+ @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+ def test_autocast_torch_bf16(self):
+ with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+ for op_with_args in self.autocast_lists.torch_fp16:
+ skip_test = False
+ op, args = op_with_args[0], op_with_args[1]
+ if len(op_with_args) == 3:
+ skip_test = op_with_args[2] # TEST_WITH_ROCM
+ should_error_from_cudnn = "cudnn" in op and (
+ "TORCH_CUDNN_V8_API_DISABLED" in os.environ
+ and int(os.environ["TORCH_CUDNN_V8_API_DISABLED"])
+ or torch.cuda.get_device_capability() < (8, 0)
+ )
+ should_error_from_not_implemented = should_error_from_cudnn
+ if not skip_test:
+ if should_error_from_not_implemented:
+ with self.assertRaises(
+ RuntimeError,
+ msg=str(op) + " should not be supported for bfloat16!",
+ ):
+ self._run_autocast_outofplace(
+ op, args, torch.bfloat16, device="cuda"
+ )
+ else:
+ if torch.cuda.is_bf16_supported():
+ self._run_autocast_outofplace(
+ op, args, torch.bfloat16, device="cuda"
+ )
+ else:
+ with self.assertRaisesRegex(
+ RuntimeError, "Device does not support bfloat16"
+ ):
+ self._run_autocast_outofplace(
+ op, args, torch.bfloat16, device="cuda"
+ )
+
+ @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+ def test_autocast_torch_fp32(self):
+ for op_with_args in self.autocast_lists.torch_fp32:
+ op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
+ self._run_autocast_outofplace(
+ op,
+ args,
+ torch.float32,
+ device="cuda",
+ add_kwargs=maybe_kwargs,
+ amp_dtype=torch.float16,
+ )
+
+ @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+ def test_autocast_torch_need_autocast_promote(self):
+ for op, args in self.autocast_lists.torch_need_autocast_promote:
+ self._run_autocast_outofplace(
+ op, args, torch.float32, device="cuda", amp_dtype=torch.float16
+ )
+
+ @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+ def test_autocast_torch_expect_builtin_promote(self):
+ for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
+ self._run_autocast_outofplace(
+ op,
+ args,
+ torch.float32,
+ device="cuda",
+ out_type=out_type,
+ amp_dtype=torch.float16,
+ )
+
+ @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+ def test_autocast_nn_fp16(self):
+ with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+ for op, args in self.autocast_lists.nn_fp16:
+ self._run_autocast_outofplace(
+ op,
+ args,
+ torch.float16,
+ device="cuda",
+ module=torch._C._nn,
+ amp_dtype=torch.float16,
+ )
+
+ @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+ def test_autocast_nn_bf16(self):
+ with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+ for op, args in self.autocast_lists.nn_fp16:
+ if torch.cuda.is_bf16_supported():
+ self._run_autocast_outofplace(
+ op, args, torch.bfloat16, device="cuda", module=torch._C._nn
+ )
+ else:
+ with self.assertRaisesRegex(
+ RuntimeError, "Device does not support bfloat16"
+ ):
+ self._run_autocast_outofplace(
+ op, args, torch.bfloat16, device="cuda", module=torch._C._nn
+ )
+
+ @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+ def test_autocast_nn_fp32(self):
+ for op, args in self.autocast_lists.nn_fp32:
+ self._run_autocast_outofplace(
+ op,
+ args,
+ torch.float32,
+ device="cuda",
+ module=torch._C._nn,
+ amp_dtype=torch.float16,
+ )
+
+ @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+ def test_autocast_linalg_fp16(self):
+ with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+ for op, args in self.autocast_lists.linalg_fp16:
+ self._run_autocast_outofplace(
+ op,
+ args,
+ torch.float16,
+ device="cuda",
+ module=torch._C._linalg,
+ amp_dtype=torch.float16,
+ )
+
+ @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+ def test_autocast_methods_fp16(self):
+ with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+ for op, args in self.autocast_lists.methods_fp16:
+ self._run_autocast_outofplace(
+ op,
+ args,
+ torch.float16,
+ device="cuda",
+ module=None,
+ amp_dtype=torch.float16,
+ )
+
+ @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+ def test_autocast_methods_fp32(self):
+ for op, args in self.autocast_lists.methods_fp32:
+ self._run_autocast_outofplace(
+ op,
+ args,
+ torch.float32,
+ device="cuda",
+ module=None,
+ amp_dtype=torch.float16,
+ )
+
+ @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+ def test_autocast_methods_expect_builtin_promote(self):
+ for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
+ self._run_autocast_outofplace(
+ op,
+ args,
+ torch.float32,
+ device="cuda",
+ module=None,
+ out_type=out_type,
+ amp_dtype=torch.float16,
+ )
+
+ def test_autocast_banned(self):
+ with torch.autocast("cuda"):
+ for op, args, module in self.autocast_lists.banned:
+ with self.assertRaises(RuntimeError):
+ getattr(module, op)(*args)
+
+ def test_autocast_ignored_types(self):
+ with torch.autocast("cuda"):
+ for ignore_type in (torch.double, torch.int32):
+ a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
+ b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
+ c_16 = torch.ones((8, 8), dtype=torch.float16, device="cuda:0")
+
+ # Tests if CastPolicy::fp16 ops ignore double and int
+ # Currently, no ops belonging to this policy support integer inputs.
+ if ignore_type is torch.double:
+ with self.assertRaises(RuntimeError):
+ torch.mm(a_ignore, c_16)
+ with torch.autocast("cuda", enabled=False):
+ type_no_autocast = torch.mm(a_ignore, b_ignore).dtype
+ self.assertTrue(
+ torch.mm(a_ignore, b_ignore).dtype is type_no_autocast
+ )
+
+ # Tests if CastPolicy::fp32 ops ignore double and int
+ with torch.autocast("cuda", enabled=False):
+ type_no_autocast = torch.pow(a_ignore, 2.0).dtype
+ self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast)
+
+ # Tests if CastPolicy::fp32_set_opt_dtype ops ignore double and int
+ with torch.autocast("cuda", enabled=False):
+ type_no_autocast = torch.sum(a_ignore).dtype
+ self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast)
+
+ # Tests if CastPolicy::fp32_append_dtype ops ignore double and int
+ # Currently, no ops belonging to this policy support integer inputs.
+ if ignore_type is torch.double:
+ with torch.autocast("cuda", enabled=False):
+ type_no_autocast = torch.norm(a_ignore).dtype
+ self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast)
+
+ def test_autocast_custom_enabled(self):
+ class MyMM(torch.autograd.Function):
+ @staticmethod
+ @torch.amp.custom_fwd(device_type="cuda")
+ def forward(ctx, a, b):
+ self.assertTrue(a.dtype is torch.float32)
+ self.assertTrue(b.dtype is torch.float32)
+ self.assertTrue(torch.is_autocast_enabled())
+ ctx.save_for_backward(a, b)
+ return a.mm(b)
+
+ @staticmethod
+ @torch.amp.custom_bwd(device_type="cuda")
+ def backward(ctx, grad):
+ self.assertTrue(torch.is_autocast_enabled())
+ a, b = ctx.saved_tensors
+ a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad)
+ self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype)
+ return a_grad, b_grad
+
+ mymm = MyMM.apply
+
+ x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
+ y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
+
+ dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,)
+ for dtype in dtypes:
+ with torch.cuda.amp.autocast(dtype=dtype):
+ output = mymm(x, y)
+ self.assertTrue(output.dtype is dtype)
+ loss = output.sum()
+ loss.backward()
+
+ def test_autocast_custom_cast_inputs(self):
+ class MyMM(torch.autograd.Function):
+ @staticmethod
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
+ def forward(ctx, a, container, expect_type):
+ b = container[1][0]
+ self.assertTrue(a.dtype is expect_type)
+ self.assertTrue(b.dtype is expect_type)
+ self.assertFalse(torch.is_autocast_enabled())
+ ctx.save_for_backward(a, b)
+ return a.mm(b)
+
+ @staticmethod
+ @torch.amp.custom_bwd(device_type="cuda")
+ def backward(ctx, grad):
+ self.assertFalse(torch.is_autocast_enabled())
+ a, b = ctx.saved_tensors
+ return grad.mm(b.t()), None, None
+
+ mymm = MyMM.apply
+
+ x = torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=True)
+ # Puts one input tensor in a nested container. y's contained Tensor won't receive a gradient,
+ # because torch.autograd.Function can't hand gradients back to non-Tensor forward arguments.
+ # Sets requires_grad=False explicitly so we don't lie about expecting a gradient.
+ y = (
+ 0,
+ {
+ 0: torch.randn(
+ (8, 8), device="cuda", dtype=torch.float16, requires_grad=False
+ )
+ },
+ )
+
+ with torch.autocast("cuda"):
+ output = mymm(x, y, torch.float32)
+ self.assertTrue(output.dtype is torch.float32)
+ loss = output.sum()
+ loss.backward()
+
+ # Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region.
+ output = mymm(x, y, torch.float16)
+ self.assertTrue(output.dtype is torch.float16)
+ loss = output.sum()
+ loss.backward()
+
+ def test_autocast_custom_deprecated_warning(self):
+ with warnings.catch_warnings(record=True) as w:
+
+ class MyMM(torch.autograd.Function):
+ @staticmethod
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, x, y):
+ ctx.save_for_backward(x, y)
+ self.assertFalse(torch.is_autocast_enabled())
+ return x + y
+
+ @staticmethod
+ @torch.cuda.amp.custom_bwd
+ def backward(ctx, grad):
+ _, _ = ctx.saved_tensors
+ self.assertFalse(torch.is_autocast_enabled())
+ return grad, grad
+
+ self.assertRegex(
+ str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated."
+ )
+ self.assertRegex(
+ str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated."
+ )
+
+ mymm = MyMM.apply
+ x = torch.randn(3, 3, requires_grad=True)
+ y = torch.randn(3, 3, requires_grad=True)
+ with torch.amp.autocast("cuda"):
+ output = mymm(x, y)
+ loss = output.sum()
+ loss.backward()
+
+ def test_autocast_cat_jit(self):
+ # Reported at https://github.com/pytorch/pytorch/issues/38958
+
+ class Model(torch.nn.Module):
+ def forward(self):
+ a = torch.randn(1)
+ b = torch.randn(1)
+ c = torch.cat((a, b), 0)
+ d = torch.stack([c, c], 0)
+ return d
+
+ # The JIT here doesn't really matter, we just need to call
+ # cat via the boxed API
+ model = Model()
+ model_jit_script = torch.jit.script(model)
+
+ with torch.autocast("cuda", enabled=True):
+ model()
+ model_jit_script()
+
+ # cudnn RNNs require special backend handling (weights are cast to FP16 and reflattened)
+ # so they get a dedicated test.
+ # Despite the large number of RNN cases it tries, the test takes < 15 seconds on a Titan V (similar to V100).
+ @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
+ def test_autocast_rnn(self):
+ with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+ # seq, batch, features, hidden size
+ clses = ("RNN", "GRU", "LSTM")
+ T, B, F, H = 3, 4, 5, 6
+ dtypes = (torch.float16, torch.float32)
+ input_layouts = ("seq_first", "batch_first", "packed")
+
+ for (
+ cls,
+ num_layers,
+ bias,
+ input_layout,
+ bidirectional,
+ try_nonpreflattened_weights,
+ input_dtype,
+ hidden_dtype,
+ weight_dtype,
+ ) in product(
+ clses,
+ (1, 2),
+ (True, False),
+ input_layouts,
+ (True, False),
+ (True, False),
+ dtypes,
+ dtypes,
+ dtypes,
+ ):
+ if input_layout == "seq_first":
+ batch_first = False
+ x = torch.randn((T, B, F), device="cuda", dtype=input_dtype)
+ elif input_layout == "batch_first":
+ batch_first = True
+ x = torch.randn((B, T, F), device="cuda", dtype=input_dtype)
+ elif input_layout == "packed":
+ batch_first = False
+ x = torch.nn.utils.rnn.pack_padded_sequence(
+ torch.randn((T, B, F), device="cuda", dtype=input_dtype),
+ lengths=(3, 2, 1, 3),
+ enforce_sorted=False,
+ )
+
+ rnn = (
+ getattr(torch.nn, cls)(
+ F,
+ H,
+ num_layers=num_layers,
+ bidirectional=bidirectional,
+ bias=bias,
+ batch_first=batch_first,
+ )
+ .cuda()
+ .to(dtype=weight_dtype)
+ )
+
+ if try_nonpreflattened_weights:
+ for p in rnn.parameters():
+ with torch.no_grad():
+ p.set_(p.clone())
+
+ h = torch.randn(
+ (num_layers * (2 if bidirectional else 1), B, H),
+ device="cuda",
+ dtype=hidden_dtype,
+ )
+ if cls == "LSTM":
+ c = torch.randn(
+ (num_layers * (2 if bidirectional else 1), B, H),
+ device="cuda",
+ dtype=hidden_dtype,
+ )
+ h = (h, c)
+
+ with torch.autocast("cuda"):
+ out, h_out = rnn(x, h)
+ out = out.data if input_layout == "packed" else out
+ self.assertEqual(out.dtype, torch.float16)
+ # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed. This check can't guarantee
+ # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has
+ # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed.
+ self.assertEqual(
+ out.grad_fn.name(),
+ "MiopenRnnBackward0" if torch.version.hip else "CudnnRnnBackward0",
+ )
+ out.sum().backward()
+ grads = [p.grad.clone() for p in rnn.parameters()]
+
+ rnn.zero_grad()
+
+ if cls == "LSTM":
+ out_control, h_out_control = rnn.to(dtype=torch.float16)(
+ x.half(), (h[0].half(), h[1].half())
+ )
+ else:
+ out_control, h_out_control = rnn.to(dtype=torch.float16)(
+ x.half(), h.half()
+ )
+ out_control = (
+ out_control.data if input_layout == "packed" else out_control
+ )
+ out_control.sum().backward()
+ grads_control = [p.grad.clone() for p in rnn.parameters()]
+
+ # Compares with default tolerances, even for FP16 execution. Barring nondeterminism,
+ # autocast and control results should be bitwise identical.
+ self.assertEqual(out, out_control)
+
+ if cls == "LSTM":
+ self.assertTrue(
+ h_out[0].dtype is torch.float16
+ and h_out[1].dtype is torch.float16
+ )
+ self.assertEqual(h_out[0], h_out_control[0])
+ self.assertEqual(h_out[1], h_out_control[1])
+ else:
+ self.assertEqual(h_out.dtype, torch.float16)
+ self.assertEqual(h_out, h_out_control)
+ for grad, grad_control in zip(grads, grads_control):
+ self.assertEqual(grad.half(), grad_control)
+
+ def test_autocast_cache_leak(self):
+ # Reported at https://github.com/pytorch/pytorch/issues/48049
+ # Test is used to check, if autocast recaches the same parameters
+ # when executed in a `torch.no_grad()` block.
+
+ linear = torch.nn.Linear(10, 10).to("cuda")
+ data = torch.randn(1, 10, device="cuda")
+
+ with torch.autocast("cuda"):
+ with torch.no_grad():
+ out = linear(data)
+ first_iter_mem = torch.cuda.memory_allocated()
+ for _ in range(3):
+ out = linear(data)
+ self.assertTrue(first_iter_mem == torch.cuda.memory_allocated())
+
+ def test_autocast_checkpointing(self):
+ model = torch.nn.Sequential(
+ torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)
+ ).cuda()
+ input = torch.rand(
+ (8, 8), device="cuda", dtype=torch.float16, requires_grad=True
+ )
+ for reentrant in (True, False):
+ with torch.autocast("cuda"):
+ output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant)
+ self.assertTrue(output.requires_grad)
+ self.assertTrue(output.dtype is torch.float16)
+ output.sum().backward()
+
+ def test_cuda_autocast_deprecated_warning(self):
+ with self.assertWarnsRegex(
+ FutureWarning,
+ r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.",
+ ):
+ with torch.cuda.amp.autocast():
+ _ = torch.ones(10)
+
+
instantiate_parametrized_tests(TestCuda)
instantiate_parametrized_tests(TestCudaMallocAsync)
instantiate_device_type_tests(TestCudaOptims, globals())
diff --git a/test/test_xpu.py b/test/test_xpu.py
index 9dde7d8..e77a1e7 100644
--- a/test/test_xpu.py
+++ b/test/test_xpu.py
@@ -1,6 +1,5 @@
# Owner(s): ["module: intel"]
-import collections
import subprocess
import sys
import tempfile
@@ -8,7 +7,7 @@
import torch
import torch.xpu._gpu_trace as gpu_trace
-from torch.testing._internal.autocast_test_lists import AutocastTestLists
+from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyXPU,
@@ -371,7 +370,7 @@
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)
-class TestXpuAutocast(TestCase):
+class TestXpuAutocast(TestAutocast):
# These operators are not implemented on XPU backend and we can NOT fall back
# them to CPU. So we have to skip them at this moment.
# TODO: remove these operators from skip list when they are implemented on XPU backend.
@@ -385,89 +384,6 @@
del self.autocast_lists
super().tearDown()
- def _run_autocast_outofplace(
- self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None
- ):
- # helper to cast args
- def cast(val, to_type):
- if isinstance(val, torch.Tensor):
- return val.to(to_type) if val.is_floating_point() else val
- elif isinstance(val, collections.abc.Iterable):
- return type(val)(cast(v, to_type) for v in val)
- else:
- return val
-
- if add_kwargs is None:
- add_kwargs = {}
- fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16
- self.assertFalse(torch.is_autocast_enabled("xpu"))
- with torch.amp.autocast("xpu", dtype=fast_dtype):
- self.assertTrue(torch.is_autocast_enabled("xpu"))
-
- out_type = out_type if out_type is not None else run_as_type
- output = output_method = None
-
- # Try module.* variant, if requested:
- if module is not None and hasattr(module, op):
- output = getattr(module, op)(*args, **add_kwargs)
- if isinstance(output, torch.Tensor):
- self.assertTrue(
- out_type == output.dtype,
- f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
- )
-
- # Try Tensor.* variant:
- if hasattr(torch.Tensor, op):
- output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
- if isinstance(output_method, torch.Tensor):
- self.assertTrue(
- out_type == output_method.dtype,
- f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
- )
-
- self.assertTrue(
- (output is not None) or (output_method is not None),
- f"{op} not found as an attribute on either Tensor or the requested module {module}",
- )
-
- # Accounts for ops that return Tensors, iterables, and other non-Tensors.
- # For example, lstm_cell returns a tuple and equal returns bool.
- def compare(first, second):
- if isinstance(first, torch.Tensor):
- return torch.equal(first, second)
- elif isinstance(first, collections.abc.Iterable):
- return all(compare(f, s) for f, s in zip(first, second))
- else:
- return first == second
-
- # If both torch.* and Tensor.* variants were found, check outputs are identical
- if (output is not None) and (output_method is not None):
- self.assertTrue(type(output) == type(output_method))
- comparison = compare(output, output_method)
- self.assertTrue(
- comparison, f"torch.{op} result did not match Tensor.{op} result"
- )
-
- # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
- # as the C++-side autocasting, and should be bitwise accurate.
- output_to_compare = output if output is not None else output_method
- with torch.amp.autocast("xpu", enabled=False):
- self.assertFalse(torch.is_autocast_enabled("xpu"))
-
- if module is not None and hasattr(module, op):
- control = getattr(module, op)(
- *cast(args, run_as_type), **add_kwargs
- )
- else:
- control = getattr(args[0].to(run_as_type), op)(
- *cast(args[1:], run_as_type), **add_kwargs
- )
- self.assertTrue(type(output_to_compare) == type(control))
- comparison = compare(output_to_compare, control)
- self.assertTrue(comparison, f"torch.{op} result did not match control")
- self.assertTrue(torch.is_autocast_enabled("xpu"))
- self.assertFalse(torch.is_autocast_enabled("xpu"))
-
def test_autocast_torch_fp16(self):
for op_with_args in self.autocast_lists.torch_fp16:
skip_test = False
@@ -477,7 +393,9 @@
if len(op_with_args) == 3:
skip_test = True # skip cudnn op
if not skip_test:
- self._run_autocast_outofplace(op, args, torch.float16)
+ self._run_autocast_outofplace(
+ op, args, torch.float16, device="xpu", amp_dtype=torch.float16
+ )
def test_autocast_torch_bf16(self):
for op_with_args in self.autocast_lists.torch_fp16:
@@ -488,15 +406,24 @@
if len(op_with_args) == 3:
skip_test = True # skip cudnn op
if not skip_test:
- self._run_autocast_outofplace(op, args, torch.bfloat16)
+ self._run_autocast_outofplace(op, args, torch.bfloat16, device="xpu")
def test_autocast_torch_need_autocast_promote(self):
for op, args in self.autocast_lists.torch_need_autocast_promote:
- self._run_autocast_outofplace(op, args, torch.float32)
+ self._run_autocast_outofplace(
+ op, args, torch.float32, device="xpu", amp_dtype=torch.float16
+ )
def test_autocast_torch_expect_builtin_promote(self):
for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
- self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
+ self._run_autocast_outofplace(
+ op,
+ args,
+ torch.float32,
+ device="xpu",
+ out_type=out_type,
+ amp_dtype=torch.float16,
+ )
def test_autocast_checkpointing(self):
model = torch.nn.Sequential(
diff --git a/torch/testing/_internal/autocast_test_lists.py b/torch/testing/_internal/autocast_test_lists.py
index 8527084..c9789f1 100644
--- a/torch/testing/_internal/autocast_test_lists.py
+++ b/torch/testing/_internal/autocast_test_lists.py
@@ -1,7 +1,10 @@
# mypy: ignore-errors
+import collections
+
import torch
from torch.testing._internal.common_utils import TEST_WITH_ROCM
+from torch.testing._internal.common_utils import TestCase
class AutocastTestLists:
@@ -234,6 +237,7 @@
torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn),
]
+
class AutocastCPUTestLists:
# Supplies ops and arguments for test_autocast_* in test/test_cpu.py
def __init__(self, dev):
@@ -368,3 +372,103 @@
("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
]
+
+
+class TestAutocast(TestCase):
+ def args_maybe_kwargs(self, op_with_args):
+ if len(op_with_args) == 2:
+ return op_with_args[0], op_with_args[1], {}
+ else:
+ return op_with_args[0], op_with_args[1], op_with_args[2]
+
+ def _run_autocast_outofplace(
+ self,
+ op,
+ args,
+ run_as_type,
+ device,
+ out_type=None,
+ module=torch,
+ add_kwargs=None,
+ amp_dtype=torch.bfloat16,
+ ):
+ # helper to cast args
+ def cast(val, to_type):
+ if isinstance(val, torch.Tensor):
+ return val.to(to_type) if val.is_floating_point() else val
+ elif isinstance(val, collections.abc.Iterable):
+ return type(val)(cast(v, to_type) for v in val)
+ else:
+ return val
+
+ if add_kwargs is None:
+ add_kwargs = {}
+
+ self.assertFalse(torch.is_autocast_enabled(device_type=device))
+ with torch.amp.autocast(device_type=device, dtype=amp_dtype):
+ self.assertTrue(torch.is_autocast_enabled(device_type=device))
+
+ out_type = out_type if out_type is not None else run_as_type
+ output = output_method = None
+
+ # Try module.* variant, if requested:
+ if module is not None and hasattr(module, op):
+ output = getattr(module, op)(*args, **add_kwargs)
+ if isinstance(output, torch.Tensor):
+ self.assertTrue(
+ out_type == output.dtype,
+ f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
+ )
+ # Try Tensor.* variant:
+ if hasattr(torch.Tensor, op):
+ output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
+ if isinstance(output_method, torch.Tensor):
+ self.assertTrue(
+ out_type == output_method.dtype,
+ f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
+ )
+
+ self.assertTrue(
+ (output is not None) or (output_method is not None),
+ f"{op} not found as an attribute on either Tensor or the requested module {module}",
+ )
+
+ # Accounts for ops that return Tensors, iterables, and other non-Tensors.
+ # For example, lstm_cell returns a tuple and equal returns bool.
+ def compare(first, second):
+ if isinstance(first, torch.Tensor):
+ return torch.equal(first, second)
+ elif isinstance(first, collections.abc.Iterable):
+ return all(compare(f, s) for f, s in zip(first, second))
+ else:
+ return first == second
+
+ # If both torch.* and Tensor.* variants were found, check outputs are identical
+ if (output is not None) and (output_method is not None):
+ self.assertTrue(type(output) == type(output_method))
+ comparison = compare(output, output_method)
+ self.assertTrue(
+ comparison, f"torch.{op} result did not match Tensor.{op} result"
+ )
+
+ # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
+ # as the C++-side autocasting, and should be bitwise accurate.
+ output_to_compare = output if output is not None else output_method
+ with torch.amp.autocast(device_type=device, enabled=False):
+ self.assertFalse(
+ torch.is_autocast_enabled(device_type=device)
+ )
+
+ if module is not None and hasattr(module, op):
+ control = getattr(module, op)(
+ *cast(args, run_as_type), **add_kwargs
+ )
+ else:
+ control = getattr(args[0].to(run_as_type), op)(
+ *cast(args[1:], run_as_type), **add_kwargs
+ )
+ self.assertTrue(type(output_to_compare) == type(control))
+ comparison = compare(output_to_compare, control)
+ self.assertTrue(comparison, f"torch.{op} result did not match control")
+ self.assertTrue(torch.is_autocast_enabled(device_type=device))
+ self.assertFalse(torch.is_autocast_enabled(device_type=device))