blob: 855484a13bfa4b13586f22c5d7c6f1474304e7c2 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import unittest
from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.onnx.operators
from torch._dynamo.testing import same
from torch.nn import functional as F
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FUSED_SDPA,
SM80OrLater,
)
class CtxManagerTests(torch._dynamo.test_case.TestCase):
def test_no_grad(self):
def fn1(a, b):
x = a + 1
# redundant no_grad should get ignored
with torch.no_grad():
x = x + b
x = x + 2
return x
def fn2(a, b):
x = a + 1
with torch.set_grad_enabled(False):
x = x + b
x = x + 2
return x
def fn3(a, b):
x = a + 1
with torch.enable_grad():
x = x + b
x = x + 2
return x
def fn4(a, b):
x = a + 1
with torch.set_grad_enabled(True):
if torch.is_grad_enabled():
x = x + b
x = x + 2
return x
with torch.no_grad():
torch._dynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5)
with torch.enable_grad():
torch._dynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5)
def test_grad_mode_guard(self):
def fn(a, b):
prev_grad = torch.is_grad_enabled()
torch.set_grad_enabled(False)
a = a + 1
a.tolist() # graph break
ret = a + b
torch.set_grad_enabled(prev_grad)
return ret
a = torch.randn([3, 4])
b = torch.randn([3, 4])
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
for _ in range(10):
opt_fn(a, b)
self.assertEqual(cnts.frame_count, 2)
def test_nested_grad_mode_graph_break(self):
def fn(x):
before = torch.is_grad_enabled()
with torch.set_grad_enabled(False):
torch._dynamo.graph_break()
with torch.set_grad_enabled(True):
x = torch.mul(x, 5)
torch._dynamo.graph_break()
x = torch.sqrt(x)
assert torch.is_grad_enabled()
assert not torch.is_grad_enabled()
assert torch.is_grad_enabled() == before
return x
a = torch.randn([3, 4])
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
for _ in range(10):
opt_fn(a)
self.assertEqual(cnts.frame_count, 3)
def test_torch_profiler(self):
# wrap torch.profiler.* as NullContextVariable and do nothing
def fn(x):
y = x**2
with torch.profiler.profile():
y = y + 2
with torch.profiler.record_function("my_function"):
z = y**3
z.tolist() # graph break
z = z + 1
return z
x = torch.randn((2, 2), requires_grad=True)
ref = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 2)
def test_autograd_profiler(self):
# wrap torch.autograd.profiler.* as NullContextVariable and do nothing
def fn(x):
y = x**2
with torch.autograd.profiler.profile():
y = y + 2
with torch.autograd.profiler.record_function("my_function"):
z = y**3
z.tolist() # graph break
z = z + 1
return z
x = torch.randn((2, 2), requires_grad=True)
ref = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 2)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_stream_context_manager1(self):
def fn(x):
s = torch.cuda.Stream()
x = torch.mul(x, 5)
x = torch.add(x, 2)
with torch.cuda.stream(s):
x = torch.relu(x)
x = torch.add(x, 1)
x = torch.cos(x)
return x
x = torch.randn((2, 2))
ref = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 9)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_stream_context_manager2(self):
def fn(x, s):
x = torch.mul(x, 5)
x = torch.add(x, 2)
with torch.cuda.stream(s):
x = torch.relu(x)
x = torch.add(x, 1)
x = torch.cos(x)
return x
x = torch.randn((2, 2))
s = torch.cuda.Stream()
ref = fn(x, s)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
res = opt_fn(x, s)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 8)
def test_autograd_profiler_enabled(self):
def fn(x):
if torch.autograd._profiler_enabled():
return x + 1
else:
return x - 1
x = torch.randn((2, 2), requires_grad=True)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
if torch.autograd._profiler_enabled():
torch.autograd._disable_profiler()
assert not torch.autograd._profiler_enabled()
ref = fn(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))
with torch.autograd.profiler.profile():
assert torch.autograd._profiler_enabled()
ref = fn(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_autocast(self):
if not torch.cuda.is_bf16_supported():
raise unittest.SkipTest("requires bf16")
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
e_float16 = torch.mm(a_float32, b_float32)
f_float16 = torch.mm(d_float32, e_float16)
return f_float16
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.type, "cuda")
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.bfloat16)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_amp_autocast(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
with torch.cuda.amp.autocast(dtype=torch.torch.float64):
c_float64 = torch.mm(a_float32, b_float32)
return c_float64
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.type, "cuda")
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.float64)
def test_is_autocast_cpu_enabled(self):
def fn(a_float32, b_float32):
with torch.cpu.amp.autocast(dtype=torch.bfloat16):
c_float16 = torch.mm(a_float32, b_float32)
if torch.is_autocast_cpu_enabled():
c_float16 = c_float16 + 1
return c_float16
a = torch.rand((8, 8))
b = torch.rand((8, 8))
ref = fn(a, b)
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
res = opt_fn(a, b)
self.assertTrue(same(ref, res))
@unittest.skipIf(
not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater,
"Can't run fused SDPA on this platform",
)
@patch.object(torch._dynamo.config, "dynamic_shapes", False)
def test_autocast_sdpa(self):
class MyModule(torch.nn.Module):
def forward(self, query, key, value):
with torch.autocast("cpu"):
with torch.autocast("cuda", dtype=torch.float32):
out = F.scaled_dot_product_attention(
query, key, value, None, 0.5, True
)
return out
dtype = torch.float32
seq_len_q = 1
seq_len_k = 1
head_dim = 8
query = torch.ones(
1, 8, seq_len_q, head_dim, device="cuda", dtype=dtype, requires_grad=True
)
key = torch.ones(
1, 8, seq_len_k, head_dim, device="cuda", dtype=dtype, requires_grad=True
)
value = torch.ones(
1, 8, seq_len_k, head_dim, device="cuda", dtype=dtype, requires_grad=True
)
module = MyModule()
real = module(query, key, value)
real_device = real.device
real_dtype = real.dtype
opt_mod = torch._dynamo.optimize("inductor")(module)
compiled = opt_mod(query, key, value)
self.assertEqual(compiled.device, real_device)
self.assertEqual(compiled.dtype, real_dtype)
self.assertEqual(compiled.device.type, "cuda")
self.assertEqual(compiled.device.index, 0)
self.assertEqual(compiled.dtype, torch.float32)
def test_autocast_cpu(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cpu")
b_float32 = torch.rand((8, 8), device="cpu")
d_float32 = torch.rand((8, 8), device="cpu")
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
e_float16 = torch.mm(a_float32, b_float32)
f_float16 = torch.mm(d_float32, e_float16)
return f_float16
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.type, "cpu")
self.assertEqual(exported.dtype, torch.bfloat16)
def test_autocast_cpu_graph_break(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cpu")
b_float32 = torch.rand((8, 8), device="cpu")
torch._dynamo.graph_break()
d_float32 = torch.rand((8, 8), device="cpu")
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
e_float16 = torch.mm(a_float32, b_float32)
torch._dynamo.graph_break()
f_float16 = torch.mm(d_float32, e_float16)
return f_float16
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
opt = torch._dynamo.optimize("eager")(module)
res = opt(torch.tensor([0.5]))
self.assertEqual(res.device, real_device)
self.assertEqual(res.dtype, real_dtype)
self.assertEqual(res.device.type, "cpu")
self.assertEqual(res.dtype, torch.bfloat16)
def test_autocast_cpu_graph_break_2(self):
# Regression for: https://github.com/pytorch/pytorch/issues/93890
def fn(x):
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
x = torch.mm(x, x)
torch._dynamo.graph_break()
x = torch.relu(x)
return x
x = torch.rand([4, 4])
self.assertEqual(x.dtype, torch.float32)
res = fn(x)
opt_fn = torch._dynamo.optimize("eager")(fn)
opt_res = opt_fn(x)
self.assertTrue(torch.allclose(res, opt_res))
self.assertEqual(res.dtype, torch.bfloat16)
self.assertEqual(opt_res.dtype, torch.bfloat16)
def test_autocast_cpu_graph_break_inner_fn(self):
class MyModule(torch.nn.Module):
@staticmethod
def mm_breaks(x, y):
torch._dynamo.graph_break()
return torch.mm(x, y)
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cpu")
b_float32 = torch.rand((8, 8), device="cpu")
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
torch._dynamo.graph_break()
with torch.autocast(
device_type="cpu", dtype=torch.bfloat16, enabled=False
):
torch._dynamo.graph_break()
g_float32 = torch.mm(a_float32, b_float32)
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
# Check that nested with non-inlineable function with graph break
torch._dynamo.graph_break()
f_float16_1 = self.mm_breaks(a_float32, b_float32)
# We remember to exit the inner autocast correctly to outer
# even after graph breaks
f_float16 = self.mm_breaks(a_float32, b_float32)
assert f_float16.dtype == f_float16_1.dtype
return f_float16, g_float32
module = MyModule()
real_16, real_32 = module(torch.tensor([0.5]))
real_device_16 = real_16.device
real_dtype_16 = real_16.dtype
real_device_32 = real_32.device
real_dtype_32 = real_32.dtype
graph = torch._dynamo.optimize("eager")(module)
out_16, out_32 = graph(torch.tensor([0.5]))
self.assertEqual(out_16.device, real_device_16)
self.assertEqual(out_16.dtype, real_dtype_16)
self.assertEqual(out_32.device, real_device_32)
self.assertEqual(out_32.dtype, real_dtype_32)
self.assertEqual(out_16.device.type, "cpu")
self.assertEqual(out_16.dtype, torch.bfloat16)
self.assertEqual(out_32.device.type, "cpu")
self.assertEqual(out_32.dtype, torch.float32)
def test_autocast_graph_break_method(self):
class MyModule(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.bias = bias
def mm_not_break(self, x, y):
return torch.mm(x, y) + self.bias
def mm_breaks(self, x, y):
torch._dynamo.graph_break()
return torch.mm(x, y) + self.bias
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cpu")
b_float32 = torch.rand((8, 8), device="cpu")
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
with torch.autocast(
device_type="cpu", dtype=torch.bfloat16, enabled=False
):
g_float32 = torch.mm(a_float32, b_float32)
f_float16 = self.mm_breaks(a_float32, b_float32)
assert (
f_float16[0][0] == self.mm_not_break(a_float32, b_float32)[0][0]
)
return f_float16, g_float32
module = MyModule(bias=torch.rand((8, 8), device="cpu", dtype=torch.bfloat16))
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
# Autocast doesn't work on addition, so we need the bias to be `bfloat16`
res = torch.rand((8, 8), device="cpu", dtype=torch.float32) + torch.rand(
(8, 8), device="cpu", dtype=torch.bfloat16
)
self.assertEqual(res.dtype, torch.float32)
real_16, real_32 = module(torch.tensor([0.5]))
real_device_16 = real_16.device
real_dtype_16 = real_16.dtype
real_device_32 = real_32.device
real_dtype_32 = real_32.dtype
graph = torch._dynamo.optimize("eager")(module)
out_16, out_32 = graph(torch.tensor([0.5]))
self.assertEqual(out_16.device, real_device_16)
self.assertEqual(out_16.dtype, real_dtype_16)
self.assertEqual(out_32.device, real_device_32)
self.assertEqual(out_32.dtype, real_dtype_32)
self.assertEqual(out_16.device.type, "cpu")
self.assertEqual(out_16.dtype, torch.bfloat16)
self.assertEqual(out_32.device.type, "cpu")
self.assertEqual(out_32.dtype, torch.float32)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_autocast_float64(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with torch.autocast(device_type="cuda", dtype=torch.float64):
e_float64 = torch.mm(a_float32, b_float32)
f_float64 = torch.mm(d_float32, e_float64)
return f_float64
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.float64)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_autocast_device(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with torch.autocast("cuda"):
e_float64 = torch.mm(a_float32, b_float32)
f_float64 = torch.mm(d_float32, e_float64)
return f_float64
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.torch.float16)