[TensorExpr] Fix order comparisons for unsigned types (#44857)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44857
Test Plan: test_tensorexpr --gtest_filter=TensorExprTest.LLVMCompareSelectByte*_LLVM
Reviewed By: glaringlee
Differential Revision: D23762162
Pulled By: asuhan
fbshipit-source-id: 1553429bd2d5292ccda57910326b8c70e4e6ab88
diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp
index 71005bc..d461b20 100644
--- a/test/cpp/tensorexpr/test_llvm.cpp
+++ b/test/cpp/tensorexpr/test_llvm.cpp
@@ -908,6 +908,180 @@
assertAllEqual(c_buffer, 1);
}
+void testLLVMCompareSelectByteGT() {
+ KernelScope kernel_scope;
+ constexpr int N = 1024;
+ Buffer a(BufHandle("A", {N}, kByte));
+ Buffer b(BufHandle("B", {N}, kByte));
+ Buffer c(BufHandle("C", {N}, kInt));
+ std::vector<uint8_t> a_buffer(N, 0);
+ std::vector<uint8_t> b_buffer(N, 0);
+ std::vector<int> c_buffer(N, 0);
+ std::vector<int> c_ref(N, 0);
+
+ for (int i = 0; i < N / 2; i++) {
+ a_buffer[i] = 128;
+ c_ref[i] = 1;
+ }
+
+ auto mask = IntImm::make(1);
+ VarHandle i("i", kInt);
+ auto expr = For::make(
+ i,
+ 0,
+ N,
+ Store::make(
+ c,
+ {i},
+ CompareSelect::make(
+ Load::make(a, {i}, mask),
+ Load::make(b, {i}, mask),
+ CompareSelectOperation::kGT),
+ mask));
+
+ LLVMCodeGen cg(expr, {a, b, c});
+
+ std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
+ ASSERT_EQ(cg.value<int>(args), 0);
+
+ ASSERT_EQ(a_buffer.size(), N);
+ ASSERT_EQ(b_buffer.size(), N);
+ ASSERT_EQ(c_buffer.size(), N);
+
+ assertAllEqual(b_buffer, uint8_t(0));
+ for (int i = 0; i < N; i++) {
+ ASSERT_EQ(c_ref[i], c_buffer[i]);
+ }
+}
+
+void testLLVMCompareSelectByteGE() {
+ KernelScope kernel_scope;
+ constexpr int N = 1024;
+ Buffer a(BufHandle("A", {N}, kByte));
+ Buffer b(BufHandle("B", {N}, kByte));
+ Buffer c(BufHandle("C", {N}, kInt));
+ std::vector<uint8_t> a_buffer(N, 0);
+ std::vector<uint8_t> b_buffer(N, 0);
+ std::vector<int> c_buffer(N, 0);
+ std::vector<int> c_ref(N, 1);
+
+ auto mask = IntImm::make(1);
+ VarHandle i("i", kInt);
+ auto expr = For::make(
+ i,
+ 0,
+ N,
+ Store::make(
+ c,
+ {i},
+ CompareSelect::make(
+ Load::make(a, {i}, mask),
+ Load::make(b, {i}, mask),
+ CompareSelectOperation::kGE),
+ mask));
+
+ LLVMCodeGen cg(expr, {a, b, c});
+
+ std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
+ ASSERT_EQ(cg.value<int>(args), 0);
+
+ ASSERT_EQ(a_buffer.size(), N);
+ ASSERT_EQ(b_buffer.size(), N);
+ ASSERT_EQ(c_buffer.size(), N);
+
+ assertAllEqual(b_buffer, uint8_t(0));
+ for (int i = 0; i < N; i++) {
+ ASSERT_EQ(c_ref[i], c_buffer[i]);
+ }
+}
+
+void testLLVMCompareSelectByteLT() {
+ KernelScope kernel_scope;
+ constexpr int N = 1024;
+ Buffer a(BufHandle("A", {N}, kByte));
+ Buffer b(BufHandle("B", {N}, kByte));
+ Buffer c(BufHandle("C", {N}, kInt));
+ std::vector<uint8_t> a_buffer(N, 0);
+ std::vector<uint8_t> b_buffer(N, 128);
+ std::vector<int> c_buffer(N, 0);
+ std::vector<int> c_ref(N, 1);
+
+ for (int i = 0; i < N / 2; i++) {
+ a_buffer[i] = 128;
+ c_ref[i] = 0;
+ }
+
+ auto mask = IntImm::make(1);
+ VarHandle i("i", kInt);
+ auto expr = For::make(
+ i,
+ 0,
+ N,
+ Store::make(
+ c,
+ {i},
+ CompareSelect::make(
+ Load::make(a, {i}, mask),
+ Load::make(b, {i}, mask),
+ CompareSelectOperation::kLT),
+ mask));
+
+ LLVMCodeGen cg(expr, {a, b, c});
+
+ std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
+ ASSERT_EQ(cg.value<int>(args), 0);
+
+ ASSERT_EQ(a_buffer.size(), N);
+ ASSERT_EQ(b_buffer.size(), N);
+ ASSERT_EQ(c_buffer.size(), N);
+
+ assertAllEqual(b_buffer, uint8_t(128));
+ for (int i = 0; i < N; i++) {
+ ASSERT_EQ(c_ref[i], c_buffer[i]);
+ }
+}
+
+void testLLVMCompareSelectByteLE() {
+ KernelScope kernel_scope;
+ constexpr int N = 1024;
+ Buffer a(BufHandle("A", {N}, kByte));
+ Buffer b(BufHandle("B", {N}, kByte));
+ Buffer c(BufHandle("C", {N}, kInt));
+ std::vector<uint8_t> a_buffer(N, 0);
+ std::vector<uint8_t> b_buffer(N, 128);
+ std::vector<int> c_buffer(N, 0);
+ std::vector<int> c_ref(N, 1);
+
+ auto mask = IntImm::make(1);
+ VarHandle i("i", kInt);
+ auto expr = For::make(
+ i,
+ 0,
+ N,
+ Store::make(
+ c,
+ {i},
+ CompareSelect::make(
+ Load::make(a, {i}, mask),
+ Load::make(b, {i}, mask),
+ CompareSelectOperation::kLE),
+ mask));
+
+ LLVMCodeGen cg(expr, {a, b, c});
+
+ std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
+ ASSERT_EQ(cg.value<int>(args), 0);
+
+ ASSERT_EQ(a_buffer.size(), N);
+ ASSERT_EQ(b_buffer.size(), N);
+ ASSERT_EQ(c_buffer.size(), N);
+
+ assertAllEqual(b_buffer, uint8_t(128));
+ for (int i = 0; i < N; i++) {
+ ASSERT_EQ(c_ref[i], c_buffer[i]);
+ }
+}
+
void testLLVMStoreFloat() {
KernelScope kernel_scope;
Buffer result(BufHandle("result", {1}, kFloat));
diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h
index 60f97fd..c56725b 100644
--- a/test/cpp/tensorexpr/tests.h
+++ b/test/cpp/tensorexpr/tests.h
@@ -401,6 +401,10 @@
_(LLVMElemwiseMinNaNFloat) \
_(LLVMCompareSelectIntEQ) \
_(LLVMCompareSelectFloatEQ) \
+ _(LLVMCompareSelectByteGT) \
+ _(LLVMCompareSelectByteGE) \
+ _(LLVMCompareSelectByteLT) \
+ _(LLVMCompareSelectByteLE) \
_(LLVMStoreFloat) \
_(LLVMSimpleMath01) \
_(LLVMComputeMul) \
diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
index 8a487cd..be1bee9 100644
--- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
@@ -31,6 +31,48 @@
namespace torch {
namespace jit {
namespace tensorexpr {
+namespace {
+
+bool is_unsigned_integral(const ScalarType& type) {
+ switch (type) {
+ case ScalarType::Bool:
+ case ScalarType::Byte:
+ return true;
+ default:
+ return false;
+ }
+
+ return false;
+}
+
+llvm::CmpInst::Predicate llvm_comparison_predicate(
+ CompareSelectOperation compare_op,
+ const ScalarType& type) {
+ switch (compare_op) {
+ case CompareSelectOperation::kEQ:
+ return llvm::ICmpInst::ICMP_EQ;
+ case CompareSelectOperation::kNE:
+ return llvm::ICmpInst::ICMP_NE;
+ case CompareSelectOperation::kGT:
+ return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_UGT
+ : llvm::ICmpInst::ICMP_SGT;
+ case CompareSelectOperation::kGE:
+ return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_UGE
+ : llvm::ICmpInst::ICMP_SGE;
+ case CompareSelectOperation::kLT:
+ return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_ULT
+ : llvm::ICmpInst::ICMP_SLT;
+ case CompareSelectOperation::kLE:
+ return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_ULE
+ : llvm::ICmpInst::ICMP_SLE;
+ default:
+ // TODO: change to a proper error report
+ throw std::runtime_error("invalid operator type");
+ }
+}
+
+} // namespace
+
class LLVMCodeGenImpl : public IRVisitor {
private:
llvm::orc::ThreadSafeContext context_;
@@ -671,29 +713,8 @@
CompareSelectOperation cmp_op_ = v->compare_select_op();
if (is_integral(type_used)) {
- switch (cmp_op_) {
- case CompareSelectOperation::kEQ:
- cmp_ = irb_.CreateICmpEQ(lhs, rhs);
- break;
- case CompareSelectOperation::kNE:
- cmp_ = irb_.CreateICmpNE(lhs, rhs);
- break;
- case CompareSelectOperation::kGT:
- cmp_ = irb_.CreateICmpSGT(lhs, rhs);
- break;
- case CompareSelectOperation::kGE:
- cmp_ = irb_.CreateICmpSGE(lhs, rhs);
- break;
- case CompareSelectOperation::kLT:
- cmp_ = irb_.CreateICmpSLT(lhs, rhs);
- break;
- case CompareSelectOperation::kLE:
- cmp_ = irb_.CreateICmpSLE(lhs, rhs);
- break;
- default:
- // TODO: change to a proper error report
- throw std::runtime_error("invalid operator type");
- }
+ cmp_ = irb_.CreateICmp(
+ llvm_comparison_predicate(cmp_op_, type_used), lhs, rhs);
} else if (is_floating_point(type_used)) { // FP32
switch (cmp_op_) {
case CompareSelectOperation::kEQ: