remove backend keys from FunctionalTensorWrapper, update TensorImpl::is_device methods (#81471)
It's kinda annoying to have wrapper subclass tensors (like `FunctionalTensorWrapper` include backend dispatch keys in their keyset, because when we occasionally write something buggy, we'll send the wrapper tensor the the backend kernel (which usually segfaults). By ensuring that wrapper tensors don't get backend keys, we'll get a nicer error when that happens.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81471
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp
index f08690a..a66a42b 100644
--- a/aten/src/ATen/FunctionalTensorWrapper.cpp
+++ b/aten/src/ATen/FunctionalTensorWrapper.cpp
@@ -30,6 +30,19 @@
// Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect
// to participate in the functorch transforms.
key_set_ = key_set_ - c10::functorch_transforms_ks - c10::python_ks;
+ // For better error handling,
+ // we also don't want our wrapper tensor to be able to dispatch directly
+ // to a backend kernel.
+ // Dispatching directly to e.g. a CPU kernel would always segfault,
+ // because wrapper tensors don't have any real data.
+ // (This should never happen because we should always hit a functionalization kernel,
+ // but can help make bugs less nasty).
+ // Here, we defensively remove any backend keys from the wrapper's keyset.
+ // We don't want to remove actual backend bits though (say we're redispatching to autograd;
+ // we need to know if we're dispatching to AutogradCPU or AutogradXLA).
+ // Instead, it's sufficient to remove the `Dense` dispatch key,
+ // which prevents us from accidentally trying to directly run a CPU/CUDA kernel.
+ key_set_ = key_set_.remove(c10::DispatchKey::Dense);
}
FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value)
diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h
index 26addb8..6a56dec 100644
--- a/c10/core/TensorImpl.h
+++ b/c10/core/TensorImpl.h
@@ -837,8 +837,7 @@
if (C10_UNLIKELY(custom_device_)) {
return device_custom().is_meta();
}
- constexpr auto meta_ks = DispatchKeySet(BackendComponent::MetaBit);
- return key_set_.has_all(meta_ks);
+ return device_opt_.has_value() && device_opt_->type() == kMeta;
}
bool is_cpu() const {
@@ -847,9 +846,10 @@
if (C10_UNLIKELY(custom_device_)) {
return device_custom().is_cpu();
}
- constexpr auto cpu_bits_ks = DispatchKeySet(BackendComponent::CPUBit) |
- DispatchKeySet({DispatchKey::SparseCsrCPU, DispatchKey::MkldnnCPU});
- return key_set_.has_any(cpu_bits_ks);
+ // Note: we cannot rely on dispatch keys to determine the device type
+ // of a tensor, because "wrapper" tensors (like FunctionalTensorWrapper)
+ // don't include backend dispatch keys.
+ return device_opt_.has_value() && device_opt_->type() == kCPU;
}
bool is_cuda() const {
@@ -858,9 +858,7 @@
if (C10_UNLIKELY(custom_device_)) {
return device_custom().is_cuda();
}
- constexpr auto cuda_bits_ks = DispatchKeySet(BackendComponent::CUDABit) |
- DispatchKeySet(DispatchKey::SparseCsrCUDA);
- return key_set_.has_any(cuda_bits_ks);
+ return device_opt_.has_value() && device_opt_->type() == kCUDA;
}
bool is_xpu() const {
@@ -869,40 +867,35 @@
if (C10_UNLIKELY(custom_device_)) {
return device_custom().is_xpu();
}
- constexpr auto xpu_ks = DispatchKeySet(BackendComponent::XPUBit);
- return key_set_.has_all(xpu_ks);
+ return device_opt_.has_value() && device_opt_->type() == kXPU;
}
bool is_ipu() const {
if (C10_UNLIKELY(custom_device_)) {
return device_custom().is_ipu();
}
- constexpr auto ipu_ks = DispatchKeySet(BackendComponent::IPUBit);
- return key_set_.has_all(ipu_ks);
+ return device_opt_.has_value() && device_opt_->type() == kIPU;
}
bool is_xla() const {
if (C10_UNLIKELY(custom_device_)) {
return device_custom().is_xla();
}
- constexpr auto xla_ks = DispatchKeySet(BackendComponent::XLABit);
- return key_set_.has_all(xla_ks);
+ return device_opt_.has_value() && device_opt_->type() == kXLA;
}
bool is_hpu() const {
if (C10_UNLIKELY(custom_device_)) {
return device_custom().is_hpu();
}
- constexpr auto hpu_ks = DispatchKeySet(BackendComponent::HPUBit);
- return key_set_.has_all(hpu_ks);
+ return device_opt_.has_value() && device_opt_->type() == kHPU;
}
bool is_lazy() const {
if (C10_UNLIKELY(custom_device_)) {
return device_custom().is_lazy();
}
- constexpr auto lazy_ks = DispatchKeySet(BackendComponent::LazyBit);
- return key_set_.has_all(lazy_ks);
+ return device_opt_.has_value() && device_opt_->type() == kLazy;
}
bool is_hip() const {
@@ -911,8 +904,7 @@
if (C10_UNLIKELY(custom_device_)) {
return device_custom().is_hip();
}
- constexpr auto hip_ks = DispatchKeySet(BackendComponent::HIPBit);
- return key_set_.has_all(hip_ks);
+ return device_opt_.has_value() && device_opt_->type() == kHIP;
}
bool is_ve() const {
@@ -921,8 +913,7 @@
if (C10_UNLIKELY(custom_device_)) {
return device_custom().is_ve();
}
- constexpr auto ve_ks = DispatchKeySet(BackendComponent::VEBit);
- return key_set_.has_all(ve_ks);
+ return device_opt_.has_value() && device_opt_->type() == kVE;
}
bool is_mkldnn() const {
@@ -933,31 +924,28 @@
if (C10_UNLIKELY(custom_device_)) {
return device_custom().is_vulkan();
}
- constexpr auto vulkan_ks = DispatchKeySet(DispatchKey::Vulkan);
- return key_set_.has_all(vulkan_ks);
+ return device_opt_.has_value() && device_opt_->type() == kVulkan;
}
bool is_metal() const {
if (C10_UNLIKELY(custom_device_)) {
return device_custom().is_metal();
}
- constexpr auto metal_ks = DispatchKeySet(DispatchKey::Metal);
- return key_set_.has_all(metal_ks);
+ return device_opt_.has_value() && device_opt_->type() == kMetal;
}
bool is_mps() const {
if (C10_UNLIKELY(custom_device_)) {
return device_custom().is_mps();
}
- return key_set_.has(DispatchKey::MPS);
+ return device_opt_.has_value() && device_opt_->type() == kMPS;
}
bool is_ort() const {
if (C10_UNLIKELY(custom_device_)) {
return device_custom().is_ort();
}
- constexpr auto ort_ks = DispatchKeySet(DispatchKey::ORT);
- return key_set_.has_all(ort_ks);
+ return device_opt_.has_value() && device_opt_->type() == kORT;
}
bool is_nested() const {
diff --git a/test/test_functionalization.py b/test/test_functionalization.py
index 98385c5..b8d29a0 100644
--- a/test/test_functionalization.py
+++ b/test/test_functionalization.py
@@ -2,12 +2,11 @@
import torch
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO
-from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, capture_logs
+from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs
from torch.utils._pytree import tree_map
from torch.fx.experimental.proxy_tensor import make_fx
import unittest
-import logging
def are_aliased(x, y):
if x._base is None and y._base is None:
@@ -18,46 +17,6 @@
return y._base is x
return x._base is y._base
-# Just for testing: a logging tensor that also transforms out-of-place ops into inplace ops.
-# That way even if the outer wrapper is functionalized, the inner wrapper will also need functionalization.
-class InplaceLoggingTensor(LoggingTensorReentrant):
- @staticmethod
- def __new__(cls, e):
- r = torch.Tensor._make_wrapper_subclass(cls, e.shape, dtype=e.dtype, requires_grad=False)
- r.elem = e
- return r
-
- __torch_function__ = torch._C._disabled_torch_function_impl
-
- def __str__(self):
- return f'InplaceLoggingTensor({self.elem})'
-
- @classmethod
- def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
- def unwrap(e):
- if isinstance(e, InplaceLoggingTensor):
- return e.elem
- else:
- return e
-
- def wrap(e):
- if isinstance(e, torch.Tensor):
- return InplaceLoggingTensor(e)
- else:
- return e
- f = func
- # this subclass converts all `add()` ops into `add_()` ops
- if f is torch.ops.aten.add.Tensor:
- f = torch.ops.aten.add_.Tensor
-
- with cls.context():
- rs = tree_map(wrap, f(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
- # after running the (potentially transformed) op,
- # log the original op that we saw.
- logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
- return rs
-
-
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457")
class TestFunctionalization(TestCase):