Revert "Refactored prim utils into _prims_utils folder (#81088)"

This reverts commit 80231d0a72453573728242daa0dbc9cc7b45669c.

Reverted https://github.com/pytorch/pytorch/pull/81088 on behalf of https://github.com/jeanschmidt due to breaking internal tests
diff --git a/.github/ci_commit_pins/torchdynamo.txt b/.github/ci_commit_pins/torchdynamo.txt
index 9a41dd3..4b29216 100644
--- a/.github/ci_commit_pins/torchdynamo.txt
+++ b/.github/ci_commit_pins/torchdynamo.txt
@@ -1 +1 @@
-30ab3e6cb3c39678c366f102a49cfb59f8507b15
+dbb005e7429300e6605abdf6533d0d6dac8dabe3
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index d8b403a..27d906e 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -315,13 +315,13 @@
 
     def test_no_ref_cycle(self):
         x = torch.rand([4])
-        mode = torch._prims.get_prim_fake_mode()
+        mode = torch._prims.utils.get_prim_fake_mode()
         y = mode.from_tensor(x)
-        assert mode is torch._prims.get_prim_fake_mode()
+        assert mode is torch._prims.utils.get_prim_fake_mode()
         self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1)
         del mode
         del y
-        new_mode = torch._prims.get_prim_fake_mode()
+        new_mode = torch._prims.utils.get_prim_fake_mode()
         self.assertEqual(len(new_mode.fake_tensor_converter.tensor_memo), 0)
 
 
diff --git a/test/test_ops.py b/test/test_ops.py
index 0b04ae7..dc1b4e8 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -362,7 +362,7 @@
     @onlyNativeDeviceTypes
     @ops(python_ref_db)
     def test_python_ref_meta(self, device, dtype, op):
-        mode = torch._prims.get_prim_fake_mode()
+        mode = torch._prims.utils.get_prim_fake_mode()
 
         def _to_tensormeta(x):
             if isinstance(x, torch.Tensor):
@@ -512,7 +512,7 @@
     @parametrize('executor', ['aten', 'nvfuser'])
     def test_python_ref_executor(self, device, dtype, op, executor):
         # TODO: Not all dtypes are supported with nvfuser
-        from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map
+        from torch._prims.utils import _torch_dtype_to_nvfuser_dtype_map
         if executor == "nvfuser" and dtype not in _torch_dtype_to_nvfuser_dtype_map:
             raise unittest.SkipTest(f"nvfuser doesn't support dtype {dtype}")
 
@@ -560,7 +560,7 @@
     @onlyNativeDeviceTypes
     @ops([op for op in python_ref_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
     def test_python_ref_errors(self, device, op):
-        mode = torch._prims.get_prim_fake_mode()
+        mode = torch._prims.utils.get_prim_fake_mode()
 
         def _to_tensormeta(x):
             if isinstance(x, torch.Tensor):
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 49b61ea..cfc62d7 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -6,8 +6,8 @@
 import torch.nn.functional as F
 import functools
 from torch.utils._pytree import tree_map, tree_flatten
-import torch._prims_common as utils
-from torch._prims_common.wrappers import out_wrapper
+import torch._prims.utils as utils
+from torch._prims.wrappers import out_wrapper
 
 # None of these functions are publicly accessible; get at them
 # from torch._decomps
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index c33ab30..5f3e3a0 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -1,12 +1,12 @@
 import torch
 from torch import Tensor
-import torch._prims_common as utils
-from torch._prims_common import (
+from torch._prims import utils
+from torch._prims.utils import (
     ELEMENTWISE_TYPE_PROMOTION_KIND,
     check,
     elementwise_dtypes,
 )
-from torch._prims_common.wrappers import out_wrapper
+from torch._prims.wrappers import out_wrapper
 
 from typing import List, Optional
 
diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py
index 8174146..e09763f 100644
--- a/torch/_prims/__init__.py
+++ b/torch/_prims/__init__.py
@@ -1,8 +1,8 @@
 import torch
 from torch import Tensor, _TypedStorage
 
-import torch._prims_common as utils
-from torch._prims_common import (
+import torch._prims.utils as utils
+from torch._prims.utils import (
     check,
     TensorLike,
     TensorLikeType,
@@ -13,12 +13,12 @@
     StrideType,
     Number,
     NumberType,
-    type_to_dtype,
+    TensorMeta,
 )
 from torch.overrides import has_torch_function, handle_torch_function
 import torch.library
 from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
-from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
+from torch._subclasses.fake_tensor import FakeTensor
 
 import contextlib
 from typing import Sequence, Optional, Union, Callable, List, Tuple, Any, Type
@@ -26,7 +26,6 @@
 from enum import Enum
 import operator
 import math
-import weakref
 
 prim = torch.library.Library("prims", "DEF")
 prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd")
@@ -192,83 +191,6 @@
     "fft_c2r",
 ]
 
-
-# In order to keep things like aliasing relationships and storage
-# consistent wrt/meta tensors, FakeTensors own a FakeTensorMode
-# which caches conversions to Meta Tensors. We would like to use
-# one consistent mode among along FakeTensors, which we store here.
-# We store a weakref, so that when all previous FakeTensors are
-# the present mode will also deallocate. FakeTensorMode holds onto
-# tensors that are converted to Meta so we don't want to persist it
-# longer than necessary.x
-prim_fake_mode_ref = None
-
-
-def get_prim_fake_mode():
-    global prim_fake_mode_ref
-    if prim_fake_mode_ref is None or prim_fake_mode_ref() is None:
-        mode = FakeTensorMode()
-        prim_fake_mode_ref = weakref.ref(mode)
-        return mode
-    else:
-        return prim_fake_mode_ref()
-
-
-def TensorMeta(
-    tensorlike: Optional[Union[NumberType, torch.Tensor]] = None,
-    *,
-    shape: Optional[ShapeType] = None,
-    strides: Optional[StrideType] = None,
-    dtype: Optional[torch.dtype] = None,
-    device: Optional[Union[torch.device, str]] = None,
-):
-    if isinstance(tensorlike, Number):
-        assert not shape and (shape is None or isinstance(shape, Sequence))
-        assert not strides and (strides is None or isinstance(strides, Sequence))
-        inferred_shape: Tuple[int, ...] = ()
-        inferred_strides: Tuple[int, ...] = ()
-        inferred_dtype = type_to_dtype(type(tensorlike))
-        inferred_device = torch.device("cpu")
-        # TODO: This looks wrong, a number that is wrapped into a tensor
-        # needs to behave differently than a scalar tensor for type
-        # promotion purposes
-    elif tensorlike is not None:
-        assert isinstance(tensorlike, torch.Tensor)
-        inferred_shape = tuple(tensorlike.shape)
-        inferred_strides = tuple(tensorlike.stride())
-        inferred_dtype = tensorlike.dtype
-        inferred_device = tensorlike.device
-    else:
-        # If no tensorlike "example" is given then all metadata
-        # must be provided explicitly
-        assert shape is not None
-        assert strides is not None
-        assert dtype is not None
-        assert device is not None
-
-    shape = inferred_shape if shape is None else tuple(shape)
-    strides = inferred_strides if strides is None else tuple(strides)
-    dtype = inferred_dtype if dtype is None else dtype
-    device = inferred_device if device is None else device
-
-    if isinstance(device, str):
-        device = torch.device(device)
-
-    if isinstance(tensorlike, FakeTensor):
-        mode = tensorlike.fake_mode
-    else:
-        mode = get_prim_fake_mode()
-
-    if device.type == "meta":
-        return torch.empty_strided(shape, strides, dtype=dtype, device="meta")
-    else:
-        return FakeTensor(
-            mode,
-            torch.empty(shape, dtype=dtype, device="meta"),
-            device,
-        )
-
-
 #
 # Common datastructures and helpers
 #
@@ -394,7 +316,7 @@
             and not isinstance(t, FakeTensor)
             and not t.device.type == "meta"
         ):
-            return FakeTensor.from_tensor(t, get_prim_fake_mode())
+            return FakeTensor.from_tensor(t, utils.get_prim_fake_mode())
         else:
             return t
 
diff --git a/torch/_prims/context.py b/torch/_prims/context.py
index 0066aea..8f22ad6 100644
--- a/torch/_prims/context.py
+++ b/torch/_prims/context.py
@@ -5,7 +5,7 @@
 import torch
 import torch.overrides
 
-from torch._prims_common import torch_function_passthrough
+from torch._prims.utils import torch_function_passthrough
 
 import torch._refs
 import torch._refs.nn
@@ -34,9 +34,6 @@
         torch.Tensor.__and__: torch._refs.bitwise_and,
         torch.Tensor.__or__: torch._refs.bitwise_or,
         torch.Tensor.__eq__: torch._refs.eq,
-        # TODO: Should these methods be mapped some other way?
-        torch.Tensor.copy_: torch._prims.copy_to,
-        torch.Tensor.resize: torch._prims.resize,
     }
     for mod_torch, mod_refs in modules:
         for s in mod_refs.__all__:  # type: ignore[attr-defined]
diff --git a/torch/_prims/nvfuser_executor.py b/torch/_prims/nvfuser_executor.py
index 9a917c4..f1d5cea 100644
--- a/torch/_prims/nvfuser_executor.py
+++ b/torch/_prims/nvfuser_executor.py
@@ -4,7 +4,7 @@
 import torch
 
 from torch.fx import GraphModule
-from torch._prims_common import getnvFuserDtype, Number
+from torch._prims.utils import getnvFuserDtype, Number
 import torch.overrides
 from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
 
diff --git a/torch/_prims_common/__init__.py b/torch/_prims/utils.py
similarity index 93%
rename from torch/_prims_common/__init__.py
rename to torch/_prims/utils.py
index b2608dc..5080eba 100644
--- a/torch/_prims_common/__init__.py
+++ b/torch/_prims/utils.py
@@ -5,6 +5,7 @@
 from functools import reduce, cmp_to_key
 import operator
 import weakref
+from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
 
 import torch
 
@@ -73,6 +74,82 @@
 TensorOrNumberLikeType = Union[TensorLikeType, NumberType]
 
 
+# In order to keep things like aliasing relationships and storage
+# consistent wrt/meta tensors, FakeTensors own a FakeTensorMode
+# which caches conversions to Meta Tensors. We would like to use
+# one consistent mode among along FakeTensors, which we store here.
+# We store a weakref, so that when all previous FakeTensors are
+# the present mode will also deallocate. FakeTensorMode holds onto
+# tensors that are converted to Meta so we don't want to persist it
+# longer than necessary.x
+prim_fake_mode_ref = None
+
+
+def get_prim_fake_mode():
+    global prim_fake_mode_ref
+    if prim_fake_mode_ref is None or prim_fake_mode_ref() is None:
+        mode = FakeTensorMode()
+        prim_fake_mode_ref = weakref.ref(mode)
+        return mode
+    else:
+        return prim_fake_mode_ref()
+
+
+def TensorMeta(
+    tensorlike: Optional[Union[NumberType, torch.Tensor]] = None,
+    *,
+    shape: Optional[ShapeType] = None,
+    strides: Optional[StrideType] = None,
+    dtype: Optional[torch.dtype] = None,
+    device: Optional[Union[torch.device, str]] = None,
+):
+    if isinstance(tensorlike, Number):
+        assert not shape and (shape is None or isinstance(shape, Sequence))
+        assert not strides and (strides is None or isinstance(strides, Sequence))
+        inferred_shape: Tuple[int, ...] = ()
+        inferred_strides: Tuple[int, ...] = ()
+        inferred_dtype = type_to_dtype(type(tensorlike))
+        inferred_device = torch.device("cpu")
+        # TODO: This looks wrong, a number that is wrapped into a tensor
+        # needs to behave differently than a scalar tensor for type
+        # promotion purposes
+    elif tensorlike is not None:
+        assert isinstance(tensorlike, torch.Tensor)
+        inferred_shape = tuple(tensorlike.shape)
+        inferred_strides = tuple(tensorlike.stride())
+        inferred_dtype = tensorlike.dtype
+        inferred_device = tensorlike.device
+    else:
+        # If no tensorlike "example" is given then all metadata
+        # must be provided explicitly
+        assert shape is not None
+        assert strides is not None
+        assert dtype is not None
+        assert device is not None
+
+    shape = inferred_shape if shape is None else tuple(shape)
+    strides = inferred_strides if strides is None else tuple(strides)
+    dtype = inferred_dtype if dtype is None else dtype
+    device = inferred_device if device is None else device
+
+    if isinstance(device, str):
+        device = torch.device(device)
+
+    if isinstance(tensorlike, FakeTensor):
+        mode = tensorlike.fake_mode
+    else:
+        mode = get_prim_fake_mode()
+
+    if device.type == "meta":
+        return torch.empty_strided(shape, strides, dtype=dtype, device="meta")
+    else:
+        return FakeTensor(
+            mode,
+            torch.empty_strided(shape, strides, dtype=dtype, device="meta"),
+            device,
+        )
+
+
 def same_shape(a: ShapeType, b: ShapeType) -> bool:
     if len(a) != len(b):
         return False
diff --git a/torch/_prims_common/wrappers.py b/torch/_prims/wrappers.py
similarity index 97%
rename from torch/_prims_common/wrappers.py
rename to torch/_prims/wrappers.py
index 542b977..fe3e5f5 100644
--- a/torch/_prims_common/wrappers.py
+++ b/torch/_prims/wrappers.py
@@ -1,12 +1,12 @@
 import torch
-from torch._prims_common import (
+from torch._prims.utils import (
     Number,
     NumberType,
     TensorLike,
     TensorLikeType,
     ELEMENTWISE_TYPE_PROMOTION_KIND,
 )
-import torch._prims_common as utils
+import torch._prims.utils as utils
 from torch.utils._pytree import tree_flatten
 
 from typing import Callable, Sequence, Union, Tuple, NamedTuple
@@ -20,7 +20,6 @@
 def _maybe_convert_to_dtype(
     a: Union[TensorLikeType, NumberType, Sequence], dtype: torch.dtype
 ) -> Union[TensorLikeType, NumberType, Sequence]:
-    import torch._prims as prims
     if isinstance(a, TensorLike):
         if a.dtype != dtype:
             # NOTE: this is incorrect on the CPU
@@ -125,7 +124,7 @@
 # TODO: handle tuples of tensors
 def _maybe_resize_out(out: TensorLikeType, shape):
     if out.numel() == 0:
-        return out.resize_(shape)
+        return prims.resize(out, shape)
 
     if out.numel() != reduce(operator.mul, shape, 1):
         msg = (
@@ -138,7 +137,7 @@
             )
         )
         warnings.warn(msg)
-        return out.resize_(shape)
+        return prims.resize(out, shape)
 
     return out
 
@@ -167,7 +166,7 @@
             "but this can't be cast because it is not safe!",
         )
 
-    return copy_to.copy_(copy_from)
+    return prims.copy_to(copy_to, copy_from)
 
 
 def out_wrapper(*out_names: str, exact_dtype: bool = False):
@@ -282,3 +281,7 @@
 
     _fn.__signature__ = sig  # type: ignore[attr-defined]
     return _fn
+
+
+# avoid mypy import cycle
+import torch._prims as prims
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 86e4d86..0dff4fc 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -1,8 +1,8 @@
 import torch
 
 import torch._prims as prims
-import torch._prims_common as utils
-from torch._prims_common import (
+import torch._prims.utils as utils
+from torch._prims.utils import (
     check,
     DimsType,
     ShapeType,
@@ -20,7 +20,7 @@
     is_weakly_lesser_type,
     dtype_to_type,
 )
-from torch._prims_common.wrappers import (
+from torch._prims.wrappers import (
     elementwise_type_promotion_wrapper,
     out_wrapper,
     _maybe_convert_to_dtype,
diff --git a/torch/_refs/fft.py b/torch/_refs/fft.py
index e76cd4c..92996ea 100644
--- a/torch/_refs/fft.py
+++ b/torch/_refs/fft.py
@@ -1,13 +1,13 @@
 import torch
 import torch._prims as prims
-import torch._prims_common as utils
-from torch._prims_common import (
+import torch._prims.utils as utils
+from torch._prims.utils import (
     check,
     TensorLikeType,
     ShapeType,
     DimsType,
 )
-from torch._prims_common.wrappers import (
+from torch._prims.wrappers import (
     out_wrapper,
 )
 from torch._decomp import register_decomposition
diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py
index 52a4307..8916e8e 100644
--- a/torch/_refs/linalg/__init__.py
+++ b/torch/_refs/linalg/__init__.py
@@ -3,9 +3,9 @@
 
 import torch._prims as prims
 import torch._refs as refs
-from torch._prims_common.wrappers import out_wrapper
+from torch._prims.wrappers import out_wrapper
 
-from torch._prims_common import (
+from torch._prims.utils import (
     check,
     check_fp_or_complex,
     DimsType,
diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py
index 8d9a0fe..c0f5075 100644
--- a/torch/_refs/nn/functional/__init__.py
+++ b/torch/_refs/nn/functional/__init__.py
@@ -1,8 +1,8 @@
 import torch
 
 import torch._prims as prims
-import torch._prims_common as utils
-from torch._prims_common import (
+import torch._prims.utils as utils
+from torch._prims.utils import (
     check,
     ShapeType,
     TensorLike,
@@ -12,7 +12,7 @@
 )
 import torch._refs as refs
 from torch._decomp import register_decomposition
-from torch._prims_common.wrappers import (
+from torch._prims.wrappers import (
     elementwise_type_promotion_wrapper,
     elementwise_unary_scalar_wrapper,
     out_wrapper,
diff --git a/torch/_refs/special/__init__.py b/torch/_refs/special/__init__.py
index 22649b6..4d576af 100644
--- a/torch/_refs/special/__init__.py
+++ b/torch/_refs/special/__init__.py
@@ -3,10 +3,10 @@
 from torch import Tensor
 from typing import Optional
 import torch._prims as prims
-import torch._prims_common as utils
+import torch._prims.utils as utils
 import torch._refs as refs
-from torch._prims_common import TensorLikeType, ELEMENTWISE_TYPE_PROMOTION_KIND
-from torch._prims_common.wrappers import out_wrapper, elementwise_type_promotion_wrapper
+from torch._prims.utils import TensorLikeType, ELEMENTWISE_TYPE_PROMOTION_KIND
+from torch._prims.wrappers import out_wrapper, elementwise_type_promotion_wrapper
 from torch._refs import (
     _make_elementwise_unary_reference,
     _make_elementwise_binary_reference,
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index 25cd3f5..2905383 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -237,4 +237,4 @@
             # non-Tensor types don't count as hit or miss
             return t
 
-import torch._prims_common as utils
+import torch._prims.utils as utils
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 8809870..17be376 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -6362,7 +6362,7 @@
 
     # NOTE: primTorch is more strict about the type of the fill value argument
     # So we must cast it to the correct dtype
-    from torch._prims_common import dtype_to_type
+    from torch._prims.utils import dtype_to_type
     scalar_type = dtype_to_type(dtype)
 
     def drop_mode_argument(input, pad, mode=None, value=None):