Produce constant variables in cases where a SymNode is created with a constant (#100144)
` AOT_DYNAMIC_SHAPES=1 TORCHDYNAMO_DYNAMIC_SHAPES=1 benchmarks/dynamo/huggingface.py --performance --training --amp --backend eager --disable-cudagraphs --device cuda --only AllenaiLongformerBase --explain`
Looks promising!
Goes from:
Dynamo produced 173 graphs covering 2760 ops with 160 graph breaks (14 unique)
To:
Dynamo produced 6 graphs covering 2298 ops with 15 graph breaks (7 unique)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100144
Approved by: https://github.com/ezyang
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_inference.csv
index b641f92..826ceb1 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_inference.csv
@@ -1,7 +1,7 @@
name,accuracy,graph_breaks
AlbertForMaskedLM,pass,0
AlbertForQuestionAnswering,pass,0
-AllenaiLongformerBase,pass,152
+AllenaiLongformerBase,pass,136
BartForCausalLM,pass,0
BertForMaskedLM,pass,0
BertForQuestionAnswering,pass,0
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_training.csv
index 51f19ca..e8fb2a4 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_dynamic_training.csv
@@ -1,7 +1,7 @@
name,accuracy,graph_breaks
AlbertForMaskedLM,pass,7
AlbertForQuestionAnswering,pass,7
-AllenaiLongformerBase,pass,160
+AllenaiLongformerBase,pass,144
BartForCausalLM,pass,7
BertForMaskedLM,pass,7
BertForQuestionAnswering,pass,7
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index d6f6805..e94d721 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -34,7 +34,7 @@
unsupported,
)
-from torch._dynamo.utils import CompileProfiler, ifdyn, ifunspec
+from torch._dynamo.utils import CompileProfiler, ifdyn, ifdynstaticdefault, ifunspec
from torch.ao.quantization import MinMaxObserver
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.qconfig import QConfig
@@ -154,7 +154,11 @@
return o
torch._dynamo.testing.standard_test(
- self, unpack4, 2, expected_ops=5, expected_ops_dynamic=8
+ self,
+ unpack4,
+ 2,
+ expected_ops=5,
+ expected_ops_dynamic=ifdynstaticdefault(6, 7),
)
def test_unpack5(self):
@@ -167,7 +171,11 @@
return o
torch._dynamo.testing.standard_test(
- self, unpack5, 2, expected_ops=5, expected_ops_dynamic=8
+ self,
+ unpack5,
+ 2,
+ expected_ops=5,
+ expected_ops_dynamic=ifdynstaticdefault(6, 7),
)
def test_matmul1(self):
@@ -191,7 +199,7 @@
return x + y
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=1, expected_ops_dynamic=11
+ self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 11)
)
def test_shape_int_inplace_binops(self):
@@ -207,7 +215,7 @@
return x + p
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=1, expected_ops_dynamic=10
+ self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 10)
)
def test_int_shape_inplace_binops(self):
@@ -231,7 +239,7 @@
return x + y
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=1, expected_ops_dynamic=10
+ self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 10)
)
def test_int_int_comparisons(self):
@@ -276,7 +284,7 @@
# expect for dynamic: size, index, 6 comparison ops, add
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=1, expected_ops_dynamic=9
+ self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 9)
)
def test_int_shape_comparisons(self):
@@ -301,7 +309,7 @@
# expect for dynamic: size, index, 6 comparison ops, add
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=1, expected_ops_dynamic=9
+ self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 9)
)
def test_param_shape_binops(self):
@@ -333,7 +341,12 @@
self.assertTrue(same(ref, res))
self.assertEqual(counts.frame_count, 1)
- expected_op_count = 13 if torch._dynamo.testing.config.dynamic_shapes else 1
+
+ expected_op_count = (
+ ifdynstaticdefault(3, 12)
+ if torch._dynamo.testing.config.dynamic_shapes
+ else 1
+ )
self.assertEqual(counts.op_count, expected_op_count)
def test_user_defined_binop(self):
@@ -358,7 +371,11 @@
self.assertTrue(same(ref, res))
self.assertEqual(counts.frame_count, 1)
- expected_op_count = 4 if torch._dynamo.testing.config.dynamic_shapes else 1
+ expected_op_count = (
+ ifdynstaticdefault(2, 4)
+ if torch._dynamo.testing.config.dynamic_shapes
+ else 1
+ )
self.assertEqual(counts.op_count, expected_op_count)
def test_compare_shapes_eq(self):
@@ -511,16 +528,19 @@
return _fn
- # expect for dynamic:
- # 2 * (size, getitem) ops +
- # 1 add op +
- # 4 * 2 min / max ops +
- # 4 final add ops = 17
torch._dynamo.testing.standard_test(
- self, get_test_fn(func=min), 2, expected_ops=1, expected_ops_dynamic=17
+ self,
+ get_test_fn(func=min),
+ 2,
+ expected_ops=1,
+ expected_ops_dynamic=ifdynstaticdefault(3, 14),
)
torch._dynamo.testing.standard_test(
- self, get_test_fn(func=max), 2, expected_ops=1, expected_ops_dynamic=17
+ self,
+ get_test_fn(func=max),
+ 2,
+ expected_ops=1,
+ expected_ops_dynamic=ifdynstaticdefault(3, 17),
)
def test_config_obj(self):
@@ -773,7 +793,11 @@
)
return torch._dynamo.testing.standard_test(
- self, fn=fn, nargs=1, expected_ops=5, expected_ops_dynamic=8
+ self,
+ fn=fn,
+ nargs=1,
+ expected_ops=5,
+ expected_ops_dynamic=ifdynstaticdefault(6, 8),
)
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
@@ -916,7 +940,7 @@
opt_fn = torch._dynamo.optimize(cnts)(fn)
self.assertEqual(opt_fn(2), [2, 3] * 4)
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
- self.assertEqual(cnts.op_count, ifunspec(14, 0))
+ self.assertEqual(cnts.op_count, ifunspec(2, 0))
def test_tuple_mul(self):
def fn(count):
@@ -927,7 +951,7 @@
opt_fn = torch._dynamo.optimize(cnts)(fn)
self.assertEqual(opt_fn(2), (2, 3) * 4)
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
- self.assertEqual(cnts.op_count, ifunspec(14, 0))
+ self.assertEqual(cnts.op_count, ifunspec(ifdynstaticdefault(2, 2), 0))
def test_tuple_mul_with_shape(self):
def fn(a):
@@ -937,7 +961,7 @@
# expect 3 ops post folding for dynamic case: size, index, add
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=1, expected_ops_dynamic=3
+ self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 3)
)
def test_tuple_iadd_with_shape(self):
@@ -951,7 +975,7 @@
# expect 4 add / subs for static, 4 * 3 (size, index, math op) for dynamic
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=4, expected_ops_dynamic=12
+ self, fn, 1, expected_ops=4, expected_ops_dynamic=ifdynstaticdefault(8, 12)
)
def test_list_iadd_with_shape(self):
@@ -964,8 +988,9 @@
return output
# expect 6 add / subs for static, 6 * 3 (size, index, math op) for dynamic
+
torch._dynamo.testing.standard_test(
- self, fn, 1, expected_ops=6, expected_ops_dynamic=18
+ self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(12, 18)
)
def test_user_getattr1(self):
diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py
index a52b87b..95f39ec 100644
--- a/test/dynamo/test_recompile_ux.py
+++ b/test/dynamo/test_recompile_ux.py
@@ -163,7 +163,7 @@
cache_fail_test(
a,
a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)),
- "tensor 'L['a']' strides mismatch at index 0. expected 20, actual 1",
+ "tensor 'L['a']' stride mismatch at index 0. expected 20, actual 1",
)
cache_fail_test(
a, a[0, :, :], "tensor 'L['a']' rank mismatch. expected 3, actual 2"
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index 00f545b..386f70d 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -31,7 +31,7 @@
from torch import nn
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import rand_strided, requires_static_shapes, same
-from torch._dynamo.utils import ifdyn, ifunspec
+from torch._dynamo.utils import ifdyn, ifdynstaticdefault, ifunspec
from torch.nn import functional as F
@@ -890,7 +890,7 @@
self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0))
self.assertEqual(cnt.frame_count, 1)
- self.assertEqual(cnt.op_count, ifdyn(6, 1))
+ self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(3, 6), 1))
def _reformer(self, nopython):
input = torch.randn([1, 64, 256])
@@ -981,7 +981,7 @@
self.assertTrue(same(opt_fn(input2), correct2))
self.assertEqual(cnt.frame_count, 2)
- self.assertEqual(cnt.op_count, ifunspec(42, ifdyn(38, 4)))
+ self.assertEqual(cnt.op_count, ifunspec(37, ifdyn(20, 4)))
def test_hf_t5_forward(self):
input = torch.randn([1, 2048, 512])
@@ -992,7 +992,7 @@
self.assertTrue(same(opt_model(input), correct))
self.assertEqual(cnt.frame_count, 1)
- self.assertEqual(cnt.op_count, ifdyn(13, 11))
+ self.assertEqual(cnt.op_count, ifdyn(12, 11))
def test_module_in_skipfiles(self):
model = nn.Linear(10, 10)
@@ -1283,7 +1283,7 @@
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
self.assertTrue(same(opt_fn(x), correct))
self.assertEqual(cnt.frame_count, 1)
- self.assertEqual(cnt.op_count, ifdyn(28, 14))
+ self.assertEqual(cnt.op_count, ifdyn(ifdynstaticdefault(21, 27), 14))
def test_recursive_map(self):
# https://github.com/pytorch/torchdynamo/issues/132
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index 4894578..dd766a0 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -1368,6 +1368,13 @@
return count2
+def ifdynstaticdefault(count1, count2):
+ if torch._dynamo.config.assume_static_by_default:
+ return count1
+ else:
+ return count2
+
+
def ifunspec(count1, count2):
if torch._dynamo.config.dynamic_shapes and not torch._dynamo.config.specialize_int:
return count1
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 3afe664..1183d6d 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -901,7 +901,10 @@
)
)
fake_tensor_value = None
- example_value = unspec_var.proxy.node.meta["example_value"]
+ if isinstance(unspec_var, ConstantVariable):
+ example_value = unspec_var.value
+ else:
+ example_value = unspec_var.proxy.node.meta["example_value"]
if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor):
fake_tensor_value = example_value
proxy.node.meta["grapharg"] = GraphArg(
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index d8dfd77..f92ab23 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -1263,11 +1263,16 @@
sym_num=None,
)
+ if isinstance(left, ConstantVariable) and isinstance(right, ConstantVariable):
+ return ConstantVariable(op(left.value, right.value))
+
_unimplemented()
# and_ is a constant fold function, so we only get here if constant fold is not valid
def call_and_(self, tx, a, b):
- if isinstance(a, SymNodeVariable) and isinstance(b, SymNodeVariable):
+ if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance(
+ b, (SymNodeVariable, ConstantVariable)
+ ):
return SymNodeVariable.create(
tx,
tx.output.create_proxy(
@@ -1280,7 +1285,9 @@
# or_ is a constant fold function, so we only get here if constant fold is not valid
def call_or_(self, tx, a, b):
- if isinstance(a, SymNodeVariable) and isinstance(b, SymNodeVariable):
+ if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance(
+ b, (SymNodeVariable, ConstantVariable)
+ ):
return SymNodeVariable.create(
tx,
tx.output.create_proxy(
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index 67a9bbb..3e5a7b0 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -4,6 +4,8 @@
import types
from typing import Dict, List
+import sympy
+
import torch.fx
import torch.random
from torch.fx.experimental.symbolic_shapes import guard_scalar, SymTypes
@@ -238,8 +240,13 @@
length = self.size[0]
else:
dyn_length = self.call_method(tx, "size", [ConstantVariable(0)], {})
- assert isinstance(dyn_length, SymNodeVariable)
- length = dyn_length.evaluate_expr(tx.output)
+ # SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values prouced through
+ # symbolic_shapes, but that end up as int/sympy.Integer
+ assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable))
+ if isinstance(dyn_length, SymNodeVariable):
+ length = dyn_length.evaluate_expr(tx.output)
+ else:
+ length = dyn_length.value
idxes = range(length)
return [wrap_fx_proxy(tx, self.as_proxy()[i], **options) for i in idxes]
@@ -495,6 +502,10 @@
if sym_num is None:
sym_num = get_fake_value(proxy.node, tx)
proxy.node.meta["example_value"] = sym_num
+
+ if isinstance(sym_num, (sympy.Integer, int)):
+ return ConstantVariable(int(sym_num))
+
return SymNodeVariable(proxy, sym_num, **options)
def __init__(self, proxy, sym_num, **kwargs):