Fix regression in `torch.equal` behavior for NaNs (#111699)
`torch.equal(x, x)` should return false if one of `x` is a tenor of floats one of which is NaN.
So, it renders some of the optimization proposed in https://github.com/pytorch/pytorch/pull/100024 invalid, though as result `torch.equal` will become much slower for identical floating point tensors.
Add regression test that calls torch.equal for tensor containing NaN
Fixes https://github.com/pytorch/pytorch/issues/111251
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111699
Approved by: https://github.com/Skylion007, https://github.com/albanD
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index d84e7ad..adf23e5 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -2096,7 +2096,27 @@
&& self.layout() == other.layout()
&& self.is_neg() == other.is_neg()
&& self.is_conj() == other.is_conj()) {
- return true;
+ if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
+ return true;
+ }
+ std::atomic<bool> result{true};
+ auto iter = TensorIteratorConfig().add_input(self).build();
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "equal_notnan_cpu", [&] {
+ iter.for_each([&](char** data, const int64_t *strides, int64_t dim_size) {
+ if (!result) {
+ return;
+ }
+ char* self_data = data[0];
+ for (C10_UNUSED const auto i : c10::irange(dim_size)) {
+ if (isnan_(c10::load<scalar_t>(self_data))) {
+ result = false;
+ return;
+ }
+ self_data += strides[0];
+ }
+ });
+ });
+ return result.load();
}
std::atomic<bool> result{true};
diff --git a/test/test_torch.py b/test/test_torch.py
index f090b6b..2798fee 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -6351,6 +6351,11 @@
self.assertNotEqual(t_0.size(), t_1.size())
self.assertFalse(torch.equal(t_0, t_1))
+ # Fast path: tensor containing `nan` is not equal to self
+ for dtype in floating_and_complex_types():
+ t = torch.tensor([1., float('nan')], dtype=dtype)
+ self.assertFalse(torch.equal(t, t))
+
def test_element_size(self):
byte = torch.ByteStorage().element_size()
char = torch.CharStorage().element_size()