[custom ops] disable kernel temporarily (#130190)
Fixes #128621
Sometimes we want to disable the backend implementation for testing/benchmarking purposes.
For example:
```python
@custom_op("mylib::f", mutates_args=())
def f(x: Tensor) -> Tensor:
return torch.zeros(1)
print(f(torch.randn(1))) # tensor([0.])
@f.register_kernel("cpu")
def _(x):
return torch.ones(1)
print(f(torch.randn(1))). # tensor([1.])
with f.set_kernel_enabled("cpu", enabled = False):
print(f(0)) # tensor([0.])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130190
Approved by: https://github.com/williamwen42, https://github.com/zou3519
diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py
index 8214060..5631db0 100644
--- a/test/test_custom_ops.py
+++ b/test/test_custom_ops.py
@@ -3147,6 +3147,55 @@
self.assertEqual(result.device, torch.device("cpu"))
self.assertEqual(result, torch.ones(3))
+ @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
+ def test_set_kernel_enabled(self):
+ x = torch.ones(1)
+
+ @torch.library.custom_op("mylib::f", mutates_args=())
+ def f(x: Tensor) -> Tensor:
+ return x + 1
+
+ self.assertEqual(f(x), x + 1)
+ with self.assertLogs(
+ "torch._library.custom_ops",
+ ) as captured:
+ with f.set_kernel_enabled("gpu", enabled=False):
+ self.assertEqual(f(x), x + 1)
+ self.assertIn(
+ "no kernel was registered for this device type", captured.output[0]
+ )
+
+ @f.register_kernel("cpu")
+ def _(x):
+ return x + 2
+
+ self.assertEqual(f(x), x + 2)
+
+ with self.assertLogs(
+ "torch._library.custom_ops",
+ ) as captured:
+ with f.set_kernel_enabled("cpu", enabled=True):
+ self.assertEqual(f(x), x + 2)
+ self.assertIn("already enabled", captured.output[0])
+
+ with f.set_kernel_enabled("cpu", enabled=False):
+ self.assertEqual(f(x), x + 1)
+
+ with self.assertLogs(
+ "torch._library.custom_ops",
+ ) as captured:
+ with f.set_kernel_enabled("cpu", enabled=False):
+ self.assertEqual(f(x), x + 1)
+ self.assertIn("already disabled", captured.output[0])
+
+ self.assertEqual(f(x), x + 1)
+
+ with f.set_kernel_enabled("cpu", enabled=True):
+ self.assertEqual(f(x), x + 2)
+
+ with f.set_kernel_enabled("cpu", enabled=False):
+ self.assertEqual(f(x), x + 1)
+
class MiniOpTestOther(CustomOpTestCaseBase):
test_ns = "mini_op_test"
diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py
index 9691b28..65139a4 100644
--- a/torch/_library/custom_ops.py
+++ b/torch/_library/custom_ops.py
@@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
import inspect
+import logging
import weakref
+from contextlib import contextmanager
from typing import (
Any,
Callable,
@@ -10,6 +12,7 @@
List,
Optional,
Sequence,
+ Set,
Tuple,
Union,
)
@@ -21,6 +24,7 @@
device_types_t = Optional[Union[str, Sequence[str]]]
+log = logging.getLogger(__name__)
@exposed_in("torch.library")
@@ -178,6 +182,7 @@
self._lib = get_library_allowing_overwrite(self._namespace, self._name)
self._register_to_dispatcher()
+ self._disabled_kernel: Set = set()
OPDEFS[self._qualname] = self
@property
@@ -187,6 +192,55 @@
def __repr__(self) -> str:
return f"<CustomOpDef({self._qualname})>"
+ @contextmanager
+ def set_kernel_enabled(self, device_type: str, enabled: bool = True):
+ """
+ Disable or re-enable an already registered kernel for this custom operator.
+
+ If the kernel is already disabled/enabled, this is a no-op.
+
+ Note:
+ If a kernel is first disabled and then registered, it is disabled until enabled again.
+
+ Args:
+ device_type (str): The device type to disable/enable the kernel for.
+ disable (bool): Whether to disable or enable the kernel.
+ """
+ action = "enable" if enabled else "disable"
+ originally_disabled = device_type in self._disabled_kernel
+ if device_type not in self._backend_fns:
+ log.warning(
+ "Attempted to %s kernel for %s but no kernel was registered for this device type.",
+ action,
+ device_type,
+ )
+
+ if not enabled:
+ if originally_disabled:
+ log.warning(
+ "Attempted to disable kernel for %s but it was already disabled.",
+ device_type,
+ )
+ else:
+ self._disabled_kernel.add(device_type)
+ else: # enable the kernel
+ if not originally_disabled:
+ log.warning(
+ "Attempted to enable kernel for %s but it was already enabled.",
+ device_type,
+ )
+ else:
+ self._disabled_kernel.remove(device_type)
+
+ try:
+ yield
+ finally:
+ # restore original state
+ if originally_disabled:
+ self._disabled_kernel.add(device_type)
+ else:
+ self._disabled_kernel.discard(device_type)
+
def register_kernel(
self, device_types: device_types_t, fn: Optional[Callable] = None, /
) -> Callable:
@@ -275,7 +329,16 @@
backend_impl,
_C._dispatch_key_for_device(device_type),
)
- self._backend_fns[device_type] = fn
+
+ # Wrap function to choose between the default implementation or the device-specific
+ # implementation depending on if the kernel is disabled.
+ def wrapped_fn(*args, **kwargs):
+ if device_type in self._disabled_kernel:
+ return self._init_fn(*args, **kwargs)
+ else:
+ return fn(*args, **kwargs)
+
+ self._backend_fns[device_type] = wrapped_fn
return fn
from torch._library.utils import get_device_arg_index, has_tensor_arg