[BE] add parentheses to kwargs unpacking `func(*args, **(kwargs or {}))` (#115026)

This PR adds parentheses to kwargs unpacking `func(*args, **(kwargs or {}))` for better code readability.

With/without the parentheses are semantic equivalent because they produce the same bytecode.

```console
$ echo "func(*args, **kwargs or {})" | python3 -m dis -
  0           0 RESUME                   0

  1           2 PUSH_NULL
              4 LOAD_NAME                0 (func)
              6 LOAD_NAME                1 (args)
              8 BUILD_MAP                0
             10 LOAD_NAME                2 (kwargs)
             12 JUMP_IF_TRUE_OR_POP      1 (to 16)
             14 BUILD_MAP                0
        >>   16 DICT_MERGE               1
             18 CALL_FUNCTION_EX         1
             20 POP_TOP
             22 LOAD_CONST               0 (None)
             24 RETURN_VALUE

$ echo "func(*args, **(kwargs or {}))" | python3 -m dis -
  0           0 RESUME                   0

  1           2 PUSH_NULL
              4 LOAD_NAME                0 (func)
              6 LOAD_NAME                1 (args)
              8 BUILD_MAP                0
             10 LOAD_NAME                2 (kwargs)
             12 JUMP_IF_TRUE_OR_POP      1 (to 16)
             14 BUILD_MAP                0
        >>   16 DICT_MERGE               1
             18 CALL_FUNCTION_EX         1
             20 POP_TOP
             22 LOAD_CONST               0 (None)
             24 RETURN_VALUE
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115026
Approved by: https://github.com/Skylion007
diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst
index 729ba13..8079637 100644
--- a/docs/source/notes/extending.rst
+++ b/docs/source/notes/extending.rst
@@ -923,12 +923,12 @@
   class FunctionLog(TorchFunctionMode):
       def __torch_function__(self, func, types, args, kwargs=None):
           print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
-          return func(*args, **kwargs or {})
+          return func(*args, **(kwargs or {}))
 
   class DispatchLog(TorchDispatchMode):
       def __torch_dispatch__(self, func, types, args, kwargs=None):
           print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
-          return func(*args, **kwargs or {})
+          return func(*args, **(kwargs or {}))
 
   def f():
       a = torch.rand(10, requires_grad=True)
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 0c8fe14..97091ef 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -4143,7 +4143,7 @@
                 # Don't use node.name() here as it is not consistent on windows
                 node_name = node.__class__.__name__ if node else "None"
                 pr.append(f"Running {func} from within {node_name}")
-                return func(*args, **kwargs or {})
+                return func(*args, **(kwargs or {}))
 
         with MyMode():
             pr.append("FW")
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index dc53de0..334dabb 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -1147,7 +1147,7 @@
         constraints,
         disable_constraint_solver=disable_constraint_solver
     )
-    flat_example_inputs = pytree.arg_tree_leaves(*args, **kwargs or {})
+    flat_example_inputs = pytree.arg_tree_leaves(*args, **(kwargs or {}))
 
     with torch.no_grad():
         so_path = torch._inductor.aot_compile(gm, flat_example_inputs, options)  # type: ignore[arg-type]
diff --git a/torch/_ops.py b/torch/_ops.py
index 828c605..8998b40 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -7,7 +7,7 @@
 from typing import Any, Callable, Dict, List, Type, Union
 
 import torch._C
-
+import torch.utils._pytree as pytree
 from torch import _utils_internal
 from torch._functorch.pyfunctorch import dispatch_functorch
 
@@ -369,7 +369,7 @@
 
 
 def _to_flat_tuple(args, kwargs):
-    return torch.utils._pytree.arg_tree_leaves(*args, **kwargs)
+    return pytree.arg_tree_leaves(*args, **kwargs)
 
 
 def _compute_keyset(args, kwargs, non_fallthrough_keys):
@@ -506,7 +506,7 @@
         )
 
     def __call__(self, *args, **kwargs):
-        return self._op(*args, **kwargs or {})
+        return self._op(*args, **(kwargs or {}))
 
     def __hash__(self):
         return hash(self._op)
@@ -601,9 +601,7 @@
                     with temporarily_pop_mode(curr_stack) as curr_mode:
                         assert hasattr(curr_mode, "__torch_dispatch__")
                         overload_types = []
-                        args_flattened, _ = torch.utils._pytree.tree_flatten(
-                            (args, kwargs.values())
-                        )
+                        args_flattened = pytree.arg_tree_leaves(*args, **kwargs)
                         for a in args_flattened:
                             # TODO: need to double check the semantics of the "types" argument to torch_dispatch.
                             # It's generated in PyInterpreter.cpp, but seems to be generated in two places,
@@ -750,7 +748,7 @@
         # is still callable from JIT
         # We save the function ptr as the `op` attribute on
         # OpOverloadPacket to access it here.
-        return self._op(*args, **kwargs or {})
+        return self._op(*args, **(kwargs or {}))
 
     # TODO: use this to make a __dir__
     def overloads(self):
diff --git a/torch/testing/_internal/common_subclass.py b/torch/testing/_internal/common_subclass.py
index 28d1ece..429fcaf 100644
--- a/torch/testing/_internal/common_subclass.py
+++ b/torch/testing/_internal/common_subclass.py
@@ -73,7 +73,7 @@
         # For everything else, call the handler:
         fn = cls.handled_ops.get(func.__name__, None)
         if fn:
-            return fn(*args, **kwargs or {})
+            return fn(*args, **(kwargs or {}))
         else:
             # Note that here, because we don't need to provide the autograd formulas
             # we can have a default "fallback" that creates a plain Tensor based
diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py
index 764e5d9..6562a96 100644
--- a/torch/testing/_internal/composite_compliance.py
+++ b/torch/testing/_internal/composite_compliance.py
@@ -159,7 +159,7 @@
 
         @classmethod
         def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
-            all_args = pytree.arg_tree_leaves(*args, **kwargs or {})
+            all_args = pytree.arg_tree_leaves(*args, **(kwargs or {}))
             modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor))
             if not all_same_mode(modes):
                 raise RuntimeError("Multiple CompositeCompliantTensorModes NYI")