blob: 8b0b4b3a13cde52d7968fd18128a74f7cd4d17e6 [file] [log] [blame]
import os
import tempfile
import unittest
from tools.gen_vulkan_glsl import GLSLGenerator
from yaml.constructor import ConstructorError
class TestGLSLCodegen(unittest.TestCase):
def test_assert_on_duplicate_key_yaml(self) -> None:
yaml_with_duplicate_keys = """
conv2d_pw:
parameter_names_with_default_values:
TILE_SIZE_X: 1
TILE_SIZE_Y: 1
parameter_values:
- TILE_SIZE_X: 2
TILE_SIZE_Y: 2
- TILE_SIZE_X: 2
TILE_SIZE_Y: 4
- TILE_SIZE_X: 4
TILE_SIZE_Y: 2
- TILE_SIZE_X: 4
TILE_SIZE_Y: 4
conv2d_pw:
parameter_names_with_default_values:
- TILE_SIZE_X: 1
- TILE_SIZE_Y: 1
parameter_values:
- TILE_SIZE_X: 2
TILE_SIZE_Y: 2
- TILE_SIZE_X: 2
TILE_SIZE_Y: 4
- TILE_SIZE_X: 4
TILE_SIZE_Y: 2
- TILE_SIZE_X: 4
TILE_SIZE_Y: 4
"""
generator = GLSLGenerator() # type: ignore[no-untyped-call]
with tempfile.NamedTemporaryFile(mode="w") as fp:
fp.write(yaml_with_duplicate_keys)
fp.flush()
with self.assertRaisesRegex(
ConstructorError, r"while constructing a mapping"
):
generator.add_params_yaml(fp.name) # type: ignore[no-untyped-call]
def test_assert_keys_mismatch(self) -> None:
yaml_with_key_mismatch = """
conv2d_pw:
parameter_names_with_default_values:
TILE_SIZE_X: 1
TILE_SIZE_Y: 1
parameter_values:
- TILE_SIZE_X: 2
TILE_SIZE_Z: 2
"""
generator = GLSLGenerator() # type: ignore[no-untyped-call]
with tempfile.NamedTemporaryFile(mode="w") as fp:
fp.write(yaml_with_key_mismatch)
fp.flush()
with self.assertRaisesRegex(KeyError, r"Invalid keys {'TILE_SIZE_Z'}"):
generator.add_params_yaml(fp.name) # type: ignore[no-untyped-call]
def test_missing_key_default_val(self) -> None:
yaml_with_key_mismatch = """
conv2d_pw:
parameter_names_with_default_values:
TILE_SIZE_X: 1
TILE_SIZE_Y: 1
parameter_values:
- TILE_SIZE_Y: 2
"""
file_content = """
x = $TILE_SIZE_X + $TILE_SIZE_Y
"""
generator = GLSLGenerator() # type: ignore[no-untyped-call]
with tempfile.NamedTemporaryFile(mode="w") as fp:
fp.write(yaml_with_key_mismatch)
fp.flush()
generator.add_params_yaml(fp.name) # type: ignore[no-untyped-call]
with tempfile.TemporaryDirectory() as tmp_dir:
template_file_name = os.path.join(tmp_dir, "conv2d_pw.glslt")
with open(template_file_name, "w") as template_file:
template_file.write(file_content)
template_file.flush()
generator.generate(template_file.name, tmp_dir) # type: ignore[no-untyped-call]
file_name_1 = os.path.join(tmp_dir, "conv2d_pw_1x1.glsl")
file_name_2 = os.path.join(tmp_dir, "conv2d_pw_1x2.glsl")
self.assertTrue(os.path.exists(file_name_1))
self.assertTrue(os.path.exists(file_name_2))
with open(file_name_1, "r") as f:
contents = f.read()
self.assertTrue("1 + 1" in contents)
with open(file_name_2, "r") as f:
contents = f.read()
self.assertTrue("1 + 2" in contents)