Require at least one tensor to be marked dynamic with --dynamic-batch-only (#99620)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99620
Approved by: https://github.com/voznesenskym
diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py
index 322d603..b653759 100644
--- a/benchmarks/dynamo/common.py
+++ b/benchmarks/dynamo/common.py
@@ -2412,14 +2412,19 @@
# NB: This must be done late enough so that we don't do more
# conversions on the inputs
# NB: Assumes only the first batch-y like dimension is the batch
+ marked = False
+
def detect_and_mark_batch(t):
+ nonlocal marked
for i, s in enumerate(t.size()):
if s == batch_size:
torch._dynamo.mark_dynamic(t, i)
+ marked = True
break
if args.dynamic_batch_only:
tree_map_only(torch.Tensor, detect_and_mark_batch, example_inputs)
+ assert marked, f"nothing in example_inputs had a dim with {batch_size}"
if args.log_operator_inputs:
log_operator_inputs(
diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py
index 475581e..31d7791 100755
--- a/benchmarks/dynamo/torchbench.py
+++ b/benchmarks/dynamo/torchbench.py
@@ -282,6 +282,10 @@
if self.args.accuracy and model_name in MAX_BATCH_SIZE_FOR_ACCURACY_CHECK:
batch_size = min(batch_size, MAX_BATCH_SIZE_FOR_ACCURACY_CHECK[model_name])
+ # See https://github.com/pytorch/benchmark/issues/1560
+ if model_name == "speech_transformer":
+ batch_size = 10
+
# workaround "RuntimeError: not allowed to set torch.backends.cudnn flags"
torch.backends.__allow_nonbracketed_mutation_flag = True
extra_args = []
@@ -317,6 +321,10 @@
# the right example_inputs
if model_name == "yolov3":
example_inputs = (torch.rand(batch_size, 3, 384, 512).to(device),)
+ # See https://github.com/pytorch/benchmark/issues/1561
+ if model_name == "maml_omniglot":
+ batch_size = 5
+ assert example_inputs[0].shape[0] == batch_size
# global current_name, current_device
# current_device = device
# current_name = benchmark.name