Revert "[pytree] reorganize submodule structure for C++ and Python pytree (#112278)"

This reverts commit fa895da968ec6f1ae128ee95fcb96ba9addac8a0.

Reverted https://github.com/pytorch/pytorch/pull/112278 on behalf of https://github.com/PaliC due to in the bottom diff in the stack changing _register_pytree_node's signature is bc breaking, please revert the signature and reland ([comment](https://github.com/pytorch/pytorch/pull/112278#issuecomment-1804870560))
diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py
index 2fe1811..0a820f0 100644
--- a/test/profiler/test_profiler_tree.py
+++ b/test/profiler/test_profiler_tree.py
@@ -22,7 +22,7 @@
 KEEP_NAME_AND_ELLIPSES = 3
 
 PRUNE_FUNCTIONS = {
-    "torch/utils/_pytree/api/python.py(...): tree_map": KEEP_NAME_AND_ELLIPSES,
+    "torch/utils/_pytree.py(...): tree_map": KEEP_NAME_AND_ELLIPSES,
     "torch/profiler/profiler.py(...): start": KEEP_ELLIPSES,
     "torch/profiler/profiler.py(...): stop_trace": KEEP_ELLIPSES,
     "torch/profiler/profiler.py(...): _transit_action": KEEP_ELLIPSES,
@@ -699,14 +699,14 @@
                 ...
               aten::add
                 test_profiler_tree.py(...): __torch_dispatch__
-                  torch/utils/_pytree/api/python.py(...): tree_map
+                  torch/utils/_pytree.py(...): tree_map
                     ...
-                  torch/utils/_pytree/api/python.py(...): tree_map
+                  torch/utils/_pytree.py(...): tree_map
                     ...
                   torch/_ops.py(...): __call__
                     <built-in method  of PyCapsule object at 0xXXXXXXXXXXXX>
                       aten::add
-                  torch/utils/_pytree/api/python.py(...): tree_map
+                  torch/utils/_pytree.py(...): tree_map
                     ...
               torch/profiler/profiler.py(...): __exit__
                 torch/profiler/profiler.py(...): stop
diff --git a/test/test_pytree.py b/test/test_pytree.py
index d23cbd5..0c01203 100644
--- a/test/test_pytree.py
+++ b/test/test_pytree.py
@@ -4,8 +4,8 @@
 from collections import namedtuple, OrderedDict
 
 import torch
-import torch.utils._pytree.api.cxx as cxx_pytree
-import torch.utils._pytree.api.python as py_pytree
+import torch.utils._cxx_pytree as cxx_pytree
+import torch.utils._pytree as py_pytree
 from torch.testing._internal.common_utils import (
     instantiate_parametrized_tests,
     parametrize,
diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py
index e0fc1b4..36996a3 100644
--- a/torch/distributed/_functional_collectives.py
+++ b/torch/distributed/_functional_collectives.py
@@ -12,7 +12,7 @@
 from torch._custom_ops import impl_abstract
 
 try:
-    from torch.utils._pytree.api.cxx import tree_map_only
+    from torch.utils._cxx_pytree import tree_map_only
 except ImportError:
     from torch.utils._pytree import tree_map_only  # type: ignore[no-redef]
 
diff --git a/torch/distributed/_tensor/dispatch.py b/torch/distributed/_tensor/dispatch.py
index 3ac7034..7486f20 100644
--- a/torch/distributed/_tensor/dispatch.py
+++ b/torch/distributed/_tensor/dispatch.py
@@ -23,9 +23,9 @@
 from torch.distributed._tensor.sharding_prop import ShardingPropagator
 
 try:
-    from torch.utils._pytree.api import cxx as pytree
+    from torch.utils import _cxx_pytree as pytree
 except ImportError:
-    from torch.utils._pytree.api import python as pytree  # type: ignore[no-redef]
+    from torch.utils import _pytree as pytree  # type: ignore[no-redef]
 
 aten = torch.ops.aten
 
diff --git a/torch/distributed/_tensor/op_schema.py b/torch/distributed/_tensor/op_schema.py
index 2217756..e2cecba 100644
--- a/torch/distributed/_tensor/op_schema.py
+++ b/torch/distributed/_tensor/op_schema.py
@@ -7,7 +7,7 @@
 from torch.distributed._tensor.placement_types import DTensorSpec
 
 try:
-    from torch.utils._pytree.api.cxx import tree_map_only, TreeSpec
+    from torch.utils._cxx_pytree import tree_map_only, TreeSpec
 except ImportError:
     from torch.utils._pytree import (  # type: ignore[no-redef, assignment]
         tree_map_only,
diff --git a/torch/utils/_pytree/api/cxx.py b/torch/utils/_cxx_pytree.py
similarity index 97%
rename from torch/utils/_pytree/api/cxx.py
rename to torch/utils/_cxx_pytree.py
index ddf79e0..f89a563 100644
--- a/torch/utils/_pytree/api/cxx.py
+++ b/torch/utils/_cxx_pytree.py
@@ -13,7 +13,18 @@
 """
 
 import functools
-from typing import Any, Callable, Iterable, List, Optional, overload, Tuple, Type, Union
+from typing import (
+    Any,
+    Callable,
+    Iterable,
+    List,
+    Optional,
+    overload,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+)
 
 import torch
 
@@ -23,20 +34,6 @@
 import optree
 from optree import PyTreeSpec  # direct import for type annotations
 
-from .typing import (
-    Context,
-    DumpableContext,
-    FlattenFunc,
-    FromDumpableContextFn,
-    PyTree,
-    R,
-    S,
-    T,
-    ToDumpableContextFn,
-    U,
-    UnflattenFunc,
-)
-
 
 __all__ = [
     "PyTree",
@@ -67,8 +64,21 @@
 ]
 
 
+T = TypeVar("T")
+S = TypeVar("S")
+U = TypeVar("U")
+R = TypeVar("R")
+
+
+Context = Optional[Any]
+PyTree = Any
 TreeSpec = PyTreeSpec
+FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
+UnflattenFunc = Callable[[Iterable, Context], PyTree]
 OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree]
+DumpableContext = Any  # Any json dumpable text
+ToDumpableContextFn = Callable[[Context], DumpableContext]
+FromDumpableContextFn = Callable[[DumpableContext], Context]
 
 
 def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
@@ -213,7 +223,7 @@
         namespace=namespace,
     )
 
-    from . import python
+    from . import _pytree as python
 
     python._private_register_pytree_node(
         cls,
@@ -883,7 +893,7 @@
             f"treespec_dumps(spec): Expected `spec` to be instance of "
             f"TreeSpec but got item of type {type(treespec)}."
         )
-    from .python import (
+    from ._pytree import (
         tree_structure as _tree_structure,
         treespec_dumps as _treespec_dumps,
     )
@@ -894,7 +904,7 @@
 
 def treespec_loads(serialized: str) -> TreeSpec:
     """Deserialize a treespec from a JSON string."""
-    from .python import (
+    from ._pytree import (
         tree_unflatten as _tree_unflatten,
         treespec_loads as _treespec_loads,
     )
diff --git a/torch/utils/_pytree/api/python.py b/torch/utils/_pytree.py
similarity index 97%
rename from torch/utils/_pytree/api/python.py
rename to torch/utils/_pytree.py
index 120e37e..6b4e387 100644
--- a/torch/utils/_pytree/api/python.py
+++ b/torch/utils/_pytree.py
@@ -32,23 +32,10 @@
     overload,
     Tuple,
     Type,
+    TypeVar,
     Union,
 )
 
-from .typing import (
-    Context,
-    DumpableContext,
-    FlattenFunc,
-    FromDumpableContextFn,
-    PyTree,
-    R,
-    S,
-    T,
-    ToDumpableContextFn,
-    U,
-    UnflattenFunc,
-)
-
 
 __all__ = [
     "PyTree",
@@ -79,10 +66,22 @@
 ]
 
 
+T = TypeVar("T")
+S = TypeVar("S")
+U = TypeVar("U")
+R = TypeVar("R")
+
+
 DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1
 NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND"
 
-
+Context = Any
+PyTree = Any
+FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
+UnflattenFunc = Callable[[Iterable, Context], PyTree]
+DumpableContext = Any  # Any json dumpable text
+ToDumpableContextFn = Callable[[Context], DumpableContext]
+FromDumpableContextFn = Callable[[DumpableContext], Context]
 ToStrFunc = Callable[["TreeSpec", List[str]], str]
 MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]]
 
@@ -163,7 +162,7 @@
     )
 
     try:
-        from . import cxx
+        from . import _cxx_pytree as cxx
     except ImportError:
         pass
     else:
diff --git a/torch/utils/_pytree/__init__.py b/torch/utils/_pytree/__init__.py
deleted file mode 100644
index 0cd613d..0000000
--- a/torch/utils/_pytree/__init__.py
+++ /dev/null
@@ -1,87 +0,0 @@
-"""
-Contains utility functions for working with nested python data structures.
-
-A *pytree* is Python nested data structure. It is a tree in the sense that
-nodes are Python collections (e.g., list, tuple, dict) and the leaves are
-Python values. Furthermore, a pytree should not contain reference cycles.
-
-pytrees are useful for working with nested collections of Tensors. For example,
-one can use `tree_map` to map a function over all Tensors inside some nested
-collection of Tensors and `tree_leaves` to get a flat list of all Tensors
-inside some nested collection. pytrees are helpful for implementing nested
-collection support for PyTorch APIs.
-"""
-
-from .api import (
-    Context,
-    DumpableContext,
-    FlattenFunc,
-    FromDumpableContextFn,
-    LeafSpec,
-    PyTree,
-    register_pytree_node,
-    ToDumpableContextFn,
-    tree_all,
-    tree_all_only,
-    tree_any,
-    tree_any_only,
-    tree_flatten,
-    tree_leaves,
-    tree_map,
-    tree_map_,
-    tree_map_only,
-    tree_map_only_,
-    tree_structure,
-    tree_unflatten,
-    TreeSpec,
-    treespec_dumps,
-    treespec_loads,
-    treespec_pprint,
-    UnflattenFunc,
-)
-from .api.python import (  # used by internals and/or third-party packages
-    _broadcast_to_and_flatten,
-    _dict_flatten,
-    _dict_unflatten,
-    _is_leaf,
-    _list_flatten,
-    _list_unflatten,
-    _namedtuple_flatten,
-    _namedtuple_unflatten,
-    _odict_flatten,
-    _odict_unflatten,
-    _register_pytree_node,
-    _tuple_flatten,
-    _tuple_unflatten,
-    arg_tree_leaves,
-    SUPPORTED_NODES,
-)
-
-
-__all__ = [
-    "PyTree",
-    "Context",
-    "FlattenFunc",
-    "UnflattenFunc",
-    "DumpableContext",
-    "ToDumpableContextFn",
-    "FromDumpableContextFn",
-    "TreeSpec",
-    "LeafSpec",
-    "register_pytree_node",
-    "tree_flatten",
-    "tree_unflatten",
-    "tree_leaves",
-    "tree_structure",
-    "tree_map",
-    "tree_map_",
-    "tree_map_only",
-    "tree_map_only_",
-    "tree_all",
-    "tree_any",
-    "tree_all_only",
-    "tree_any_only",
-    "treespec_dumps",
-    "treespec_loads",
-    "treespec_pprint",
-]
diff --git a/torch/utils/_pytree/api/__init__.py b/torch/utils/_pytree/api/__init__.py
deleted file mode 100644
index f7d5ab8..0000000
--- a/torch/utils/_pytree/api/__init__.py
+++ /dev/null
@@ -1,72 +0,0 @@
-"""
-Contains utility functions for working with nested python data structures.
-
-A *pytree* is Python nested data structure. It is a tree in the sense that
-nodes are Python collections (e.g., list, tuple, dict) and the leaves are
-Python values. Furthermore, a pytree should not contain reference cycles.
-
-pytrees are useful for working with nested collections of Tensors. For example,
-one can use `tree_map` to map a function over all Tensors inside some nested
-collection of Tensors and `tree_leaves` to get a flat list of all Tensors
-inside some nested collection. pytrees are helpful for implementing nested
-collection support for PyTorch APIs.
-"""
-
-from .python import (
-    LeafSpec,
-    register_pytree_node,
-    tree_all,
-    tree_all_only,
-    tree_any,
-    tree_any_only,
-    tree_flatten,
-    tree_leaves,
-    tree_map,
-    tree_map_,
-    tree_map_only,
-    tree_map_only_,
-    tree_structure,
-    tree_unflatten,
-    TreeSpec,
-    treespec_dumps,
-    treespec_loads,
-    treespec_pprint,
-)
-from .typing import (
-    Context,
-    DumpableContext,
-    FlattenFunc,
-    FromDumpableContextFn,
-    PyTree,
-    ToDumpableContextFn,
-    UnflattenFunc,
-)
-
-
-__all__ = [
-    "PyTree",
-    "Context",
-    "FlattenFunc",
-    "UnflattenFunc",
-    "DumpableContext",
-    "ToDumpableContextFn",
-    "FromDumpableContextFn",
-    "TreeSpec",
-    "LeafSpec",
-    "register_pytree_node",
-    "tree_flatten",
-    "tree_unflatten",
-    "tree_leaves",
-    "tree_structure",
-    "tree_map",
-    "tree_map_",
-    "tree_map_only",
-    "tree_map_only_",
-    "tree_all",
-    "tree_any",
-    "tree_all_only",
-    "tree_any_only",
-    "treespec_dumps",
-    "treespec_loads",
-    "treespec_pprint",
-]
diff --git a/torch/utils/_pytree/api/typing.py b/torch/utils/_pytree/api/typing.py
deleted file mode 100644
index c9e0189..0000000
--- a/torch/utils/_pytree/api/typing.py
+++ /dev/null
@@ -1,29 +0,0 @@
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from typing import Any, Callable, Iterable, List, Tuple, TypeVar
-
-
-__all__ = [
-    "Context",
-    "PyTree",
-    "FlattenFunc",
-    "UnflattenFunc",
-    "DumpableContext",
-    "ToDumpableContextFn",
-    "FromDumpableContextFn",
-]
-
-
-T = TypeVar("T")
-S = TypeVar("S")
-U = TypeVar("U")
-R = TypeVar("R")
-
-Context = Any
-PyTree = Any
-FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
-UnflattenFunc = Callable[[Iterable, Context], PyTree]
-DumpableContext = Any  # Any json dumpable text
-ToDumpableContextFn = Callable[[Context], DumpableContext]
-FromDumpableContextFn = Callable[[DumpableContext], Context]