[inductor] Decompose boolean min/max into all/any (#110311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110311
Approved by: https://github.com/lezcano
ghstack dependencies: #110310
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
index eb0d394..883ad4e 100644
--- a/torch/_inductor/decomposition.py
+++ b/torch/_inductor/decomposition.py
@@ -309,6 +309,20 @@
return torch.where(torch.isnan(other) | (other < self), self, other)
+@register_decomposition(aten.amax)
+def amax(self, dim=None, keepdim=False):
+ if self.dtype == torch.bool:
+ return torch.any(self, dim=dim, keepdim=keepdim)
+ return NotImplemented
+
+
+@register_decomposition(aten.amin)
+def amin(self, dim=None, keepdim=False):
+ if self.dtype == torch.bool:
+ return torch.all(self, dim=dim, keepdim=keepdim)
+ return NotImplemented
+
+
@register_decomposition([aten.narrow_copy])
def narrow_copy(self, dim, start, length):
return torch.narrow(self, dim, start, length).clone()