[BE] add parentheses to kwargs unpacking `func(*args, **(kwargs or {}))` (#115026)
This PR adds parentheses to kwargs unpacking `func(*args, **(kwargs or {}))` for better code readability.
With/without the parentheses are semantic equivalent because they produce the same bytecode.
```console
$ echo "func(*args, **kwargs or {})" | python3 -m dis -
0 0 RESUME 0
1 2 PUSH_NULL
4 LOAD_NAME 0 (func)
6 LOAD_NAME 1 (args)
8 BUILD_MAP 0
10 LOAD_NAME 2 (kwargs)
12 JUMP_IF_TRUE_OR_POP 1 (to 16)
14 BUILD_MAP 0
>> 16 DICT_MERGE 1
18 CALL_FUNCTION_EX 1
20 POP_TOP
22 LOAD_CONST 0 (None)
24 RETURN_VALUE
$ echo "func(*args, **(kwargs or {}))" | python3 -m dis -
0 0 RESUME 0
1 2 PUSH_NULL
4 LOAD_NAME 0 (func)
6 LOAD_NAME 1 (args)
8 BUILD_MAP 0
10 LOAD_NAME 2 (kwargs)
12 JUMP_IF_TRUE_OR_POP 1 (to 16)
14 BUILD_MAP 0
>> 16 DICT_MERGE 1
18 CALL_FUNCTION_EX 1
20 POP_TOP
22 LOAD_CONST 0 (None)
24 RETURN_VALUE
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115026
Approved by: https://github.com/Skylion007
diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst
index 729ba13..8079637 100644
--- a/docs/source/notes/extending.rst
+++ b/docs/source/notes/extending.rst
@@ -923,12 +923,12 @@
class FunctionLog(TorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
- return func(*args, **kwargs or {})
+ return func(*args, **(kwargs or {}))
class DispatchLog(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
- return func(*args, **kwargs or {})
+ return func(*args, **(kwargs or {}))
def f():
a = torch.rand(10, requires_grad=True)
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 0c8fe14..97091ef 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -4143,7 +4143,7 @@
# Don't use node.name() here as it is not consistent on windows
node_name = node.__class__.__name__ if node else "None"
pr.append(f"Running {func} from within {node_name}")
- return func(*args, **kwargs or {})
+ return func(*args, **(kwargs or {}))
with MyMode():
pr.append("FW")
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index dc53de0..334dabb 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -1147,7 +1147,7 @@
constraints,
disable_constraint_solver=disable_constraint_solver
)
- flat_example_inputs = pytree.arg_tree_leaves(*args, **kwargs or {})
+ flat_example_inputs = pytree.arg_tree_leaves(*args, **(kwargs or {}))
with torch.no_grad():
so_path = torch._inductor.aot_compile(gm, flat_example_inputs, options) # type: ignore[arg-type]
diff --git a/torch/_ops.py b/torch/_ops.py
index 828c605..8998b40 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -7,7 +7,7 @@
from typing import Any, Callable, Dict, List, Type, Union
import torch._C
-
+import torch.utils._pytree as pytree
from torch import _utils_internal
from torch._functorch.pyfunctorch import dispatch_functorch
@@ -369,7 +369,7 @@
def _to_flat_tuple(args, kwargs):
- return torch.utils._pytree.arg_tree_leaves(*args, **kwargs)
+ return pytree.arg_tree_leaves(*args, **kwargs)
def _compute_keyset(args, kwargs, non_fallthrough_keys):
@@ -506,7 +506,7 @@
)
def __call__(self, *args, **kwargs):
- return self._op(*args, **kwargs or {})
+ return self._op(*args, **(kwargs or {}))
def __hash__(self):
return hash(self._op)
@@ -601,9 +601,7 @@
with temporarily_pop_mode(curr_stack) as curr_mode:
assert hasattr(curr_mode, "__torch_dispatch__")
overload_types = []
- args_flattened, _ = torch.utils._pytree.tree_flatten(
- (args, kwargs.values())
- )
+ args_flattened = pytree.arg_tree_leaves(*args, **kwargs)
for a in args_flattened:
# TODO: need to double check the semantics of the "types" argument to torch_dispatch.
# It's generated in PyInterpreter.cpp, but seems to be generated in two places,
@@ -750,7 +748,7 @@
# is still callable from JIT
# We save the function ptr as the `op` attribute on
# OpOverloadPacket to access it here.
- return self._op(*args, **kwargs or {})
+ return self._op(*args, **(kwargs or {}))
# TODO: use this to make a __dir__
def overloads(self):
diff --git a/torch/testing/_internal/common_subclass.py b/torch/testing/_internal/common_subclass.py
index 28d1ece..429fcaf 100644
--- a/torch/testing/_internal/common_subclass.py
+++ b/torch/testing/_internal/common_subclass.py
@@ -73,7 +73,7 @@
# For everything else, call the handler:
fn = cls.handled_ops.get(func.__name__, None)
if fn:
- return fn(*args, **kwargs or {})
+ return fn(*args, **(kwargs or {}))
else:
# Note that here, because we don't need to provide the autograd formulas
# we can have a default "fallback" that creates a plain Tensor based
diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py
index 764e5d9..6562a96 100644
--- a/torch/testing/_internal/composite_compliance.py
+++ b/torch/testing/_internal/composite_compliance.py
@@ -159,7 +159,7 @@
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
- all_args = pytree.arg_tree_leaves(*args, **kwargs or {})
+ all_args = pytree.arg_tree_leaves(*args, **(kwargs or {}))
modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor))
if not all_same_mode(modes):
raise RuntimeError("Multiple CompositeCompliantTensorModes NYI")