[te] Fix bugs with shift operators (#49396)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49396
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49271
Two things:
1. These throw exceptions in their constructor, which causes a segfault (*), so
move the exceptions to ::make.
2. They technically support FP types but the rules are complicated so let's not
bother.
(*) The reason for the segfault: all Exprs including these inherit from
KernelScopedObject, whose constructor adds the object to a list for destruction
at the end of the containing KernelArena's lifetime. But if the derived-class
constructor throws, the object is deleted even though it's still in the
KernelArena's list. So when the KernelArena is itself deleted, it double-frees
the pointer and dies. I've also fixed And, Or, and Xor in this diff.
ghstack-source-id: 118594998
Test Plan: `buck test //caffe2/test:jit`
Reviewed By: bwasti
Differential Revision: D25512052
fbshipit-source-id: 42670b3be0cc1600dc5cda6811f7f270a2c88bba
diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py
index 8b04418..a9a3aee 100644
--- a/test/test_jit_fuser_te.py
+++ b/test/test_jit_fuser_te.py
@@ -477,7 +477,9 @@
binary_ops = [
operator.__and__,
operator.__or__,
- operator.__xor__
+ operator.__xor__,
+ operator.__lshift__,
+ operator.__rshift__,
]
devices = self.devices
for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
@@ -1292,11 +1294,6 @@
torch.lt,
torch.fmod,
torch.remainder,
-
- # FIXME: segfaults on CPU backend
- # operator.__rshift__,
- # operator.__lshift__,
-
lambda x, y: y.type_as(x),
]
fp_only = [
@@ -1343,10 +1340,6 @@
torch.ge,
torch.lt,
torch.gt,
-
- # FIXME: segfaults on CPU backend
- # operator.__rshift__,
- # operator.__lshift__,
]
devices = self.devices
# Maybe we should split this into separate tests to speed it up by
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
index 8a71e52..6f98b88 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
@@ -737,6 +737,12 @@
"aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor",
"aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor",
};
+ static const OperatorSet int_only_operator_set{
+ "aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor",
+ "aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor",
+ "aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor",
+ "aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor",
+ };
// clang-format on
for (const Value* v : node->inputs()) {
@@ -759,11 +765,20 @@
if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) {
return false;
}
+
+ // These operators have complicated casting rules for floats.
+ if (node->isMemberOf(int_only_operator_set) && isFloatingType(*st)) {
+ return false;
+ }
} else if (node->isMemberOf(float_only_operator_set)) {
// Check scalar operands of float-only ops.
if (!v->type()->cast<FloatType>()) {
return false;
}
+ } else if (node->isMemberOf(int_only_operator_set)) {
+ if (!v->type()->cast<IntType>()) {
+ return false;
+ }
}
}
diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h
index e7fbd37..b3fd7c6 100644
--- a/torch/csrc/jit/tensorexpr/eval.h
+++ b/torch/csrc/jit/tensorexpr/eval.h
@@ -422,8 +422,14 @@
if (expr_type == IRNodeType::kLshift || expr_type == IRNodeType::kRshift) {
switch (lhs_v.dtype().scalar_type()) {
- case ScalarType::Int:
- value_ = shift_binary_op<int>(lhs_v, rhs_v, expr_type);
+#define TYPE_CASE(Type, Name) \
+ case ScalarType::Name: \
+ value_ = shift_binary_op<Type>(lhs_v, rhs_v, expr_type); \
+ break;
+ AT_FORALL_INT_TYPES(TYPE_CASE);
+#undef TYPE_CASE
+ case ScalarType::Bool:
+ value_ = shift_binary_op<unsigned char>(lhs_v, rhs_v, expr_type);
break;
default:
throw unsupported_dtype();
diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h
index 6fe4bf0..8aa2d11 100644
--- a/torch/csrc/jit/tensorexpr/ir.h
+++ b/torch/csrc/jit/tensorexpr/ir.h
@@ -179,69 +179,51 @@
: BinaryOpNode(lhs, rhs, IRNodeType::kMod) {}
};
-class And : public BinaryOpNode<And> {
+template <typename Op>
+class BitwiseOpNode : public BinaryOpNode<Op> {
+ public:
+ BitwiseOpNode(const Expr* lhs, const Expr* rhs, IRNodeType type)
+ : BinaryOpNode<Op>(lhs, rhs, type) {}
+
+ static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
+ if (!lhs.dtype().is_integral()) {
+ throw unsupported_dtype();
+ }
+ if (lhs.dtype() != rhs.dtype()) {
+ throw malformed_input("lhs/rhs dtype mismatch");
+ }
+ return BinaryOpNode<Op>::make(lhs, rhs);
+ }
+};
+
+class And : public BitwiseOpNode<And> {
public:
And(const Expr* lhs, const Expr* rhs)
- : BinaryOpNode(lhs, rhs, IRNodeType::kAnd) {
- if (!lhs->dtype().is_integral()) {
- throw unsupported_dtype();
- }
- if (lhs->dtype() != rhs->dtype()) {
- throw malformed_input("bad dtype in And");
- }
- }
+ : BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {}
};
-class Or : public BinaryOpNode<Or> {
+class Or : public BitwiseOpNode<Or> {
public:
Or(const Expr* lhs, const Expr* rhs)
- : BinaryOpNode(lhs, rhs, IRNodeType::kOr) {
- if (!lhs->dtype().is_integral()) {
- throw unsupported_dtype();
- }
- if (lhs->dtype() != rhs->dtype()) {
- throw malformed_input("bad dtype in Or");
- }
- }
+ : BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {}
};
-class Xor : public BinaryOpNode<Xor> {
+class Xor : public BitwiseOpNode<Xor> {
public:
Xor(const Expr* lhs, const Expr* rhs)
- : BinaryOpNode(lhs, rhs, IRNodeType::kXor) {
- if (!lhs->dtype().is_integral()) {
- throw unsupported_dtype();
- }
- if (lhs->dtype() != rhs->dtype()) {
- throw malformed_input("bad dtype in Xor");
- }
- }
+ : BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {}
};
-class Lshift : public BinaryOpNode<Lshift> {
+class Lshift : public BitwiseOpNode<Lshift> {
public:
Lshift(const Expr* lhs, const Expr* rhs)
- : BinaryOpNode(lhs, rhs, IRNodeType::kLshift) {
- if (lhs->dtype().scalar_type() != ScalarType::Int) {
- throw unsupported_dtype();
- }
- if (lhs->dtype() != rhs->dtype()) {
- throw malformed_input("bad dtype in Lshift");
- }
- }
+ : BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {}
};
-class Rshift : public BinaryOpNode<Rshift> {
+class Rshift : public BitwiseOpNode<Rshift> {
public:
Rshift(const Expr* lhs, const Expr* rhs)
- : BinaryOpNode(lhs, rhs, IRNodeType::kRshift) {
- if (lhs->dtype().scalar_type() != ScalarType::Int) {
- throw unsupported_dtype();
- }
- if (lhs->dtype() != rhs->dtype()) {
- throw malformed_input("bad dtype in Rshift");
- }
- }
+ : BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {}
};
class Max : public BinaryOpNode<Max> {