[TensorExpr] Fix order comparisons for unsigned types (#44857)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44857

Test Plan: test_tensorexpr --gtest_filter=TensorExprTest.LLVMCompareSelectByte*_LLVM

Reviewed By: glaringlee

Differential Revision: D23762162

Pulled By: asuhan

fbshipit-source-id: 1553429bd2d5292ccda57910326b8c70e4e6ab88
diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp
index 71005bc..d461b20 100644
--- a/test/cpp/tensorexpr/test_llvm.cpp
+++ b/test/cpp/tensorexpr/test_llvm.cpp
@@ -908,6 +908,180 @@
   assertAllEqual(c_buffer, 1);
 }
 
+void testLLVMCompareSelectByteGT() {
+  KernelScope kernel_scope;
+  constexpr int N = 1024;
+  Buffer a(BufHandle("A", {N}, kByte));
+  Buffer b(BufHandle("B", {N}, kByte));
+  Buffer c(BufHandle("C", {N}, kInt));
+  std::vector<uint8_t> a_buffer(N, 0);
+  std::vector<uint8_t> b_buffer(N, 0);
+  std::vector<int> c_buffer(N, 0);
+  std::vector<int> c_ref(N, 0);
+
+  for (int i = 0; i < N / 2; i++) {
+    a_buffer[i] = 128;
+    c_ref[i] = 1;
+  }
+
+  auto mask = IntImm::make(1);
+  VarHandle i("i", kInt);
+  auto expr = For::make(
+      i,
+      0,
+      N,
+      Store::make(
+          c,
+          {i},
+          CompareSelect::make(
+              Load::make(a, {i}, mask),
+              Load::make(b, {i}, mask),
+              CompareSelectOperation::kGT),
+          mask));
+
+  LLVMCodeGen cg(expr, {a, b, c});
+
+  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
+  ASSERT_EQ(cg.value<int>(args), 0);
+
+  ASSERT_EQ(a_buffer.size(), N);
+  ASSERT_EQ(b_buffer.size(), N);
+  ASSERT_EQ(c_buffer.size(), N);
+
+  assertAllEqual(b_buffer, uint8_t(0));
+  for (int i = 0; i < N; i++) {
+    ASSERT_EQ(c_ref[i], c_buffer[i]);
+  }
+}
+
+void testLLVMCompareSelectByteGE() {
+  KernelScope kernel_scope;
+  constexpr int N = 1024;
+  Buffer a(BufHandle("A", {N}, kByte));
+  Buffer b(BufHandle("B", {N}, kByte));
+  Buffer c(BufHandle("C", {N}, kInt));
+  std::vector<uint8_t> a_buffer(N, 0);
+  std::vector<uint8_t> b_buffer(N, 0);
+  std::vector<int> c_buffer(N, 0);
+  std::vector<int> c_ref(N, 1);
+
+  auto mask = IntImm::make(1);
+  VarHandle i("i", kInt);
+  auto expr = For::make(
+      i,
+      0,
+      N,
+      Store::make(
+          c,
+          {i},
+          CompareSelect::make(
+              Load::make(a, {i}, mask),
+              Load::make(b, {i}, mask),
+              CompareSelectOperation::kGE),
+          mask));
+
+  LLVMCodeGen cg(expr, {a, b, c});
+
+  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
+  ASSERT_EQ(cg.value<int>(args), 0);
+
+  ASSERT_EQ(a_buffer.size(), N);
+  ASSERT_EQ(b_buffer.size(), N);
+  ASSERT_EQ(c_buffer.size(), N);
+
+  assertAllEqual(b_buffer, uint8_t(0));
+  for (int i = 0; i < N; i++) {
+    ASSERT_EQ(c_ref[i], c_buffer[i]);
+  }
+}
+
+void testLLVMCompareSelectByteLT() {
+  KernelScope kernel_scope;
+  constexpr int N = 1024;
+  Buffer a(BufHandle("A", {N}, kByte));
+  Buffer b(BufHandle("B", {N}, kByte));
+  Buffer c(BufHandle("C", {N}, kInt));
+  std::vector<uint8_t> a_buffer(N, 0);
+  std::vector<uint8_t> b_buffer(N, 128);
+  std::vector<int> c_buffer(N, 0);
+  std::vector<int> c_ref(N, 1);
+
+  for (int i = 0; i < N / 2; i++) {
+    a_buffer[i] = 128;
+    c_ref[i] = 0;
+  }
+
+  auto mask = IntImm::make(1);
+  VarHandle i("i", kInt);
+  auto expr = For::make(
+      i,
+      0,
+      N,
+      Store::make(
+          c,
+          {i},
+          CompareSelect::make(
+              Load::make(a, {i}, mask),
+              Load::make(b, {i}, mask),
+              CompareSelectOperation::kLT),
+          mask));
+
+  LLVMCodeGen cg(expr, {a, b, c});
+
+  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
+  ASSERT_EQ(cg.value<int>(args), 0);
+
+  ASSERT_EQ(a_buffer.size(), N);
+  ASSERT_EQ(b_buffer.size(), N);
+  ASSERT_EQ(c_buffer.size(), N);
+
+  assertAllEqual(b_buffer, uint8_t(128));
+  for (int i = 0; i < N; i++) {
+    ASSERT_EQ(c_ref[i], c_buffer[i]);
+  }
+}
+
+void testLLVMCompareSelectByteLE() {
+  KernelScope kernel_scope;
+  constexpr int N = 1024;
+  Buffer a(BufHandle("A", {N}, kByte));
+  Buffer b(BufHandle("B", {N}, kByte));
+  Buffer c(BufHandle("C", {N}, kInt));
+  std::vector<uint8_t> a_buffer(N, 0);
+  std::vector<uint8_t> b_buffer(N, 128);
+  std::vector<int> c_buffer(N, 0);
+  std::vector<int> c_ref(N, 1);
+
+  auto mask = IntImm::make(1);
+  VarHandle i("i", kInt);
+  auto expr = For::make(
+      i,
+      0,
+      N,
+      Store::make(
+          c,
+          {i},
+          CompareSelect::make(
+              Load::make(a, {i}, mask),
+              Load::make(b, {i}, mask),
+              CompareSelectOperation::kLE),
+          mask));
+
+  LLVMCodeGen cg(expr, {a, b, c});
+
+  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
+  ASSERT_EQ(cg.value<int>(args), 0);
+
+  ASSERT_EQ(a_buffer.size(), N);
+  ASSERT_EQ(b_buffer.size(), N);
+  ASSERT_EQ(c_buffer.size(), N);
+
+  assertAllEqual(b_buffer, uint8_t(128));
+  for (int i = 0; i < N; i++) {
+    ASSERT_EQ(c_ref[i], c_buffer[i]);
+  }
+}
+
 void testLLVMStoreFloat() {
   KernelScope kernel_scope;
   Buffer result(BufHandle("result", {1}, kFloat));
diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h
index 60f97fd..c56725b 100644
--- a/test/cpp/tensorexpr/tests.h
+++ b/test/cpp/tensorexpr/tests.h
@@ -401,6 +401,10 @@
   _(LLVMElemwiseMinNaNFloat)               \
   _(LLVMCompareSelectIntEQ)                \
   _(LLVMCompareSelectFloatEQ)              \
+  _(LLVMCompareSelectByteGT)               \
+  _(LLVMCompareSelectByteGE)               \
+  _(LLVMCompareSelectByteLT)               \
+  _(LLVMCompareSelectByteLE)               \
   _(LLVMStoreFloat)                        \
   _(LLVMSimpleMath01)                      \
   _(LLVMComputeMul)                        \
diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
index 8a487cd..be1bee9 100644
--- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
@@ -31,6 +31,48 @@
 namespace torch {
 namespace jit {
 namespace tensorexpr {
+namespace {
+
+bool is_unsigned_integral(const ScalarType& type) {
+  switch (type) {
+    case ScalarType::Bool:
+    case ScalarType::Byte:
+      return true;
+    default:
+      return false;
+  }
+
+  return false;
+}
+
+llvm::CmpInst::Predicate llvm_comparison_predicate(
+    CompareSelectOperation compare_op,
+    const ScalarType& type) {
+  switch (compare_op) {
+    case CompareSelectOperation::kEQ:
+      return llvm::ICmpInst::ICMP_EQ;
+    case CompareSelectOperation::kNE:
+      return llvm::ICmpInst::ICMP_NE;
+    case CompareSelectOperation::kGT:
+      return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_UGT
+                                        : llvm::ICmpInst::ICMP_SGT;
+    case CompareSelectOperation::kGE:
+      return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_UGE
+                                        : llvm::ICmpInst::ICMP_SGE;
+    case CompareSelectOperation::kLT:
+      return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_ULT
+                                        : llvm::ICmpInst::ICMP_SLT;
+    case CompareSelectOperation::kLE:
+      return is_unsigned_integral(type) ? llvm::ICmpInst::ICMP_ULE
+                                        : llvm::ICmpInst::ICMP_SLE;
+    default:
+      // TODO: change to a proper error report
+      throw std::runtime_error("invalid operator type");
+  }
+}
+
+} // namespace
+
 class LLVMCodeGenImpl : public IRVisitor {
  private:
   llvm::orc::ThreadSafeContext context_;
@@ -671,29 +713,8 @@
   CompareSelectOperation cmp_op_ = v->compare_select_op();
 
   if (is_integral(type_used)) {
-    switch (cmp_op_) {
-      case CompareSelectOperation::kEQ:
-        cmp_ = irb_.CreateICmpEQ(lhs, rhs);
-        break;
-      case CompareSelectOperation::kNE:
-        cmp_ = irb_.CreateICmpNE(lhs, rhs);
-        break;
-      case CompareSelectOperation::kGT:
-        cmp_ = irb_.CreateICmpSGT(lhs, rhs);
-        break;
-      case CompareSelectOperation::kGE:
-        cmp_ = irb_.CreateICmpSGE(lhs, rhs);
-        break;
-      case CompareSelectOperation::kLT:
-        cmp_ = irb_.CreateICmpSLT(lhs, rhs);
-        break;
-      case CompareSelectOperation::kLE:
-        cmp_ = irb_.CreateICmpSLE(lhs, rhs);
-        break;
-      default:
-        // TODO: change to a proper error report
-        throw std::runtime_error("invalid operator type");
-    }
+    cmp_ = irb_.CreateICmp(
+        llvm_comparison_predicate(cmp_op_, type_used), lhs, rhs);
   } else if (is_floating_point(type_used)) { // FP32
     switch (cmp_op_) {
       case CompareSelectOperation::kEQ: