blob: 24d8192844f21d4b564f8fd9a0891158130c4d57 [file] [log] [blame]
# Owner(s): ["module: inductor"]
import contextlib
import torch
from torch._inductor.dependencies import MemoryDep
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import Buffer, FixedLayout, Pointwise
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import sympy_index_symbol
from torch._inductor.virtualized import ops, V
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
class TestDependencies(InductorTestCase):
def _create_buffer(self, name, shape, dtype=torch.float32):
return Buffer(name, FixedLayout(torch.device(GPU_TYPE), dtype, shape))
def setUp(self):
super().setUp()
class DummyModule(torch.nn.Module):
def forward(self, x):
return x * 2
self._gm = torch.fx.symbolic_trace(DummyModule())
self._graph = GraphLowering(self._gm)
self._stack = contextlib.ExitStack()
self._stack.enter_context(V.set_graph_handler(self._graph))
def tearDown(self):
self._stack.close()
super().tearDown()
def test_bucketize_dependencies(self):
offsets = self._create_buffer("offsets", (1025,), torch.int32)
def inner_fn(index):
idx = index[0]
return ops.bucketize(
values=idx,
offsets_name=offsets.get_name(),
offsets_size=offsets.get_size()[0],
indexing_dtype=torch.int32,
right=True,
)
pointwise = Pointwise.create(
device=torch.device(GPU_TYPE),
dtype=torch.int32,
inner_fn=inner_fn,
ranges=[1024 * 4],
)
self.assertEqual(len(pointwise.get_reads()), 1)
def test_get_offset(self):
x = sympy_index_symbol("x")
y = sympy_index_symbol("y")
var_ranges = {
x: 1024,
y: 2048,
}
dep1 = MemoryDep(
"dep1",
x * 2048 + y,
list(var_ranges.keys()),
list(var_ranges.values()),
)
dep2 = MemoryDep(
"dep2",
x * 2048 + y + 1024,
list(var_ranges.keys()),
list(var_ranges.values()),
)
self.assertEqual(dep1.get_offset(), 0)
self.assertEqual(dep2.get_offset(), 1024)
def test_normalize_with_stride_order_equal(self):
x = sympy_index_symbol("x")
y = sympy_index_symbol("y")
var_ranges = {
x: 1024,
y: 2048,
}
loop_order1 = MemoryDep(
"access_the_same_buffer",
x * 2048 + y,
[x, y],
[1024, 2048],
)
loop_order2 = MemoryDep(
"access_the_same_buffer",
x * 2048 + y,
[y, x],
[2048, 1024],
)
self.assertTrue(loop_order1 != loop_order2)
normalized_loop_order1 = loop_order1.normalize_with_stride_order()
normalized_loop_order2 = loop_order2.normalize_with_stride_order()
self.assertTrue(normalized_loop_order1 == normalized_loop_order2)
def test_normalize_with_stride_order_unequal(self):
x = sympy_index_symbol("x")
y = sympy_index_symbol("y")
var_ranges = {
x: 1024,
y: 2048,
}
loop_order1 = MemoryDep(
"access_the_same_buffer",
x * 2048 + y,
[x, y],
[1024, 2048],
)
loop_order2 = MemoryDep(
"access_the_same_buffer",
x * 2048 + y + 5,
[y, x],
[2048, 1024],
)
self.assertTrue(loop_order1 != loop_order2)
normalized_loop_order1 = loop_order1.normalize_with_stride_order()
normalized_loop_order2 = loop_order2.normalize_with_stride_order()
# unequal due to different offset
self.assertTrue(normalized_loop_order1 != normalized_loop_order2)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_CPU and HAS_GPU:
run_tests("sympy")