[te] Fix bugs with shift operators (#49396)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49396

Pull Request resolved: https://github.com/pytorch/pytorch/pull/49271

Two things:

1. These throw exceptions in their constructor, which causes a segfault (*), so
   move the exceptions to ::make.
2. They technically support FP types but the rules are complicated so let's not
   bother.

(*) The reason for the segfault: all Exprs including these inherit from
KernelScopedObject, whose constructor adds the object to a list for destruction
at the end of the containing KernelArena's lifetime.  But if the derived-class
constructor throws, the object is deleted even though it's still in the
KernelArena's list.  So when the KernelArena is itself deleted, it double-frees
the pointer and dies.  I've also fixed And, Or, and Xor in this diff.
ghstack-source-id: 118594998

Test Plan: `buck test //caffe2/test:jit`

Reviewed By: bwasti

Differential Revision: D25512052

fbshipit-source-id: 42670b3be0cc1600dc5cda6811f7f270a2c88bba
diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py
index 8b04418..a9a3aee 100644
--- a/test/test_jit_fuser_te.py
+++ b/test/test_jit_fuser_te.py
@@ -477,7 +477,9 @@
         binary_ops = [
             operator.__and__,
             operator.__or__,
-            operator.__xor__
+            operator.__xor__,
+            operator.__lshift__,
+            operator.__rshift__,
         ]
         devices = self.devices
         for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
@@ -1292,11 +1294,6 @@
             torch.lt,
             torch.fmod,
             torch.remainder,
-
-            # FIXME: segfaults on CPU backend
-            # operator.__rshift__,
-            # operator.__lshift__,
-
             lambda x, y: y.type_as(x),
         ]
         fp_only = [
@@ -1343,10 +1340,6 @@
             torch.ge,
             torch.lt,
             torch.gt,
-
-            # FIXME: segfaults on CPU backend
-            # operator.__rshift__,
-            # operator.__lshift__,
         ]
         devices = self.devices
         # Maybe we should split this into separate tests to speed it up by
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
index 8a71e52..6f98b88 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
@@ -737,6 +737,12 @@
       "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor",
       "aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor",
     };
+    static const OperatorSet int_only_operator_set{
+      "aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor",
+      "aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor",
+      "aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor",
+      "aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor",
+    };
     // clang-format on
 
     for (const Value* v : node->inputs()) {
@@ -759,11 +765,20 @@
         if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) {
           return false;
         }
+
+        // These operators have complicated casting rules for floats.
+        if (node->isMemberOf(int_only_operator_set) && isFloatingType(*st)) {
+          return false;
+        }
       } else if (node->isMemberOf(float_only_operator_set)) {
         // Check scalar operands of float-only ops.
         if (!v->type()->cast<FloatType>()) {
           return false;
         }
+      } else if (node->isMemberOf(int_only_operator_set)) {
+        if (!v->type()->cast<IntType>()) {
+          return false;
+        }
       }
     }
 
diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h
index e7fbd37..b3fd7c6 100644
--- a/torch/csrc/jit/tensorexpr/eval.h
+++ b/torch/csrc/jit/tensorexpr/eval.h
@@ -422,8 +422,14 @@
 
     if (expr_type == IRNodeType::kLshift || expr_type == IRNodeType::kRshift) {
       switch (lhs_v.dtype().scalar_type()) {
-        case ScalarType::Int:
-          value_ = shift_binary_op<int>(lhs_v, rhs_v, expr_type);
+#define TYPE_CASE(Type, Name)                                \
+  case ScalarType::Name:                                     \
+    value_ = shift_binary_op<Type>(lhs_v, rhs_v, expr_type); \
+    break;
+        AT_FORALL_INT_TYPES(TYPE_CASE);
+#undef TYPE_CASE
+        case ScalarType::Bool:
+          value_ = shift_binary_op<unsigned char>(lhs_v, rhs_v, expr_type);
           break;
         default:
           throw unsupported_dtype();
diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h
index 6fe4bf0..8aa2d11 100644
--- a/torch/csrc/jit/tensorexpr/ir.h
+++ b/torch/csrc/jit/tensorexpr/ir.h
@@ -179,69 +179,51 @@
       : BinaryOpNode(lhs, rhs, IRNodeType::kMod) {}
 };
 
-class And : public BinaryOpNode<And> {
+template <typename Op>
+class BitwiseOpNode : public BinaryOpNode<Op> {
+ public:
+  BitwiseOpNode(const Expr* lhs, const Expr* rhs, IRNodeType type)
+      : BinaryOpNode<Op>(lhs, rhs, type) {}
+
+  static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
+    if (!lhs.dtype().is_integral()) {
+      throw unsupported_dtype();
+    }
+    if (lhs.dtype() != rhs.dtype()) {
+      throw malformed_input("lhs/rhs dtype mismatch");
+    }
+    return BinaryOpNode<Op>::make(lhs, rhs);
+  }
+};
+
+class And : public BitwiseOpNode<And> {
  public:
   And(const Expr* lhs, const Expr* rhs)
-      : BinaryOpNode(lhs, rhs, IRNodeType::kAnd) {
-    if (!lhs->dtype().is_integral()) {
-      throw unsupported_dtype();
-    }
-    if (lhs->dtype() != rhs->dtype()) {
-      throw malformed_input("bad dtype in And");
-    }
-  }
+      : BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {}
 };
 
-class Or : public BinaryOpNode<Or> {
+class Or : public BitwiseOpNode<Or> {
  public:
   Or(const Expr* lhs, const Expr* rhs)
-      : BinaryOpNode(lhs, rhs, IRNodeType::kOr) {
-    if (!lhs->dtype().is_integral()) {
-      throw unsupported_dtype();
-    }
-    if (lhs->dtype() != rhs->dtype()) {
-      throw malformed_input("bad dtype in Or");
-    }
-  }
+      : BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {}
 };
 
-class Xor : public BinaryOpNode<Xor> {
+class Xor : public BitwiseOpNode<Xor> {
  public:
   Xor(const Expr* lhs, const Expr* rhs)
-      : BinaryOpNode(lhs, rhs, IRNodeType::kXor) {
-    if (!lhs->dtype().is_integral()) {
-      throw unsupported_dtype();
-    }
-    if (lhs->dtype() != rhs->dtype()) {
-      throw malformed_input("bad dtype in Xor");
-    }
-  }
+      : BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {}
 };
 
-class Lshift : public BinaryOpNode<Lshift> {
+class Lshift : public BitwiseOpNode<Lshift> {
  public:
   Lshift(const Expr* lhs, const Expr* rhs)
-      : BinaryOpNode(lhs, rhs, IRNodeType::kLshift) {
-    if (lhs->dtype().scalar_type() != ScalarType::Int) {
-      throw unsupported_dtype();
-    }
-    if (lhs->dtype() != rhs->dtype()) {
-      throw malformed_input("bad dtype in Lshift");
-    }
-  }
+      : BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {}
 };
 
-class Rshift : public BinaryOpNode<Rshift> {
+class Rshift : public BitwiseOpNode<Rshift> {
  public:
   Rshift(const Expr* lhs, const Expr* rhs)
-      : BinaryOpNode(lhs, rhs, IRNodeType::kRshift) {
-    if (lhs->dtype().scalar_type() != ScalarType::Int) {
-      throw unsupported_dtype();
-    }
-    if (lhs->dtype() != rhs->dtype()) {
-      throw malformed_input("bad dtype in Rshift");
-    }
-  }
+      : BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {}
 };
 
 class Max : public BinaryOpNode<Max> {