Remove config check in specialize (#102098)

Fixes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102098
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py
index 4d422d7..ed4bf62 100644
--- a/test/dynamo/test_dynamic_shapes.py
+++ b/test/dynamo/test_dynamic_shapes.py
@@ -36,9 +36,6 @@
         # Could not infer dtype of torch._C.SymIntNode
         "test_convert_boxes_to_pooler_format",
     ],
-    "SubGraphTests": [
-        "test_enumerate_not_break_graph",
-    ],
 }
 
 XFAIL_HITS = 0
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index ab0b22c..4e688cd 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -163,7 +163,7 @@
             unpack4,
             2,
             expected_ops=5,
-            expected_ops_dynamic=ifdynstaticdefault(6, 7),
+            expected_ops_dynamic=ifdynstaticdefault(5, 7),
         )
 
     def test_unpack5(self):
@@ -180,7 +180,7 @@
             unpack5,
             2,
             expected_ops=5,
-            expected_ops_dynamic=ifdynstaticdefault(6, 7),
+            expected_ops_dynamic=ifdynstaticdefault(5, 7),
         )
 
     def test_matmul1(self):
@@ -204,7 +204,7 @@
             return x + y
 
         torch._dynamo.testing.standard_test(
-            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 11)
+            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 11)
         )
 
     def test_shape_int_inplace_binops(self):
@@ -220,7 +220,7 @@
             return x + p
 
         torch._dynamo.testing.standard_test(
-            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 10)
+            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 10)
         )
 
     def test_int_shape_inplace_binops(self):
@@ -244,7 +244,7 @@
             return x + y
 
         torch._dynamo.testing.standard_test(
-            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 10)
+            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 10)
         )
 
     def test_int_int_comparisons(self):
@@ -289,7 +289,7 @@
 
         # expect for dynamic: size, index, 6 comparison ops, add
         torch._dynamo.testing.standard_test(
-            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 9)
+            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 9)
         )
 
     def test_int_shape_comparisons(self):
@@ -314,7 +314,7 @@
 
         # expect for dynamic: size, index, 6 comparison ops, add
         torch._dynamo.testing.standard_test(
-            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 9)
+            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 9)
         )
 
     def test_param_shape_binops(self):
@@ -348,7 +348,7 @@
         self.assertEqual(counts.frame_count, 1)
 
         expected_op_count = (
-            ifdynstaticdefault(3, 12)
+            ifdynstaticdefault(1, 11)
             if torch._dynamo.testing.config.dynamic_shapes
             else 1
         )
@@ -377,7 +377,7 @@
         self.assertTrue(same(ref, res))
         self.assertEqual(counts.frame_count, 1)
         expected_op_count = (
-            ifdynstaticdefault(2, 4)
+            ifdynstaticdefault(1, 4)
             if torch._dynamo.testing.config.dynamic_shapes
             else 1
         )
@@ -504,7 +504,11 @@
 
         # expect extra size node for dynamic
         torch._dynamo.testing.standard_test(
-            self, fn, 1, expected_ops=20, expected_ops_dynamic=21
+            self,
+            fn,
+            1,
+            expected_ops=20,
+            expected_ops_dynamic=ifdynstaticdefault(20, 21),
         )
 
     def test_empty_list(self):
@@ -538,14 +542,14 @@
             get_test_fn(func=min),
             2,
             expected_ops=1,
-            expected_ops_dynamic=ifdynstaticdefault(3, 14),
+            expected_ops_dynamic=ifdynstaticdefault(1, 14),
         )
         torch._dynamo.testing.standard_test(
             self,
             get_test_fn(func=max),
             2,
             expected_ops=1,
-            expected_ops_dynamic=ifdynstaticdefault(3, 17),
+            expected_ops_dynamic=ifdynstaticdefault(1, 17),
         )
 
     def test_config_obj(self):
@@ -786,7 +790,11 @@
             return (a + a.numel() + torch.numel(a), a + a.nelement())
 
         return torch._dynamo.testing.standard_test(
-            self, fn=fn, nargs=1, expected_ops=3, expected_ops_dynamic=6
+            self,
+            fn=fn,
+            nargs=1,
+            expected_ops=3,
+            expected_ops_dynamic=ifdynstaticdefault(3, 6),
         )
 
     def test_pair(self):
@@ -802,7 +810,7 @@
             fn=fn,
             nargs=1,
             expected_ops=5,
-            expected_ops_dynamic=ifdynstaticdefault(6, 8),
+            expected_ops_dynamic=ifdynstaticdefault(5, 8),
         )
 
     @patch.object(torch._dynamo.config, "dynamic_shapes", True)
@@ -901,7 +909,11 @@
 
         # expect 1 more op (size call) for dynamic
         return torch._dynamo.testing.standard_test(
-            self, fn=fn, nargs=1, expected_ops=9, expected_ops_dynamic=10
+            self,
+            fn=fn,
+            nargs=1,
+            expected_ops=9,
+            expected_ops_dynamic=ifdynstaticdefault(9, 10),
         )
 
     def test_build_tuple_unpack(self):
@@ -966,7 +978,7 @@
 
         # expect 3 ops post folding for dynamic case: size, index, add
         torch._dynamo.testing.standard_test(
-            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 3)
+            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 3)
         )
 
     def test_tuple_iadd_with_shape(self):
@@ -980,7 +992,7 @@
 
         # expect 4 add / subs for static, 4 * 3 (size, index, math op) for dynamic
         torch._dynamo.testing.standard_test(
-            self, fn, 1, expected_ops=4, expected_ops_dynamic=ifdynstaticdefault(8, 12)
+            self, fn, 1, expected_ops=4, expected_ops_dynamic=ifdynstaticdefault(4, 12)
         )
 
     def test_list_iadd_with_shape(self):
@@ -995,7 +1007,7 @@
         # expect 6 add / subs for static, 6 * 3 (size, index, math op) for dynamic
 
         torch._dynamo.testing.standard_test(
-            self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(12, 18)
+            self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(6, 18)
         )
 
     def test_user_getattr1(self):
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index 036d981..cd5179d 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -890,7 +890,7 @@
         self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0))
 
         self.assertEqual(cnt.frame_count, 1)
-        self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(3, 6), 1))
+        self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(1, 6), 1))
 
     def _reformer(self, nopython):
         input = torch.randn([1, 64, 256])
@@ -964,8 +964,8 @@
         with torch.enable_grad():
             cnt = self._reformer(nopython=False)
         # cant inline torch.autograd.Function means graph break
-        self.assertEqual(cnt.frame_count, 3)
-        self.assertEqual(cnt.op_count, 10)
+        self.assertEqual(cnt.frame_count, ifunspec(ifdyn(3, 1), 3))
+        self.assertEqual(cnt.op_count, ifunspec(ifdyn(10, 11), 10))
 
     def test_longformer_chunk(self):
         input1 = torch.randn([1, 4096, 1])
@@ -981,7 +981,9 @@
         self.assertTrue(same(opt_fn(input2), correct2))
 
         self.assertEqual(cnt.frame_count, 2)
-        self.assertEqual(cnt.op_count, ifunspec(37, ifdyn(20, 4)))
+        self.assertEqual(
+            cnt.op_count, ifunspec(37, ifdyn(ifdynstaticdefault(15, 20), 4))
+        )
 
     def test_hf_t5_forward(self):
         input = torch.randn([1, 2048, 512])
@@ -992,7 +994,7 @@
         self.assertTrue(same(opt_model(input), correct))
 
         self.assertEqual(cnt.frame_count, 1)
-        self.assertEqual(cnt.op_count, ifdyn(12, 11))
+        self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(11, 12), 11))
 
     def test_module_in_skipfiles(self):
         model = nn.Linear(10, 10)
@@ -1089,7 +1091,7 @@
             self.assertTrue(same(opt_model(a, b, c, d), correct))
 
         if torch._dynamo.config.assume_static_by_default:
-            self.assertEqual(cnt.frame_count, 5)
+            self.assertEqual(cnt.frame_count, 4)
         else:
             self.assertEqual(cnt.frame_count, 6)
 
@@ -1282,7 +1284,7 @@
         opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
         self.assertTrue(same(opt_fn(x), correct))
         self.assertEqual(cnt.frame_count, 1)
-        self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(21, 27), 14))
+        self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(14, 27), 14))
 
     def test_recursive_map(self):
         # https://github.com/pytorch/torchdynamo/issues/132
@@ -2829,6 +2831,7 @@
         self.assertRaises(AttributeError, lambda: fn(x, obj1))
 
     @torch._dynamo.config.patch("dynamic_shapes", True)
+    @torch._dynamo.config.patch("automatic_dynamic_shapes", False)
     def test_dynamic_shapes_implicit_guard(self):
         def f(x):
             y = x * x.size(x.shape[0])
diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py
index b832a63..32c0c48 100644
--- a/test/dynamo/test_subgraphs.py
+++ b/test/dynamo/test_subgraphs.py
@@ -8,7 +8,7 @@
 import torch._dynamo.testing
 from torch._dynamo import config
 from torch._dynamo.testing import unsupported
-from torch._dynamo.utils import disable_cache_limit, ifunspec
+from torch._dynamo.utils import disable_cache_limit, ifdyn, ifdynstaticdefault, ifunspec
 
 globalmod = torch.nn.ReLU()
 
@@ -313,7 +313,7 @@
             return a * x + len_(b)
 
         if config.dynamic_shapes:
-            self._common(fn, 2, 5)
+            self._common(fn, 2, ifdynstaticdefault(4, 5))
         else:
             self._common(fn, 2, 4)
 
@@ -630,7 +630,7 @@
                 b = b + x * i
             return b
 
-        self._common(fn, 1, 2)
+        self._common(fn, 1, ifdyn(ifdynstaticdefault(2, 7), 2))
 
 
 if __name__ == "__main__":
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index 874b033..65c730d 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -8,7 +8,7 @@
 
 import torch.fx
 import torch.random
-from torch.fx.experimental.symbolic_shapes import guard_scalar, SymTypes
+from torch.fx.experimental.symbolic_shapes import free_symbols, guard_scalar, SymTypes
 
 from .. import config, utils, variables
 from ..bytecode_transformation import create_call_function, Instruction
@@ -133,8 +133,14 @@
             "is_sparse": value.is_sparse,
             "class_type": type(value),
         }
-        if not config.dynamic_shapes:
-            props["size"] = tuple(value.size())
+        if not free_symbols(value):
+            # this is a fully static shape, and the keys on props here inform specialization.
+            # We have to cast to int here, because these might get accessed as ConstantVariable, which has
+            # a strict no-symint policy. If we got here due to not having free symbols, this is a known constant
+            # already. We could remove the discrepancy here, by having ConstantVariable be more permissive for
+            # constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and
+            # I'd like to keep it around for now.
+            props["size"] = tuple([int(s) for s in value.size()])
             props["stride"] = tuple(value.stride())
             props["is_contiguous"] = tuple(
                 [
@@ -285,14 +291,6 @@
                 dim = kwargs.pop("dim")
                 constant_result = constant_result.getitem_const(dim)
 
-        elif name == "size" and self.size is not None:
-            sizes = [variables.ConstantVariable(x) for x in self.size]
-            constant_result = SizeVariable(sizes, **options)
-
-            if "dim" in kwargs:
-                dim = kwargs.pop("dim")
-                constant_result = constant_result.getitem_const(dim)
-
         elif name == "size" and self.size is None and config.dynamic_shapes:
             return wrap_fx_proxy(
                 tx,
@@ -303,6 +301,14 @@
                 ),
                 **options,
             )
+        elif name == "size" and self.size is not None:
+            sizes = [variables.ConstantVariable(x) for x in self.size]
+            constant_result = SizeVariable(sizes, **options)
+
+            if "dim" in kwargs:
+                dim = kwargs.pop("dim")
+                constant_result = constant_result.getitem_const(dim)
+
         elif name in ("numel", "nelement") and self.size is not None:
             constant_result = ConstantVariable(product(self.size), **options)
         elif name in ("ndimension", "dim") and self.ndim is not None: