Revert "Refactored prim utils into _prims_utils folder (#81088)"
This reverts commit 80231d0a72453573728242daa0dbc9cc7b45669c.
Reverted https://github.com/pytorch/pytorch/pull/81088 on behalf of https://github.com/jeanschmidt due to breaking internal tests
diff --git a/.github/ci_commit_pins/torchdynamo.txt b/.github/ci_commit_pins/torchdynamo.txt
index 9a41dd3..4b29216 100644
--- a/.github/ci_commit_pins/torchdynamo.txt
+++ b/.github/ci_commit_pins/torchdynamo.txt
@@ -1 +1 @@
-30ab3e6cb3c39678c366f102a49cfb59f8507b15
+dbb005e7429300e6605abdf6533d0d6dac8dabe3
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index d8b403a..27d906e 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -315,13 +315,13 @@
def test_no_ref_cycle(self):
x = torch.rand([4])
- mode = torch._prims.get_prim_fake_mode()
+ mode = torch._prims.utils.get_prim_fake_mode()
y = mode.from_tensor(x)
- assert mode is torch._prims.get_prim_fake_mode()
+ assert mode is torch._prims.utils.get_prim_fake_mode()
self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1)
del mode
del y
- new_mode = torch._prims.get_prim_fake_mode()
+ new_mode = torch._prims.utils.get_prim_fake_mode()
self.assertEqual(len(new_mode.fake_tensor_converter.tensor_memo), 0)
diff --git a/test/test_ops.py b/test/test_ops.py
index 0b04ae7..dc1b4e8 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -362,7 +362,7 @@
@onlyNativeDeviceTypes
@ops(python_ref_db)
def test_python_ref_meta(self, device, dtype, op):
- mode = torch._prims.get_prim_fake_mode()
+ mode = torch._prims.utils.get_prim_fake_mode()
def _to_tensormeta(x):
if isinstance(x, torch.Tensor):
@@ -512,7 +512,7 @@
@parametrize('executor', ['aten', 'nvfuser'])
def test_python_ref_executor(self, device, dtype, op, executor):
# TODO: Not all dtypes are supported with nvfuser
- from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map
+ from torch._prims.utils import _torch_dtype_to_nvfuser_dtype_map
if executor == "nvfuser" and dtype not in _torch_dtype_to_nvfuser_dtype_map:
raise unittest.SkipTest(f"nvfuser doesn't support dtype {dtype}")
@@ -560,7 +560,7 @@
@onlyNativeDeviceTypes
@ops([op for op in python_ref_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
def test_python_ref_errors(self, device, op):
- mode = torch._prims.get_prim_fake_mode()
+ mode = torch._prims.utils.get_prim_fake_mode()
def _to_tensormeta(x):
if isinstance(x, torch.Tensor):
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 49b61ea..cfc62d7 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -6,8 +6,8 @@
import torch.nn.functional as F
import functools
from torch.utils._pytree import tree_map, tree_flatten
-import torch._prims_common as utils
-from torch._prims_common.wrappers import out_wrapper
+import torch._prims.utils as utils
+from torch._prims.wrappers import out_wrapper
# None of these functions are publicly accessible; get at them
# from torch._decomps
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index c33ab30..5f3e3a0 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -1,12 +1,12 @@
import torch
from torch import Tensor
-import torch._prims_common as utils
-from torch._prims_common import (
+from torch._prims import utils
+from torch._prims.utils import (
ELEMENTWISE_TYPE_PROMOTION_KIND,
check,
elementwise_dtypes,
)
-from torch._prims_common.wrappers import out_wrapper
+from torch._prims.wrappers import out_wrapper
from typing import List, Optional
diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py
index 8174146..e09763f 100644
--- a/torch/_prims/__init__.py
+++ b/torch/_prims/__init__.py
@@ -1,8 +1,8 @@
import torch
from torch import Tensor, _TypedStorage
-import torch._prims_common as utils
-from torch._prims_common import (
+import torch._prims.utils as utils
+from torch._prims.utils import (
check,
TensorLike,
TensorLikeType,
@@ -13,12 +13,12 @@
StrideType,
Number,
NumberType,
- type_to_dtype,
+ TensorMeta,
)
from torch.overrides import has_torch_function, handle_torch_function
import torch.library
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
-from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
+from torch._subclasses.fake_tensor import FakeTensor
import contextlib
from typing import Sequence, Optional, Union, Callable, List, Tuple, Any, Type
@@ -26,7 +26,6 @@
from enum import Enum
import operator
import math
-import weakref
prim = torch.library.Library("prims", "DEF")
prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd")
@@ -192,83 +191,6 @@
"fft_c2r",
]
-
-# In order to keep things like aliasing relationships and storage
-# consistent wrt/meta tensors, FakeTensors own a FakeTensorMode
-# which caches conversions to Meta Tensors. We would like to use
-# one consistent mode among along FakeTensors, which we store here.
-# We store a weakref, so that when all previous FakeTensors are
-# the present mode will also deallocate. FakeTensorMode holds onto
-# tensors that are converted to Meta so we don't want to persist it
-# longer than necessary.x
-prim_fake_mode_ref = None
-
-
-def get_prim_fake_mode():
- global prim_fake_mode_ref
- if prim_fake_mode_ref is None or prim_fake_mode_ref() is None:
- mode = FakeTensorMode()
- prim_fake_mode_ref = weakref.ref(mode)
- return mode
- else:
- return prim_fake_mode_ref()
-
-
-def TensorMeta(
- tensorlike: Optional[Union[NumberType, torch.Tensor]] = None,
- *,
- shape: Optional[ShapeType] = None,
- strides: Optional[StrideType] = None,
- dtype: Optional[torch.dtype] = None,
- device: Optional[Union[torch.device, str]] = None,
-):
- if isinstance(tensorlike, Number):
- assert not shape and (shape is None or isinstance(shape, Sequence))
- assert not strides and (strides is None or isinstance(strides, Sequence))
- inferred_shape: Tuple[int, ...] = ()
- inferred_strides: Tuple[int, ...] = ()
- inferred_dtype = type_to_dtype(type(tensorlike))
- inferred_device = torch.device("cpu")
- # TODO: This looks wrong, a number that is wrapped into a tensor
- # needs to behave differently than a scalar tensor for type
- # promotion purposes
- elif tensorlike is not None:
- assert isinstance(tensorlike, torch.Tensor)
- inferred_shape = tuple(tensorlike.shape)
- inferred_strides = tuple(tensorlike.stride())
- inferred_dtype = tensorlike.dtype
- inferred_device = tensorlike.device
- else:
- # If no tensorlike "example" is given then all metadata
- # must be provided explicitly
- assert shape is not None
- assert strides is not None
- assert dtype is not None
- assert device is not None
-
- shape = inferred_shape if shape is None else tuple(shape)
- strides = inferred_strides if strides is None else tuple(strides)
- dtype = inferred_dtype if dtype is None else dtype
- device = inferred_device if device is None else device
-
- if isinstance(device, str):
- device = torch.device(device)
-
- if isinstance(tensorlike, FakeTensor):
- mode = tensorlike.fake_mode
- else:
- mode = get_prim_fake_mode()
-
- if device.type == "meta":
- return torch.empty_strided(shape, strides, dtype=dtype, device="meta")
- else:
- return FakeTensor(
- mode,
- torch.empty(shape, dtype=dtype, device="meta"),
- device,
- )
-
-
#
# Common datastructures and helpers
#
@@ -394,7 +316,7 @@
and not isinstance(t, FakeTensor)
and not t.device.type == "meta"
):
- return FakeTensor.from_tensor(t, get_prim_fake_mode())
+ return FakeTensor.from_tensor(t, utils.get_prim_fake_mode())
else:
return t
diff --git a/torch/_prims/context.py b/torch/_prims/context.py
index 0066aea..8f22ad6 100644
--- a/torch/_prims/context.py
+++ b/torch/_prims/context.py
@@ -5,7 +5,7 @@
import torch
import torch.overrides
-from torch._prims_common import torch_function_passthrough
+from torch._prims.utils import torch_function_passthrough
import torch._refs
import torch._refs.nn
@@ -34,9 +34,6 @@
torch.Tensor.__and__: torch._refs.bitwise_and,
torch.Tensor.__or__: torch._refs.bitwise_or,
torch.Tensor.__eq__: torch._refs.eq,
- # TODO: Should these methods be mapped some other way?
- torch.Tensor.copy_: torch._prims.copy_to,
- torch.Tensor.resize: torch._prims.resize,
}
for mod_torch, mod_refs in modules:
for s in mod_refs.__all__: # type: ignore[attr-defined]
diff --git a/torch/_prims/nvfuser_executor.py b/torch/_prims/nvfuser_executor.py
index 9a917c4..f1d5cea 100644
--- a/torch/_prims/nvfuser_executor.py
+++ b/torch/_prims/nvfuser_executor.py
@@ -4,7 +4,7 @@
import torch
from torch.fx import GraphModule
-from torch._prims_common import getnvFuserDtype, Number
+from torch._prims.utils import getnvFuserDtype, Number
import torch.overrides
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
diff --git a/torch/_prims_common/__init__.py b/torch/_prims/utils.py
similarity index 93%
rename from torch/_prims_common/__init__.py
rename to torch/_prims/utils.py
index b2608dc..5080eba 100644
--- a/torch/_prims_common/__init__.py
+++ b/torch/_prims/utils.py
@@ -5,6 +5,7 @@
from functools import reduce, cmp_to_key
import operator
import weakref
+from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
import torch
@@ -73,6 +74,82 @@
TensorOrNumberLikeType = Union[TensorLikeType, NumberType]
+# In order to keep things like aliasing relationships and storage
+# consistent wrt/meta tensors, FakeTensors own a FakeTensorMode
+# which caches conversions to Meta Tensors. We would like to use
+# one consistent mode among along FakeTensors, which we store here.
+# We store a weakref, so that when all previous FakeTensors are
+# the present mode will also deallocate. FakeTensorMode holds onto
+# tensors that are converted to Meta so we don't want to persist it
+# longer than necessary.x
+prim_fake_mode_ref = None
+
+
+def get_prim_fake_mode():
+ global prim_fake_mode_ref
+ if prim_fake_mode_ref is None or prim_fake_mode_ref() is None:
+ mode = FakeTensorMode()
+ prim_fake_mode_ref = weakref.ref(mode)
+ return mode
+ else:
+ return prim_fake_mode_ref()
+
+
+def TensorMeta(
+ tensorlike: Optional[Union[NumberType, torch.Tensor]] = None,
+ *,
+ shape: Optional[ShapeType] = None,
+ strides: Optional[StrideType] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[Union[torch.device, str]] = None,
+):
+ if isinstance(tensorlike, Number):
+ assert not shape and (shape is None or isinstance(shape, Sequence))
+ assert not strides and (strides is None or isinstance(strides, Sequence))
+ inferred_shape: Tuple[int, ...] = ()
+ inferred_strides: Tuple[int, ...] = ()
+ inferred_dtype = type_to_dtype(type(tensorlike))
+ inferred_device = torch.device("cpu")
+ # TODO: This looks wrong, a number that is wrapped into a tensor
+ # needs to behave differently than a scalar tensor for type
+ # promotion purposes
+ elif tensorlike is not None:
+ assert isinstance(tensorlike, torch.Tensor)
+ inferred_shape = tuple(tensorlike.shape)
+ inferred_strides = tuple(tensorlike.stride())
+ inferred_dtype = tensorlike.dtype
+ inferred_device = tensorlike.device
+ else:
+ # If no tensorlike "example" is given then all metadata
+ # must be provided explicitly
+ assert shape is not None
+ assert strides is not None
+ assert dtype is not None
+ assert device is not None
+
+ shape = inferred_shape if shape is None else tuple(shape)
+ strides = inferred_strides if strides is None else tuple(strides)
+ dtype = inferred_dtype if dtype is None else dtype
+ device = inferred_device if device is None else device
+
+ if isinstance(device, str):
+ device = torch.device(device)
+
+ if isinstance(tensorlike, FakeTensor):
+ mode = tensorlike.fake_mode
+ else:
+ mode = get_prim_fake_mode()
+
+ if device.type == "meta":
+ return torch.empty_strided(shape, strides, dtype=dtype, device="meta")
+ else:
+ return FakeTensor(
+ mode,
+ torch.empty_strided(shape, strides, dtype=dtype, device="meta"),
+ device,
+ )
+
+
def same_shape(a: ShapeType, b: ShapeType) -> bool:
if len(a) != len(b):
return False
diff --git a/torch/_prims_common/wrappers.py b/torch/_prims/wrappers.py
similarity index 97%
rename from torch/_prims_common/wrappers.py
rename to torch/_prims/wrappers.py
index 542b977..fe3e5f5 100644
--- a/torch/_prims_common/wrappers.py
+++ b/torch/_prims/wrappers.py
@@ -1,12 +1,12 @@
import torch
-from torch._prims_common import (
+from torch._prims.utils import (
Number,
NumberType,
TensorLike,
TensorLikeType,
ELEMENTWISE_TYPE_PROMOTION_KIND,
)
-import torch._prims_common as utils
+import torch._prims.utils as utils
from torch.utils._pytree import tree_flatten
from typing import Callable, Sequence, Union, Tuple, NamedTuple
@@ -20,7 +20,6 @@
def _maybe_convert_to_dtype(
a: Union[TensorLikeType, NumberType, Sequence], dtype: torch.dtype
) -> Union[TensorLikeType, NumberType, Sequence]:
- import torch._prims as prims
if isinstance(a, TensorLike):
if a.dtype != dtype:
# NOTE: this is incorrect on the CPU
@@ -125,7 +124,7 @@
# TODO: handle tuples of tensors
def _maybe_resize_out(out: TensorLikeType, shape):
if out.numel() == 0:
- return out.resize_(shape)
+ return prims.resize(out, shape)
if out.numel() != reduce(operator.mul, shape, 1):
msg = (
@@ -138,7 +137,7 @@
)
)
warnings.warn(msg)
- return out.resize_(shape)
+ return prims.resize(out, shape)
return out
@@ -167,7 +166,7 @@
"but this can't be cast because it is not safe!",
)
- return copy_to.copy_(copy_from)
+ return prims.copy_to(copy_to, copy_from)
def out_wrapper(*out_names: str, exact_dtype: bool = False):
@@ -282,3 +281,7 @@
_fn.__signature__ = sig # type: ignore[attr-defined]
return _fn
+
+
+# avoid mypy import cycle
+import torch._prims as prims
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 86e4d86..0dff4fc 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -1,8 +1,8 @@
import torch
import torch._prims as prims
-import torch._prims_common as utils
-from torch._prims_common import (
+import torch._prims.utils as utils
+from torch._prims.utils import (
check,
DimsType,
ShapeType,
@@ -20,7 +20,7 @@
is_weakly_lesser_type,
dtype_to_type,
)
-from torch._prims_common.wrappers import (
+from torch._prims.wrappers import (
elementwise_type_promotion_wrapper,
out_wrapper,
_maybe_convert_to_dtype,
diff --git a/torch/_refs/fft.py b/torch/_refs/fft.py
index e76cd4c..92996ea 100644
--- a/torch/_refs/fft.py
+++ b/torch/_refs/fft.py
@@ -1,13 +1,13 @@
import torch
import torch._prims as prims
-import torch._prims_common as utils
-from torch._prims_common import (
+import torch._prims.utils as utils
+from torch._prims.utils import (
check,
TensorLikeType,
ShapeType,
DimsType,
)
-from torch._prims_common.wrappers import (
+from torch._prims.wrappers import (
out_wrapper,
)
from torch._decomp import register_decomposition
diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py
index 52a4307..8916e8e 100644
--- a/torch/_refs/linalg/__init__.py
+++ b/torch/_refs/linalg/__init__.py
@@ -3,9 +3,9 @@
import torch._prims as prims
import torch._refs as refs
-from torch._prims_common.wrappers import out_wrapper
+from torch._prims.wrappers import out_wrapper
-from torch._prims_common import (
+from torch._prims.utils import (
check,
check_fp_or_complex,
DimsType,
diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py
index 8d9a0fe..c0f5075 100644
--- a/torch/_refs/nn/functional/__init__.py
+++ b/torch/_refs/nn/functional/__init__.py
@@ -1,8 +1,8 @@
import torch
import torch._prims as prims
-import torch._prims_common as utils
-from torch._prims_common import (
+import torch._prims.utils as utils
+from torch._prims.utils import (
check,
ShapeType,
TensorLike,
@@ -12,7 +12,7 @@
)
import torch._refs as refs
from torch._decomp import register_decomposition
-from torch._prims_common.wrappers import (
+from torch._prims.wrappers import (
elementwise_type_promotion_wrapper,
elementwise_unary_scalar_wrapper,
out_wrapper,
diff --git a/torch/_refs/special/__init__.py b/torch/_refs/special/__init__.py
index 22649b6..4d576af 100644
--- a/torch/_refs/special/__init__.py
+++ b/torch/_refs/special/__init__.py
@@ -3,10 +3,10 @@
from torch import Tensor
from typing import Optional
import torch._prims as prims
-import torch._prims_common as utils
+import torch._prims.utils as utils
import torch._refs as refs
-from torch._prims_common import TensorLikeType, ELEMENTWISE_TYPE_PROMOTION_KIND
-from torch._prims_common.wrappers import out_wrapper, elementwise_type_promotion_wrapper
+from torch._prims.utils import TensorLikeType, ELEMENTWISE_TYPE_PROMOTION_KIND
+from torch._prims.wrappers import out_wrapper, elementwise_type_promotion_wrapper
from torch._refs import (
_make_elementwise_unary_reference,
_make_elementwise_binary_reference,
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index 25cd3f5..2905383 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -237,4 +237,4 @@
# non-Tensor types don't count as hit or miss
return t
-import torch._prims_common as utils
+import torch._prims.utils as utils
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 8809870..17be376 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -6362,7 +6362,7 @@
# NOTE: primTorch is more strict about the type of the fill value argument
# So we must cast it to the correct dtype
- from torch._prims_common import dtype_to_type
+ from torch._prims.utils import dtype_to_type
scalar_type = dtype_to_type(dtype)
def drop_mode_argument(input, pad, mode=None, value=None):