Excise uses of the old custom ops APIs (#124134)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124134
Approved by: https://github.com/albanD
ghstack dependencies: #124180, #124200, #124299
diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py
index c9824fb..3e7e296 100644
--- a/test/inductor/test_torchinductor_dynamic_shapes.py
+++ b/test/inductor/test_torchinductor_dynamic_shapes.py
@@ -9,7 +9,6 @@
from functools import partial
import torch
-import torch._custom_ops as custom_ops
import torch.library
from torch._dynamo.testing import make_test_cls_with_patches
from torch._inductor.codegen.common import device_codegens, register_backend_for_device
@@ -280,30 +279,20 @@
@torch._dynamo.config.patch(capture_scalar_outputs=True)
@torch._inductor.config.patch(implicit_fallbacks=True)
def test_item_to_inputs_kernel_nobreak(self, device):
- with torch.library._scoped_library("test", "DEF") as lib:
- try:
+ @torch.library.custom_op("test::foo", mutates_args=())
+ def foo(x: torch.Tensor, y: int) -> torch.Tensor:
+ return x.clone()
- @custom_ops.custom_op("test::foo")
- def foo(x: torch.Tensor, y: int) -> torch.Tensor:
- raise NotImplementedError
+ @foo.register_fake
+ def _(x: torch.Tensor, y: int) -> torch.Tensor:
+ return x.clone()
- @custom_ops.impl("test::foo")
- def foo_impl(x: torch.Tensor, y: int) -> torch.Tensor:
- return x.clone()
+ @torch.compile(fullgraph=True)
+ def f(x, r):
+ y = x.item()
+ return torch.ops.test.foo(r, y)
- @torch.library.impl_abstract("test::foo", lib=lib)
- def foo_meta(x: torch.Tensor, y: int) -> torch.Tensor:
- return x.clone()
-
- @torch.compile(fullgraph=True)
- def f(x, r):
- y = x.item()
- return torch.ops.test.foo(r, y)
-
- f(torch.tensor([3], device=device), torch.randn(10, device=device))
-
- finally:
- custom_ops._destroy("test::foo")
+ f(torch.tensor([3], device=device), torch.randn(10, device=device))
@torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
@@ -396,34 +385,24 @@
)
@torch._inductor.config.patch(implicit_fallbacks=True)
def test_dynamic_stride_nobreak(self, device):
- with torch.library._scoped_library("test", "DEF") as lib:
- try:
+ @torch.library.custom_op("test::foo", mutates_args=())
+ def foo(x: torch.Tensor) -> torch.Tensor:
+ stride = x.item()
+ return torch.empty_strided((1,), (stride,), device=x.device)
- @custom_ops.custom_op("test::foo")
- def foo(x: torch.Tensor) -> torch.Tensor:
- raise NotImplementedError
+ @foo.register_fake
+ def _(x: torch.Tensor) -> torch.Tensor:
+ ctx = torch.library.get_ctx()
+ stride = ctx.new_dynamic_size()
+ return torch.empty_strided((1,), (stride,), device=x.device)
- @custom_ops.impl("test::foo")
- def foo_impl(x: torch.Tensor) -> torch.Tensor:
- stride = x.item()
- return torch.empty_strided((1,), (stride,), device=x.device)
+ @torch.compile(fullgraph=True)
+ def f(x):
+ r = torch.ops.test.foo(x)
+ y = r.stride(0)
+ return torch.empty(y, device=x.device)
- @torch.library.impl_abstract("test::foo", lib=lib)
- def foo_meta(x: torch.Tensor) -> torch.Tensor:
- ctx = torch.library.get_ctx()
- stride = ctx.new_dynamic_size()
- return torch.empty_strided((1,), (stride,), device=x.device)
-
- @torch.compile(fullgraph=True)
- def f(x):
- r = torch.ops.test.foo(x)
- y = r.stride(0)
- return torch.empty(y, device=x.device)
-
- f(torch.tensor([3], device=device))
-
- finally:
- custom_ops._destroy("test::foo")
+ f(torch.tensor([3], device=device))
@torch._inductor.config.patch(disable_cpp_codegen=True)
def test_floor(self):
diff --git a/test/onnx/test_fx_passes.py b/test/onnx/test_fx_passes.py
index 00fe67b..9ebbf11 100644
--- a/test/onnx/test_fx_passes.py
+++ b/test/onnx/test_fx_passes.py
@@ -3,7 +3,6 @@
import torch._dynamo
import torch.fx
-from torch._custom_op import impl as custom_op
from torch.onnx._internal.fx.passes import _utils as pass_utils
from torch.testing._internal import common_utils
@@ -58,32 +57,25 @@
), f"Expected all names to be unique, got {nodes}"
def test_onnx_dynamo_export_raises_when_model_contains_unsupported_fx_nodes(self):
- @custom_op.custom_op("mylibrary::foo_op")
+ @torch.library.custom_op(
+ "mylibrary::foo_op", device_types="cpu", mutates_args=()
+ )
def foo_op(x: torch.Tensor) -> torch.Tensor:
- ...
-
- @custom_op.custom_op("mylibrary::bar_op")
- def bar_op(x: torch.Tensor) -> torch.Tensor:
- ...
-
- @foo_op.impl_abstract()
- def foo_op_impl_abstract(x):
- return torch.empty_like(x)
-
- @foo_op.impl("cpu")
- def foo_op_impl(x):
return x + 1
- @bar_op.impl_abstract()
- def bar_op_impl_abstract(x):
- return torch.empty_like(x)
-
- @bar_op.impl("cpu")
- def bar_op_impl(x):
+ @torch.library.custom_op(
+ "mylibrary::bar_op", device_types="cpu", mutates_args=()
+ )
+ def bar_op(x: torch.Tensor) -> torch.Tensor:
return x + 2
- torch._dynamo.allow_in_graph(foo_op)
- torch._dynamo.allow_in_graph(bar_op)
+ @foo_op.register_fake
+ def _(x):
+ return torch.empty_like(x)
+
+ @bar_op.register_fake
+ def _(x):
+ return torch.empty_like(x)
def func(x, y, z):
return foo_op(x) + bar_op(y) + z
diff --git a/torch/_prims/debug_prims.py b/torch/_prims/debug_prims.py
index d4d7a0c..ea3854d 100644
--- a/torch/_prims/debug_prims.py
+++ b/torch/_prims/debug_prims.py
@@ -1,8 +1,7 @@
import contextlib
-from typing import Optional, Sequence
+from typing import Optional
import torch
-from torch._custom_op.impl import custom_op
from torch.utils._content_store import ContentStoreReader
LOAD_TENSOR_READER: Optional[ContentStoreReader] = None
@@ -26,18 +25,12 @@
def register_debug_prims():
- @custom_op("debugprims::load_tensor")
- def load_tensor( # type: ignore[empty-body]
- name: str,
- size: Sequence[int],
- stride: Sequence[int],
- *,
- dtype: torch.dtype,
- device: torch.device,
- ) -> torch.Tensor:
- ...
+ torch.library.define(
+ "debugprims::load_tensor",
+ "(str name, int[] size, int[] stride, *, ScalarType dtype, Device device) -> Tensor",
+ )
- @load_tensor.impl_factory()
+ @torch.library.impl("debugprims::load_tensor", "BackendSelect")
def load_tensor_factory(name, size, stride, dtype, device):
if LOAD_TENSOR_READER is None:
from torch._dynamo.testing import rand_strided
diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py
index 9bd6d25..f5f830c 100644
--- a/torch/utils/_python_dispatch.py
+++ b/torch/utils/_python_dispatch.py
@@ -6,6 +6,7 @@
import torch
import torchgen
+import torchgen.model
from torch._C import (
_get_dispatch_stack_at,
_len_torch_dispatch_stack,