[composite compliance] istft (#82955)
Ref #69991
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82955
Approved by: https://github.com/zou3519
diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp
index d638960..627a96c 100644
--- a/aten/src/ATen/native/SpectralOps.cpp
+++ b/aten/src/ATen/native/SpectralOps.cpp
@@ -1100,8 +1100,8 @@
y = y.slice(2, start, end, 1);
window_envelop = window_envelop.slice(2, start, end, 1);
- const auto window_envelop_lowest = window_envelop.abs().min().item().toDouble();
- if (window_envelop_lowest < 1e-11) {
+ const auto window_envelop_lowest = window_envelop.abs().min().lt(1e-11);
+ if (at::equal(window_envelop_lowest, window_envelop_lowest.new_ones({}))) {
std::ostringstream ss;
REPR(ss) << "window overlap add min: " << window_envelop_lowest;
AT_ERROR(ss.str());
@@ -1121,7 +1121,7 @@
}
return y;
- #undef REPR
+#undef REPR
}
Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 87174ba..d0f96f7 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -10344,10 +10344,6 @@
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_grad'),
# Pre-existing condition (calls .item); needs to be fixed
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
- # Pre-existing condition (calls .item); needs to be fixed
- DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
- # Pre-existing condition (calls .item); needs to be fixed
- DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
)),
UnaryUfuncInfo('floor',
ref=np.floor,