blob: ffd7dec3d83108d93548c4805854c5840e971edf [file] [log] [blame]
# Owner(s): ["module: inductor"]
import unittest
import torch
from torch._dynamo import config as dynamo_config
from torch._inductor import config as inductor_config
from torch._inductor.utils import is_big_gpu
from torch.testing import make_tensor
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_LINUX, TestCase as TorchTestCase
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, skipCUDAIf
class TestUnbackedSymints(TorchTestCase):
@skipCUDAIf(not HAS_CUDA, "requires cuda")
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
def test_expand(self, device):
def fn(x, y):
nz = torch.nonzero(x)
# unbacked symint in nz.size
x_exp = nz.expand([-1, 128])
# unbacked symint in target sizes
y_exp = y.expand([-1, nz.size(0)])
return x_exp, y_exp
example_inputs = (
torch.randn((32), device=device),
torch.randn((32, 1), device=device),
)
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
@skipCUDAIf(not HAS_CUDA, "requires cuda")
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
def test_expand_mismatch(self, device):
def fn(x):
nz = x.nonzero()
return nz.expand([-1, 128])
x = make_tensor(32, 4, device=device, dtype=torch.float32, exclude_zero=True)
with self.assertRaises(torch._dynamo.exc.TorchRuntimeError):
actual = torch.compile(fn, fullgraph=True)(x)
@skipCUDAIf(not HAS_CUDA, "requires cuda")
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
def test_broadcast_tensors(self, device):
def fn(x):
nz = x.nonzero()
a = torch.zeros([nz.size(0), 512])
b = torch.ones([nz.size(0), 1])
return a * b
x = torch.randn(32, 4, device=device)
actual = torch.compile(fn, fullgraph=True)(x)
expected = fn(x)
torch.testing.assert_close(actual, expected)
@skipCUDAIf(not HAS_CUDA, "requires cuda")
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
def test_autotuning(self, device):
def fn(x, y):
nz = torch.nonzero(x)
# unbacked symint in the GEMM input shape
a = x.new_ones([nz.size(0), y.size(0)])
return a @ y
example_inputs = (
torch.randn((64), device=device),
torch.randn((32, 16), device=device),
)
with inductor_config.patch(
{
"max_autotune_gemm": True,
}
):
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
@skipCUDAIf(not HAS_CUDA, "requires cuda")
@dynamo_config.patch({"capture_scalar_outputs": True})
def test_split_with_sizes(self, device):
def fn(x, y):
l = y.tolist()
s = torch.split(x, l)
d = l[0] + l[1] + l[2]
return s[0].sum(), d
example_inputs = (torch.randn((32), device=device), torch.tensor((7, 16, 9)))
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
@skipCUDAIf(not HAS_CUDA, "requires cuda")
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
def test_view_of_slice(self, device):
# Tests View.create(slice, size_with_unbacked_symint)
def fn(x):
nz = torch.nonzero(x) # introduce unbacked symint
squared = nz * nz # avoid ReinterpretView when lowering Slice
sliced = torch.ops.aten.slice.Tensor(squared, dim=1, start=-2, end=None)
view = sliced.unsqueeze(dim=0)
return view.squeeze(
dim=0
) # make sure no unbacked symint in output's stride
example_inputs = (torch.randn(1, 1, 1, 1, device=device),)
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
@skipCUDAIf(not HAS_CUDA, "requires cuda")
@dynamo_config.patch({"capture_scalar_outputs": True})
@inductor_config.patch({"abi_compatible": True})
def test_triton_kernel_grid(self, device):
if device == "cpu":
raise unittest.SkipTest("Triton kernel requires GPU")
from torch.testing._internal.triton_utils import add_kernel
def fn(x):
maxlen = max(x.item(), 512)
a = torch.ones(maxlen, device=device)
b = torch.ones(maxlen, device=device)
out = torch.zeros_like(a)
# unbacked symint in grid
add_kernel[(1, 1, maxlen)](a, b, out, maxlen, 32)
return out
example_inputs = (torch.randint(high=1024, size=(1,), device=device),)
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
expected = fn(*example_inputs)
torch.testing.assert_close(actual, expected)
instantiate_device_type_tests(
TestUnbackedSymints, globals(), only_for=(GPU_TYPE, "cpu")
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
if IS_LINUX and HAS_CUDA and is_big_gpu(0):
run_tests()