| # Owner(s): ["oncall: jit"] |
| |
| import torch |
| import torch._lazy.metrics as metrics |
| import torch._lazy.ts_backend |
| from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase |
| |
| torch._lazy.ts_backend.init() |
| |
| |
| class LazyGeneratorTest(TestCase): |
| def test_generator(self): |
| """ |
| Test that generators are being inserted into the TorchScript |
| graph by setting different seeds before each call to |
| generate_tensor but the resulting tensor is the same |
| """ |
| |
| def generate_tensor(): |
| g1 = torch.Generator() |
| g1.manual_seed(2023) |
| t1 = torch.tensor(1.0) |
| t1.uniform_(generator=g1) |
| |
| g2 = torch.Generator() |
| g2.manual_seed(2024) |
| t2 = torch.tensor(1.0) |
| t2.normal_(generator=g2) |
| |
| return t1, t2 |
| |
| torch.manual_seed(1) |
| |
| with torch.device("cpu"): |
| cpu_t1, cpu_t2 = generate_tensor() |
| |
| torch.manual_seed(2) |
| |
| with torch.device("lazy"): |
| lazy_t1, lazy_t2 = generate_tensor() |
| |
| torch._lazy.mark_step() |
| |
| assert torch.allclose( |
| cpu_t1, lazy_t1.to("cpu") |
| ), f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}" |
| assert torch.allclose( |
| cpu_t2, lazy_t2.to("cpu") |
| ), f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}" |
| |
| @skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type") |
| def test_generator_causes_multiple_compiles(self): |
| """ |
| Test that inserting generators with different seed caused recompile |
| """ |
| |
| def generate_tensor(seed): |
| t = torch.tensor(1.0) |
| g = torch.Generator() |
| g.manual_seed(seed) |
| t.uniform_(-1, 1, generator=g) |
| return t |
| |
| metrics.reset() |
| |
| with torch.device("lazy"): |
| t = generate_tensor(1) |
| torch._lazy.mark_step() |
| |
| uncached_compile = metrics.counter_value("UncachedCompile") |
| assert ( |
| uncached_compile == 1 |
| ), f"Expected 1 uncached compiles, got {uncached_compile}" |
| |
| t = generate_tensor(2) |
| torch._lazy.mark_step() |
| |
| uncached_compile = metrics.counter_value("UncachedCompile") |
| assert ( |
| uncached_compile == 2 |
| ), f"Expected 2 uncached compiles, got {uncached_compile}" |
| |
| t = generate_tensor(1) |
| torch._lazy.mark_step() |
| |
| uncached_compile = metrics.counter_value("UncachedCompile") |
| assert ( |
| uncached_compile == 2 |
| ), f"Expected 2 uncached compiles, got {uncached_compile}" |
| cached_compile = metrics.counter_value("CachedCompile") |
| assert ( |
| cached_compile == 1 |
| ), f"Expected 1 cached compile, got {cached_compile}" |
| |
| metrics.reset() |
| |
| latest_graph = torch._C._lazy_ts_backend._get_latest_computation_graph() |
| assert 'torch.Generator(device="cpu", seed=1)' in latest_graph |
| assert "aten::uniform" in latest_graph |
| |
| |
| if __name__ == "__main__": |
| run_tests() |