Turn translation validation on for tests and accuracy runs by default. (#103611)

This PR turns translation validation on by default for tests and accuracy benchmark
runs. It also installs Z3 on CI.

The main changes are:

- Add `--no-translation-validation` as an option in _test/run_tests.py_
    - Set `PYTORCH_TEST_WITH_TV` environment variable
- Add `TEST_WITH_TV` variable in _torch/testing/_internal/common_utils.py_
- Turn translation validation on for accuracy benchmarks in _benchmarks/dynamo/common.py_
- Add Z3 installation on CI scripts

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103611
Approved by: https://github.com/ezyang
diff --git a/.ci/docker/common/install_inductor_benchmark_deps.sh b/.ci/docker/common/install_inductor_benchmark_deps.sh
index 238681c..4e6c776 100644
--- a/.ci/docker/common/install_inductor_benchmark_deps.sh
+++ b/.ci/docker/common/install_inductor_benchmark_deps.sh
@@ -9,6 +9,7 @@
   version=$(get_pinned_commit huggingface)
   pip_install pandas
   pip_install scipy
+  pip_install z3-solver
   pip_install "transformers==${version}"
 }
 
@@ -17,8 +18,9 @@
   commit=$(get_pinned_commit timm)
   pip_install pandas
   pip_install scipy
+  pip_install z3-solver
   pip_install "git+https://github.com/rwightman/pytorch-image-models@${commit}"
 }
 
 install_huggingface
-# install_timm
\ No newline at end of file
+# install_timm
diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt
index a5bc24d..3cd3197 100644
--- a/.ci/docker/requirements-ci.txt
+++ b/.ci/docker/requirements-ci.txt
@@ -269,3 +269,8 @@
 #Description: This is used by pytest to invoke C++ tests
 #Pinned versions: 2.3.0
 #test that import:
+
+z3-solver
+#Description: The Z3 Theorem Prover Project
+#Pinned versions:
+#test that import:
diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh
index 8b86cfa..48a8707 100644
--- a/.ci/pytorch/common_utils.sh
+++ b/.ci/pytorch/common_utils.sh
@@ -194,6 +194,7 @@
   version=$(get_pinned_commit huggingface)
   pip_install pandas
   pip_install scipy
+  pip_install z3-solver
   pip_install "transformers==${version}"
 }
 
@@ -202,6 +203,7 @@
   commit=$(get_pinned_commit timm)
   pip_install pandas
   pip_install scipy
+  pip_install z3-solver
   pip_install "git+https://github.com/rwightman/pytorch-image-models@${commit}"
 }
 
diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh
index c4f222e..9933831 100755
--- a/.ci/pytorch/win-test.sh
+++ b/.ci/pytorch/win-test.sh
@@ -37,6 +37,9 @@
 # TODO: Move both of them to Windows AMI
 python -m pip install pytest-rerunfailures==10.3 pytest-cpp==2.3.0
 
+# Install Z3 optional dependency for Windows builds.
+python -m pip install z3-solver
+
 run_tests() {
     # Run nvidia-smi if available
     for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe; do
diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt
index a763ad7..0df5ea3 100644
--- a/.github/requirements/pip-requirements-macOS.txt
+++ b/.github/requirements/pip-requirements-macOS.txt
@@ -24,3 +24,4 @@
 sympy==1.11.1
 pytest-cpp==2.3.0
 rockset==1.0.3
+z3-solver==4.12.2.0
diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py
index b3cc4f0..fb5856d 100644
--- a/benchmarks/dynamo/common.py
+++ b/benchmarks/dynamo/common.py
@@ -2854,6 +2854,9 @@
         torch.backends.cudnn.benchmark = False
         torch.backends.cuda.matmul.allow_tf32 = False
 
+        # Set translation validation on by default on accuracy runs.
+        torch._dynamo.config.translation_validation = True
+
         # Remove randomeness when torch manual seed is called
         patch_torch_manual_seed()
 
diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py
index 608b598..5bd2a27 100644
--- a/test/dynamo/test_dynamic_shapes.py
+++ b/test/dynamo/test_dynamic_shapes.py
@@ -1,4 +1,6 @@
 # Owner(s): ["module: dynamo"]
+import unittest
+
 from torch._dynamo import config
 from torch._dynamo.testing import make_test_cls_with_patches
 
@@ -45,6 +47,7 @@
         (config, "assume_static_by_default", automatic_dynamic_shapes),
         (config, "automatic_dynamic_shapes", automatic_dynamic_shapes),
         (config, "specialize_int", False),
+        (config, "translation_validation", True),
         xfail_prop="_expected_failure_automatic_dynamic"
         if automatic_dynamic_shapes
         else "_expected_failure_dynamic",
@@ -71,6 +74,12 @@
     make_dynamic_cls(test)
     make_dynamic_cls(test, automatic_dynamic_shapes=True)
 
+unittest.expectedFailure(
+    # SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'.
+    # Ref: https://github.com/sympy/sympy/issues/25146
+    DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes
+)
+
 if __name__ == "__main__":
     from torch._dynamo.test_case import run_tests
 
diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py
index 8a32bbc..1179609 100644
--- a/test/dynamo/test_functions.py
+++ b/test/dynamo/test_functions.py
@@ -17,6 +17,9 @@
 from torch._dynamo.testing import expectedFailureDynamic, requires_numpy_pytorch_interop
 from torch._dynamo.utils import same
 from torch.nn import functional as F
+from torch.testing._internal.common_utils import (
+    disable_translation_validation_if_dynamic_shapes,
+)
 
 d = torch.ones(10, 10)
 e = torch.nn.Linear(10, 10)
@@ -982,6 +985,7 @@
         else:
             return x - 1
 
+    @disable_translation_validation_if_dynamic_shapes
     @make_test
     def test_torch_distributions_functions(x):
         normal = torch.distributions.Normal(x, torch.tensor(1))
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index c443fbe..efdbdfb 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -19,6 +19,7 @@
 from unittest.mock import patch
 
 import numpy as np
+import sympy
 import torch
 
 import torch._dynamo.test_case
@@ -44,7 +45,8 @@
 from torch.ao.quantization.qconfig import QConfig
 from torch.ao.quantization.quantize_fx import prepare_qat_fx
 from torch.autograd.profiler import _enable_dynamo_cache_lookup_profiler
-from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
+from torch.fx.experimental.symbolic_shapes import ConstraintViolationError, FloorDiv
+from torch.fx.experimental.validator import SympyToZ3, TranslationValidator
 from torch.nn import functional as F
 from torch.testing._internal.common_cuda import (
     PLATFORM_SUPPORTS_FUSED_SDPA,
@@ -5870,6 +5872,122 @@
         self.assertEqual(counter.frame_count, 1)
         self.assertEqual(counter.op_count, 18)
 
+    def _prepare_for_translation_validator(self):
+        validator = TranslationValidator()
+
+        # SymPy symbols.
+        s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
+
+        # Z3 symbols.
+        [validator.add_var(s, int) for s in (s0, s1, s2)]
+        z0, z1, z2 = [validator.z3var(s) for s in (s0, s1, s2)]
+
+        return (s0, s1, s2), (z0, z1, z2), validator
+
+    @torch._dynamo.config.patch(translation_validation=True)
+    def test_sympy_to_z3_translation(self):
+        import z3
+
+        (
+            (s0, s1, s2),
+            (z0, z1, z2),
+            validator,
+        ) = self._prepare_for_translation_validator()
+
+        test_cases = [
+            # Integer constants.
+            (sympy.S.Zero, z3.IntVal(0)),
+            (sympy.S.One, z3.IntVal(1)),
+            (sympy.S.NegativeOne, z3.IntVal(-1)),
+            (sympy.Integer(2), z3.IntVal(2)),
+            (
+                s0,
+                z0,
+            ),
+            # Arithmetic operations.
+            *[
+                (op(s0, s1), op(z0, z1))
+                for op in (
+                    operator.add,
+                    operator.mod,
+                    operator.mul,
+                    operator.pow,
+                )
+            ],
+            # Logical operations.
+            *[
+                (sympy_op(s0, s1), z3_op(z0, z1))
+                for sympy_op, z3_op in (
+                    (sympy.Eq, operator.eq),
+                    (sympy.Ne, operator.ne),
+                    (sympy.Lt, operator.lt),
+                    (sympy.Le, operator.le),
+                    (sympy.Gt, operator.gt),
+                    (sympy.Ge, operator.ge),
+                )
+            ],
+            # Other operations.
+            (
+                s0 - s1,
+                z0 + z3.IntVal(-1) * z1,
+            ),
+            (
+                s0 / s1,
+                z3.ToReal(z0) * (z1**-1),
+            ),
+            (s2 % (s0 / s1), z2 % z3.ToInt(z3.ToReal(z0) * (z1**-1))),
+            (s2 % (s0**3), z2 % z3.ToInt(z0**3)),
+            (FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
+        ]
+
+        toZ3 = SympyToZ3(validator)
+        for sympy_expr, z3_expr in test_cases:
+            result = toZ3.run(sympy_expr)
+            self.assertTrue(
+                z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
+            )
+
+    @torch._dynamo.config.patch(translation_validation=True)
+    def test_translation_validator_sat(self):
+        (
+            (s0, s1, s2),
+            (z0, z1, z2),
+            validator,
+        ) = self._prepare_for_translation_validator()
+
+        validator.add_source_expr(z0 > 5)
+        validator.add_source_expr(z1 / 2 > z0)
+
+        # Solutions for target is a subset of the solutions for the source.
+        validator.add_target_expr(s0 > 20)
+        validator.add_target_expr(s1 > s0**2)
+
+        r = validator.validate()
+        self.assertEqual(r.success, True, msg=f"failed with model: {r.model}")
+        self.assertIsNone(r.model)
+        self.assertIsNone(r.failed_source_expr)
+
+    @torch._dynamo.config.patch(translation_validation=True)
+    def test_translation_validator_unsat(self):
+        (
+            (s0, s1, s2),
+            (z0, z1, z2),
+            validator,
+        ) = self._prepare_for_translation_validator()
+
+        validator.add_source_expr(z0 > 5)
+        validator.add_source_expr(z1 / 2 > z0)
+
+        # Solutions for target is NOT a subset of the solutions for the source.
+        validator.add_target_expr(s0 > 20)
+        # This expression is less restrictive than its counterpart.
+        validator.add_target_expr(s1 > s0 + 2)
+
+        r = validator.validate()
+        self.assertEqual(r.success, False, msg=f"failed with model: {r.model}")
+        self.assertIsNotNone(r.model)
+        self.assertIsNotNone(r.failed_source_expr)
+
 
 class TestTracer(JitTestCase):
     def test_jit_save(self):
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index 9d5c659..855cf73 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -32,6 +32,9 @@
 from torch._dynamo.debug_utils import same_two_models
 from torch._dynamo.testing import expectedFailureDynamic, rand_strided, same
 from torch.nn import functional as F
+from torch.testing._internal.common_utils import (
+    disable_translation_validation_if_dynamic_shapes,
+)
 
 
 _orig_module_call = torch.nn.Module.__call__
@@ -978,6 +981,7 @@
             self.assertExpectedInline(cnt.frame_count, """3""")
             self.assertExpectedInline(cnt.op_count, """10""")
 
+    @disable_translation_validation_if_dynamic_shapes
     def test_longformer_chunk(self):
         input1 = torch.randn([1, 4096, 1])
         input2 = torch.randn([12, 4096, 64])
@@ -1143,6 +1147,7 @@
             self.assertEqual(cnt.frame_count, 1)
             self.assertEqual(cnt.op_count, 1)
 
+    @disable_translation_validation_if_dynamic_shapes
     def test_create_rand_mask_from_inputs(self):
         args = [
             torch.randn([1, 64, 64]),
diff --git a/test/run_test.py b/test/run_test.py
index e619e46..54df46e 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -1205,6 +1205,11 @@
             "doctest to run"
         ),
     )
+    parser.add_argument(
+        "--no-translation-validation",
+        action="store_false",
+        help="Run tests without translation validation.",
+    )
 
     group = parser.add_mutually_exclusive_group()
     group.add_argument(
@@ -1636,6 +1641,9 @@
     elif options.inductor:
         os.environ["PYTORCH_TEST_WITH_INDUCTOR"] = "1"
 
+    if not options.no_translation_validation:
+        os.environ["PYTORCH_TEST_WITH_TV"] = "1"
+
     os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
 
     failure_messages = []
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index e4e5e0e..dc10fec 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1420,8 +1420,14 @@
         fx_g = _trace(f, 2, 4, 8, 16, 32)
         self.assertExpectedInline(show_guards(fx_g), """""")
 
+    @torch._dynamo.config.patch(translation_validation=True)
+    def test_constant_specialization(self):
+        def f(t):
+            assert t.shape[0] == 10
+            return t
 
-
+        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(10))
+        self.assertExpectedInline(show_guards(tensor), """""")
 
 
 make_fx_failures = {
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 2a80b2f..b9d224a 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -1092,6 +1092,25 @@
     return decorator
 
 
+# Run PyTorch tests with translation validation on.
+TEST_WITH_TV = os.getenv('PYTORCH_TEST_WITH_TV') == '1'
+
+if TEST_WITH_TV:
+    torch._dynamo.config.translation_validation = True
+
+# Some tests take too long when dynamic_shapes is combined with
+# translation_validation. Whenever that happens, we solve that by
+# disabling translation_validation.
+def disable_translation_validation_if_dynamic_shapes(fn):
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if torch._dynamo.config.dynamic_shapes:
+            # Turning TV off due to high latency on dynamic shapes.
+            torch._dynamo.config.translation_validation = False
+        return fn(*args, **kwargs)
+    return wrapper
+
+
 # Determine whether to enable cuda memory leak check.
 # CUDA mem leak check is expensive and thus we don't want to execute it on every
 # test case / configuration.