Remove config check in specialize (#102098)
Fixes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102098
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py
index 4d422d7..ed4bf62 100644
--- a/test/dynamo/test_dynamic_shapes.py
+++ b/test/dynamo/test_dynamic_shapes.py
@@ -36,9 +36,6 @@
# Could not infer dtype of torch._C.SymIntNode
"test_convert_boxes_to_pooler_format",
],
- "SubGraphTests": [
- "test_enumerate_not_break_graph",
- ],
}
XFAIL_HITS = 0
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index ab0b22c..4e688cd 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -163,7 +163,7 @@
unpack4,
2,
expected_ops=5,
- expected_ops_dynamic=ifdynstaticdefault(6, 7),
+ expected_ops_dynamic=ifdynstaticdefault(5, 7),
)
def test_unpack5(self):
@@ -180,7 +180,7 @@
unpack5,
2,
expected_ops=5,
- expected_ops_dynamic=ifdynstaticdefault(6, 7),
+ expected_ops_dynamic=ifdynstaticdefault(5, 7),
)
def test_matmul1(self):
@@ -204,7 +204,7 @@
return x + y
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 11)
+ self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 11)
)
def test_shape_int_inplace_binops(self):
@@ -220,7 +220,7 @@
return x + p
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 10)
+ self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 10)
)
def test_int_shape_inplace_binops(self):
@@ -244,7 +244,7 @@
return x + y
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 10)
+ self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 10)
)
def test_int_int_comparisons(self):
@@ -289,7 +289,7 @@
# expect for dynamic: size, index, 6 comparison ops, add
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 9)
+ self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 9)
)
def test_int_shape_comparisons(self):
@@ -314,7 +314,7 @@
# expect for dynamic: size, index, 6 comparison ops, add
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 9)
+ self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 9)
)
def test_param_shape_binops(self):
@@ -348,7 +348,7 @@
self.assertEqual(counts.frame_count, 1)
expected_op_count = (
- ifdynstaticdefault(3, 12)
+ ifdynstaticdefault(1, 11)
if torch._dynamo.testing.config.dynamic_shapes
else 1
)
@@ -377,7 +377,7 @@
self.assertTrue(same(ref, res))
self.assertEqual(counts.frame_count, 1)
expected_op_count = (
- ifdynstaticdefault(2, 4)
+ ifdynstaticdefault(1, 4)
if torch._dynamo.testing.config.dynamic_shapes
else 1
)
@@ -504,7 +504,11 @@
# expect extra size node for dynamic
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=20, expected_ops_dynamic=21
+ self,
+ fn,
+ 1,
+ expected_ops=20,
+ expected_ops_dynamic=ifdynstaticdefault(20, 21),
)
def test_empty_list(self):
@@ -538,14 +542,14 @@
get_test_fn(func=min),
2,
expected_ops=1,
- expected_ops_dynamic=ifdynstaticdefault(3, 14),
+ expected_ops_dynamic=ifdynstaticdefault(1, 14),
)
torch._dynamo.testing.standard_test(
self,
get_test_fn(func=max),
2,
expected_ops=1,
- expected_ops_dynamic=ifdynstaticdefault(3, 17),
+ expected_ops_dynamic=ifdynstaticdefault(1, 17),
)
def test_config_obj(self):
@@ -786,7 +790,11 @@
return (a + a.numel() + torch.numel(a), a + a.nelement())
return torch._dynamo.testing.standard_test(
- self, fn=fn, nargs=1, expected_ops=3, expected_ops_dynamic=6
+ self,
+ fn=fn,
+ nargs=1,
+ expected_ops=3,
+ expected_ops_dynamic=ifdynstaticdefault(3, 6),
)
def test_pair(self):
@@ -802,7 +810,7 @@
fn=fn,
nargs=1,
expected_ops=5,
- expected_ops_dynamic=ifdynstaticdefault(6, 8),
+ expected_ops_dynamic=ifdynstaticdefault(5, 8),
)
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@@ -901,7 +909,11 @@
# expect 1 more op (size call) for dynamic
return torch._dynamo.testing.standard_test(
- self, fn=fn, nargs=1, expected_ops=9, expected_ops_dynamic=10
+ self,
+ fn=fn,
+ nargs=1,
+ expected_ops=9,
+ expected_ops_dynamic=ifdynstaticdefault(9, 10),
)
def test_build_tuple_unpack(self):
@@ -966,7 +978,7 @@
# expect 3 ops post folding for dynamic case: size, index, add
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 3)
+ self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 3)
)
def test_tuple_iadd_with_shape(self):
@@ -980,7 +992,7 @@
# expect 4 add / subs for static, 4 * 3 (size, index, math op) for dynamic
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=4, expected_ops_dynamic=ifdynstaticdefault(8, 12)
+ self, fn, 1, expected_ops=4, expected_ops_dynamic=ifdynstaticdefault(4, 12)
)
def test_list_iadd_with_shape(self):
@@ -995,7 +1007,7 @@
# expect 6 add / subs for static, 6 * 3 (size, index, math op) for dynamic
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(12, 18)
+ self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(6, 18)
)
def test_user_getattr1(self):
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index 036d981..cd5179d 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -890,7 +890,7 @@
self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0))
self.assertEqual(cnt.frame_count, 1)
- self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(3, 6), 1))
+ self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(1, 6), 1))
def _reformer(self, nopython):
input = torch.randn([1, 64, 256])
@@ -964,8 +964,8 @@
with torch.enable_grad():
cnt = self._reformer(nopython=False)
# cant inline torch.autograd.Function means graph break
- self.assertEqual(cnt.frame_count, 3)
- self.assertEqual(cnt.op_count, 10)
+ self.assertEqual(cnt.frame_count, ifunspec(ifdyn(3, 1), 3))
+ self.assertEqual(cnt.op_count, ifunspec(ifdyn(10, 11), 10))
def test_longformer_chunk(self):
input1 = torch.randn([1, 4096, 1])
@@ -981,7 +981,9 @@
self.assertTrue(same(opt_fn(input2), correct2))
self.assertEqual(cnt.frame_count, 2)
- self.assertEqual(cnt.op_count, ifunspec(37, ifdyn(20, 4)))
+ self.assertEqual(
+ cnt.op_count, ifunspec(37, ifdyn(ifdynstaticdefault(15, 20), 4))
+ )
def test_hf_t5_forward(self):
input = torch.randn([1, 2048, 512])
@@ -992,7 +994,7 @@
self.assertTrue(same(opt_model(input), correct))
self.assertEqual(cnt.frame_count, 1)
- self.assertEqual(cnt.op_count, ifdyn(12, 11))
+ self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(11, 12), 11))
def test_module_in_skipfiles(self):
model = nn.Linear(10, 10)
@@ -1089,7 +1091,7 @@
self.assertTrue(same(opt_model(a, b, c, d), correct))
if torch._dynamo.config.assume_static_by_default:
- self.assertEqual(cnt.frame_count, 5)
+ self.assertEqual(cnt.frame_count, 4)
else:
self.assertEqual(cnt.frame_count, 6)
@@ -1282,7 +1284,7 @@
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
self.assertTrue(same(opt_fn(x), correct))
self.assertEqual(cnt.frame_count, 1)
- self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(21, 27), 14))
+ self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(14, 27), 14))
def test_recursive_map(self):
# https://github.com/pytorch/torchdynamo/issues/132
@@ -2829,6 +2831,7 @@
self.assertRaises(AttributeError, lambda: fn(x, obj1))
@torch._dynamo.config.patch("dynamic_shapes", True)
+ @torch._dynamo.config.patch("automatic_dynamic_shapes", False)
def test_dynamic_shapes_implicit_guard(self):
def f(x):
y = x * x.size(x.shape[0])
diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py
index b832a63..32c0c48 100644
--- a/test/dynamo/test_subgraphs.py
+++ b/test/dynamo/test_subgraphs.py
@@ -8,7 +8,7 @@
import torch._dynamo.testing
from torch._dynamo import config
from torch._dynamo.testing import unsupported
-from torch._dynamo.utils import disable_cache_limit, ifunspec
+from torch._dynamo.utils import disable_cache_limit, ifdyn, ifdynstaticdefault, ifunspec
globalmod = torch.nn.ReLU()
@@ -313,7 +313,7 @@
return a * x + len_(b)
if config.dynamic_shapes:
- self._common(fn, 2, 5)
+ self._common(fn, 2, ifdynstaticdefault(4, 5))
else:
self._common(fn, 2, 4)
@@ -630,7 +630,7 @@
b = b + x * i
return b
- self._common(fn, 1, 2)
+ self._common(fn, 1, ifdyn(ifdynstaticdefault(2, 7), 2))
if __name__ == "__main__":
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index 874b033..65c730d 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -8,7 +8,7 @@
import torch.fx
import torch.random
-from torch.fx.experimental.symbolic_shapes import guard_scalar, SymTypes
+from torch.fx.experimental.symbolic_shapes import free_symbols, guard_scalar, SymTypes
from .. import config, utils, variables
from ..bytecode_transformation import create_call_function, Instruction
@@ -133,8 +133,14 @@
"is_sparse": value.is_sparse,
"class_type": type(value),
}
- if not config.dynamic_shapes:
- props["size"] = tuple(value.size())
+ if not free_symbols(value):
+ # this is a fully static shape, and the keys on props here inform specialization.
+ # We have to cast to int here, because these might get accessed as ConstantVariable, which has
+ # a strict no-symint policy. If we got here due to not having free symbols, this is a known constant
+ # already. We could remove the discrepancy here, by having ConstantVariable be more permissive for
+ # constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and
+ # I'd like to keep it around for now.
+ props["size"] = tuple([int(s) for s in value.size()])
props["stride"] = tuple(value.stride())
props["is_contiguous"] = tuple(
[
@@ -285,14 +291,6 @@
dim = kwargs.pop("dim")
constant_result = constant_result.getitem_const(dim)
- elif name == "size" and self.size is not None:
- sizes = [variables.ConstantVariable(x) for x in self.size]
- constant_result = SizeVariable(sizes, **options)
-
- if "dim" in kwargs:
- dim = kwargs.pop("dim")
- constant_result = constant_result.getitem_const(dim)
-
elif name == "size" and self.size is None and config.dynamic_shapes:
return wrap_fx_proxy(
tx,
@@ -303,6 +301,14 @@
),
**options,
)
+ elif name == "size" and self.size is not None:
+ sizes = [variables.ConstantVariable(x) for x in self.size]
+ constant_result = SizeVariable(sizes, **options)
+
+ if "dim" in kwargs:
+ dim = kwargs.pop("dim")
+ constant_result = constant_result.getitem_const(dim)
+
elif name in ("numel", "nelement") and self.size is not None:
constant_result = ConstantVariable(product(self.size), **options)
elif name in ("ndimension", "dim") and self.ndim is not None: