blob: c3dbf362fb896ab5c5cc3e062f0797cd1aca79a7 [file] [log] [blame]
# Owner(s): ["module: inductor"]
import sys
import unittest
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_GPU
try:
import triton # noqa: F401
except ImportError:
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires triton") # noqa: TRY200
from torch._dynamo.test_case import run_tests, TestCase
from torch._inductor import config
from torch._inductor.triton_heuristics import triton_config
class TestTritonHeuristics(TestCase):
def test_triton_config(self):
"""
Make sure block size does not exceed the maximum defined in inductor config.
"""
cfg = triton_config([2048, 2], 64, 64)
for label in "XYZ":
key = f"{label}BLOCK"
if key not in cfg.kwargs:
continue
self.assertTrue(cfg.kwargs[key] <= config.triton.max_block[label])
if __name__ == "__main__":
if IS_LINUX and HAS_GPU:
run_tests()