[TensorExpr] Fix min and max for integral inputs in CUDA backend (#44984)
Summary:
For integral types, isnan is meaningless. Provide specializations for
maximum and minimum which don't call it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44984
Test Plan: python test/test_jit_fuser_te.py -k TestTEFuser.test_minmax_int_ops
Reviewed By: ezyang
Differential Revision: D23885259
Pulled By: asuhan
fbshipit-source-id: 2e6da2c43c0ed18f0b648a2383d510894c574437
diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py
index dc7e67a..6fab650 100644
--- a/test/test_jit_fuser_te.py
+++ b/test/test_jit_fuser_te.py
@@ -537,6 +537,45 @@
)
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ def test_minmax_int_ops(self):
+ def apply(fn):
+ return lambda x, y, z: fn(fn(x, y), z)
+
+ dtypes = [
+ torch.int8,
+ torch.uint8,
+ torch.int16,
+ torch.int32,
+ torch.int64,
+ torch.bool,
+ ]
+ binary_ops = [
+ torch.min,
+ torch.max
+ ]
+ devices = ["cuda"]
+ for dtype, op, device in product(dtypes, binary_ops, devices):
+ try:
+ x = self.data_for(dtype, device)
+ y = self.data_for(dtype, device)
+ z = self.data_for(dtype, device)
+ fn = apply(op)
+ ref = fn(x, y, z)
+ except Exception:
+ # If eager mode doesn't support a dtype/op/device combo,
+ # neither does the fuser. Catch everything to avoid needing to
+ # guess what errors might be thrown by eager.
+ continue
+ try:
+ t = torch.jit.trace(fn, (x, y, z))
+ self.assertEqual(ref, t(x, y, z))
+ self.assertAllFused(t.graph_for(x, y, z))
+ except Exception as e:
+ raise RuntimeError(
+ " ".join(["Failed:", str(dtype), op.__name__, device])
+ )
+
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_comparison_eq_ne(self):
def f(x, y):
mask = (x == 0).type_as(x)
diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
index a64e413..06e6703 100644
--- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
@@ -453,7 +453,11 @@
}
void CudaPrinter::visit(const Max* v) {
- os() << "maximum(";
+ if (is_integral(v->dtype().scalar_type())) {
+ os() << "max(";
+ } else {
+ os() << "maximum(";
+ }
v->lhs()->accept(this);
os() << ",";
v->rhs()->accept(this);
@@ -461,7 +465,11 @@
}
void CudaPrinter::visit(const Min* v) {
- os() << "minimum(";
+ if (is_integral(v->dtype().scalar_type())) {
+ os() << "min(";
+ } else {
+ os() << "minimum(";
+ }
v->lhs()->accept(this);
os() << ",";
v->rhs()->accept(this);
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index 293ea78..833881c 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -816,14 +816,14 @@
case aten::min: {
return computeTwoOperand(
"aten_min", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
- return Min::make(lhs, rhs, false);
+ return Min::make(boolToInteger(lhs), boolToInteger(rhs), false);
});
} break;
case aten::max: {
return computeTwoOperand(
"aten_max", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
- return Max::make(lhs, rhs, false);
+ return Max::make(boolToInteger(lhs), boolToInteger(rhs), false);
});
} break;