fix dupe deprecated warning in dynamo export (#120896)
Summary:
When we convert `dynamic_shapes` to `constraints` and pass them to `_dynamo.export`, we shouldn't give a deprecation warning. Such conversion happens when calling `torch.export.export`, e.g. But it can also happen when calling `capture_pre_autograd_graph` (which itself has this deprecation warning when `constraints` are passed directly as well).
Since `_log_export_usage` is an indicator of a top-level call (it is `True` by default but set to `False`, or at least passed through, by callers), we can (ab)use it to indicate when to give this deprecation warning.
Test Plan: none
Differential Revision: D54350172
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120896
Approved by: https://github.com/BoyuanFeng, https://github.com/zhxchen17
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index 0a30b40..8dcd12e 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -1173,13 +1173,14 @@
def inner(*args, **kwargs):
nonlocal constraints
if constraints is not None:
- warnings.warn(
- "Using `constraints` to specify dynamic shapes for export is DEPRECATED "
- "and will not be supported in the future. "
- "Please use `dynamic_shapes` instead (see docs on `torch.export.export`).",
- DeprecationWarning,
- stacklevel=2,
- )
+ if _log_export_usage:
+ warnings.warn(
+ "Using `constraints` to specify dynamic shapes for export is DEPRECATED "
+ "and will not be supported in the future. "
+ "Please use `dynamic_shapes` instead (see docs on `torch.export.export`).",
+ DeprecationWarning,
+ stacklevel=2,
+ )
else:
constraints = _process_dynamic_shapes(_f, args, kwargs, dynamic_shapes)
f = _f
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index fa4c40b..3fd7a1f 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -173,6 +173,7 @@
decomposition_table=decomp_table,
pre_dispatch=True,
aten_graph=True,
+ _log_export_usage=False,
)(
*args,
**kwargs,
diff --git a/torch/export/_trace.py b/torch/export/_trace.py
index 8f8f0c1..c9c395b 100644
--- a/torch/export/_trace.py
+++ b/torch/export/_trace.py
@@ -322,7 +322,6 @@
if _log_export_usage:
log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"})
- constraints = constraints or []
kwargs = kwargs or {}
if not isinstance(args, tuple):