[custom ops] disable kernel temporarily  (#130190)

Fixes #128621

Sometimes we want to disable the backend implementation for testing/benchmarking purposes.

For example:

```python
@custom_op("mylib::f", mutates_args=())
def f(x: Tensor) -> Tensor:
    return torch.zeros(1)

print(f(torch.randn(1))) # tensor([0.])

@f.register_kernel("cpu")
def _(x):
    return torch.ones(1)

print(f(torch.randn(1))). # tensor([1.])

with f.set_kernel_enabled("cpu", enabled = False):
    print(f(0)) # tensor([0.])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130190
Approved by: https://github.com/williamwen42, https://github.com/zou3519
diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py
index 8214060..5631db0 100644
--- a/test/test_custom_ops.py
+++ b/test/test_custom_ops.py
@@ -3147,6 +3147,55 @@
         self.assertEqual(result.device, torch.device("cpu"))
         self.assertEqual(result, torch.ones(3))
 
+    @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
+    def test_set_kernel_enabled(self):
+        x = torch.ones(1)
+
+        @torch.library.custom_op("mylib::f", mutates_args=())
+        def f(x: Tensor) -> Tensor:
+            return x + 1
+
+        self.assertEqual(f(x), x + 1)
+        with self.assertLogs(
+            "torch._library.custom_ops",
+        ) as captured:
+            with f.set_kernel_enabled("gpu", enabled=False):
+                self.assertEqual(f(x), x + 1)
+            self.assertIn(
+                "no kernel was registered for this device type", captured.output[0]
+            )
+
+        @f.register_kernel("cpu")
+        def _(x):
+            return x + 2
+
+        self.assertEqual(f(x), x + 2)
+
+        with self.assertLogs(
+            "torch._library.custom_ops",
+        ) as captured:
+            with f.set_kernel_enabled("cpu", enabled=True):
+                self.assertEqual(f(x), x + 2)
+            self.assertIn("already enabled", captured.output[0])
+
+        with f.set_kernel_enabled("cpu", enabled=False):
+            self.assertEqual(f(x), x + 1)
+
+            with self.assertLogs(
+                "torch._library.custom_ops",
+            ) as captured:
+                with f.set_kernel_enabled("cpu", enabled=False):
+                    self.assertEqual(f(x), x + 1)
+                self.assertIn("already disabled", captured.output[0])
+
+            self.assertEqual(f(x), x + 1)
+
+        with f.set_kernel_enabled("cpu", enabled=True):
+            self.assertEqual(f(x), x + 2)
+
+        with f.set_kernel_enabled("cpu", enabled=False):
+            self.assertEqual(f(x), x + 1)
+
 
 class MiniOpTestOther(CustomOpTestCaseBase):
     test_ns = "mini_op_test"
diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py
index 9691b28..65139a4 100644
--- a/torch/_library/custom_ops.py
+++ b/torch/_library/custom_ops.py
@@ -1,6 +1,8 @@
 # mypy: allow-untyped-defs
 import inspect
+import logging
 import weakref
+from contextlib import contextmanager
 from typing import (
     Any,
     Callable,
@@ -10,6 +12,7 @@
     List,
     Optional,
     Sequence,
+    Set,
     Tuple,
     Union,
 )
@@ -21,6 +24,7 @@
 
 
 device_types_t = Optional[Union[str, Sequence[str]]]
+log = logging.getLogger(__name__)
 
 
 @exposed_in("torch.library")
@@ -178,6 +182,7 @@
 
         self._lib = get_library_allowing_overwrite(self._namespace, self._name)
         self._register_to_dispatcher()
+        self._disabled_kernel: Set = set()
         OPDEFS[self._qualname] = self
 
     @property
@@ -187,6 +192,55 @@
     def __repr__(self) -> str:
         return f"<CustomOpDef({self._qualname})>"
 
+    @contextmanager
+    def set_kernel_enabled(self, device_type: str, enabled: bool = True):
+        """
+        Disable or re-enable an already registered kernel for this custom operator.
+
+        If the kernel is already disabled/enabled, this is a no-op.
+
+        Note:
+            If a kernel is first disabled and then registered, it is disabled until enabled again.
+
+        Args:
+            device_type (str): The device type to disable/enable the kernel for.
+            disable (bool): Whether to disable or enable the kernel.
+        """
+        action = "enable" if enabled else "disable"
+        originally_disabled = device_type in self._disabled_kernel
+        if device_type not in self._backend_fns:
+            log.warning(
+                "Attempted to %s kernel for %s but no kernel was registered for this device type.",
+                action,
+                device_type,
+            )
+
+        if not enabled:
+            if originally_disabled:
+                log.warning(
+                    "Attempted to disable kernel for %s but it was already disabled.",
+                    device_type,
+                )
+            else:
+                self._disabled_kernel.add(device_type)
+        else:  # enable the kernel
+            if not originally_disabled:
+                log.warning(
+                    "Attempted to enable kernel for  %s but it was already enabled.",
+                    device_type,
+                )
+            else:
+                self._disabled_kernel.remove(device_type)
+
+        try:
+            yield
+        finally:
+            # restore original state
+            if originally_disabled:
+                self._disabled_kernel.add(device_type)
+            else:
+                self._disabled_kernel.discard(device_type)
+
     def register_kernel(
         self, device_types: device_types_t, fn: Optional[Callable] = None, /
     ) -> Callable:
@@ -275,7 +329,16 @@
                             backend_impl,
                             _C._dispatch_key_for_device(device_type),
                         )
-                self._backend_fns[device_type] = fn
+
+                # Wrap function to choose between the default implementation or the device-specific
+                # implementation depending on if the kernel is disabled.
+                def wrapped_fn(*args, **kwargs):
+                    if device_type in self._disabled_kernel:
+                        return self._init_fn(*args, **kwargs)
+                    else:
+                        return fn(*args, **kwargs)
+
+                self._backend_fns[device_type] = wrapped_fn
             return fn
 
         from torch._library.utils import get_device_arg_index, has_tensor_arg