[TensorExpr] scalar factorization of Div (#36154)
Summary:
Add support for the TensorExpr IR Simplifier to factorize common terms on either side of a Div node. e.g. `(8 * x) / (4 * y) => (2 * x) / y`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36154
Differential Revision: D20910580
Pulled By: nickgg
fbshipit-source-id: ee071d93bc4711b1e710be312de599d18ab506f3
diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp
index 26356be..7be7df1 100644
--- a/test/cpp/tensorexpr/test_simplify.cpp
+++ b/test/cpp/tensorexpr/test_simplify.cpp
@@ -586,6 +586,16 @@
IS_VAR_WITH_NAME(rhs->lhs(), "y");
IS_VAR_WITH_NAME(rhs->rhs(), "x");
}
+
+ {
+ // (x + x + x + x) => 4 * x
+ ExprHandle body = (x + x + x + x);
+ ExprHandle simplified = IRSimplifier::simplify(body);
+
+ IS_NODE_WITH_NAME(Mul, simplified.node(), root);
+ IS_IMM_WITH_VAL(Int, root->lhs(), 4);
+ IS_VAR_WITH_NAME(root->rhs(), "x");
+ }
}
void testSimplifyMuls() {
@@ -1393,5 +1403,77 @@
}
}
+void testSimplifyDivisionScalarFactorization() {
+ KernelScope kernel_scope;
+
+ {
+ // Simple factorization of numerator and denominator.
+ // 8x / 4y => 2x / y.
+ VarHandle x("x", kInt);
+ VarHandle y("y", kInt);
+ ExprHandle body = (x * 8) / (y * 4);
+ ExprHandle simplified = IRSimplifier::simplify(body);
+ IS_NODE_WITH_NAME(Div, simplified.node(), div);
+ IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
+ IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
+ IS_VAR_WITH_NAME(lhs->rhs(), "x");
+ IS_VAR_WITH_NAME(div->rhs(), "y");
+ }
+
+ {
+ // Don't change anything if we can't factorize.
+ VarHandle x("x", kInt);
+ VarHandle y("y", kInt);
+ ExprHandle body = (x * 7) / (y * 4);
+ ExprHandle simplified = IRSimplifier::simplify(body);
+ IS_NODE_WITH_NAME(Div, simplified.node(), div);
+ IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
+ IS_IMM_WITH_VAL(Int, lhs->lhs(), 7);
+ IS_VAR_WITH_NAME(lhs->rhs(), "x");
+ IS_NODE_WITH_NAME(Mul, div->rhs(), rhs);
+ IS_IMM_WITH_VAL(Int, rhs->lhs(), 4);
+ IS_VAR_WITH_NAME(rhs->rhs(), "y");
+ }
+
+ {
+ // Don't reorder floats.
+ VarHandle x("x", kFloat);
+ VarHandle y("y", kFloat);
+ ExprHandle body = (x * 8) / (y * 4);
+ ExprHandle simplified = IRSimplifier::simplify(body);
+ IS_NODE_WITH_NAME(Div, simplified.node(), div);
+ IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
+ IS_VAR_WITH_NAME(lhs->lhs(), "x");
+ IS_IMM_WITH_VAL(Float, lhs->rhs(), 8.f);
+ IS_NODE_WITH_NAME(Mul, div->rhs(), rhs);
+ IS_VAR_WITH_NAME(rhs->lhs(), "y");
+ IS_IMM_WITH_VAL(Float, rhs->rhs(), 4.f);
+ }
+
+ {
+ // Sanity check we do nothing if there are only scalar parts.
+ VarHandle x("x", kInt);
+ VarHandle y("y", kInt);
+ ExprHandle body = (x * 1) / (y * 1);
+ ExprHandle simplified = IRSimplifier::simplify(body);
+ IS_NODE_WITH_NAME(Div, simplified.node(), div);
+ IS_VAR_WITH_NAME(div->lhs(), "x");
+ IS_VAR_WITH_NAME(div->rhs(), "y");
+ }
+
+ {
+ // Can factorize amounts of variables.
+ VarHandle x("x", kInt);
+ VarHandle y("y", kInt);
+ ExprHandle body = (x + x + x + x) / (y + y);
+ ExprHandle simplified = IRSimplifier::simplify(body);
+ IS_NODE_WITH_NAME(Div, simplified.node(), div);
+ IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
+ IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
+ IS_VAR_WITH_NAME(lhs->rhs(), "x");
+ IS_VAR_WITH_NAME(div->rhs(), "y");
+ }
+}
+
} // namespace jit
} // namespace torch
diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h
index c970ffb..d9f4149 100644
--- a/test/cpp/tensorexpr/tests.h
+++ b/test/cpp/tensorexpr/tests.h
@@ -124,6 +124,7 @@
_(SimplifyRoundModPattern) \
_(SimplifyRoundModPatternFactorization) \
_(SimplifyRoundModPatternMultivar) \
+ _(SimplifyDivisionScalarFactorization) \
_(StmtClone)
#define TH_FORALL_TENSOREXPR_TESTS_LLVM(_) \
diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp
index bd3dbea..63d00ef 100644
--- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp
+++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp
@@ -4,6 +4,15 @@
namespace jit {
namespace tensorexpr {
+// Simple recursive GCD.
+template <typename T>
+T gcd(T a, T b) {
+ if (b == 0) {
+ return a;
+ }
+ return gcd(b, a % b);
+}
+
SimplifierHashType Term::hashVars() const {
SimplifierHashType hash;
for (auto* v : variables_) {
@@ -334,6 +343,13 @@
return insertTerm(poly, lhsTerm ? lhsTerm : rhsTerm);
}
+ if (lhsTerm->hashVars() == rhsTerm->hashVars()) {
+ return new Term(
+ hasher_,
+ evaluateOp(new Add(lhsTerm->scalar(), rhsTerm->scalar())),
+ lhsTerm->variables());
+ }
+
// If all else fails we have a new Polynomial with two new variable Terms.
return new Polynomial(
hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm);
@@ -831,6 +847,79 @@
return new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new, rhs_new);
}
+const Expr* factorizeDivision(const Expr* lhs_new, const Expr* rhs_new) {
+ if (!lhs_new || !rhs_new) {
+ return nullptr;
+ }
+
+ const Expr* leftScalar = lhs_new->isConstant() ? lhs_new : nullptr;
+ const Expr* rightScalar = rhs_new->isConstant() ? rhs_new : nullptr;
+
+ auto* lhsTerm = dynamic_cast<const Term*>(lhs_new);
+ auto* rhsTerm = dynamic_cast<const Term*>(rhs_new);
+ if (lhsTerm) {
+ leftScalar = lhsTerm->scalar();
+ }
+
+ if (rhsTerm) {
+ rightScalar = rhsTerm->scalar();
+ }
+
+ if (!leftScalar || !rightScalar) {
+ return nullptr;
+ }
+
+ long left = immediateAs<long>(leftScalar);
+ long right = immediateAs<long>(rightScalar);
+
+ long GCD = gcd<long>(left, right);
+ if (GCD <= 1) {
+ return nullptr;
+ }
+
+ leftScalar = evaluateOp(
+ new Div(leftScalar, getImmediateByType(leftScalar->dtype(), GCD)));
+ rightScalar = evaluateOp(
+ new Div(rightScalar, getImmediateByType(rightScalar->dtype(), GCD)));
+
+ if (lhsTerm) {
+ lhs_new = new Term(lhsTerm->hasher(), leftScalar, lhsTerm->variables());
+ } else {
+ lhs_new = leftScalar;
+ }
+
+ if (rhsTerm) {
+ rhs_new = new Term(rhsTerm->hasher(), rightScalar, rhsTerm->variables());
+ } else {
+ rhs_new = rightScalar;
+ }
+
+ return new Div(lhs_new, rhs_new);
+}
+
+const Expr* PolynomialTransformer::mutate(const Div* v) {
+ const Expr* lhs_new = v->lhs()->accept_mutator(this);
+ const Expr* rhs_new = v->rhs()->accept_mutator(this);
+
+ // Constant Folding.
+ if (lhs_new->isConstant() && rhs_new->isConstant()) {
+ return evaluateOp(new Div(lhs_new, rhs_new));
+ }
+
+ // If this is a floating point Div then order of operations is important, we
+ // dont want to combine ops.
+ if (lhs_new->dtype().is_floating_point() ||
+ rhs_new->dtype().is_floating_point()) {
+ return new Div(lhs_new, rhs_new);
+ }
+
+ if (auto ret = factorizeDivision(lhs_new, rhs_new)) {
+ return ret;
+ }
+
+ return new Div(lhs_new, rhs_new);
+}
+
const Expr* PolynomialTransformer::mutate(const Intrinsics* v) {
std::vector<const Expr*> new_params;
bool changed = false;
@@ -955,15 +1044,6 @@
return lastNode;
}
-// Simple recursive GCD.
-template <typename T>
-T gcd(T a, T b) {
- if (b == 0) {
- return a;
- }
- return gcd(b, a % b);
-}
-
// Returns an immediate containing the greatest common divisor of all terms
// (inc. the scalar term) in the polynomial. If the GCD is uninteresting
// (e.g. 1) then returns nullptr.
diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.h b/torch/csrc/jit/tensorexpr/ir_simplifier.h
index 7d20a5d..5f3b8fc 100644
--- a/torch/csrc/jit/tensorexpr/ir_simplifier.h
+++ b/torch/csrc/jit/tensorexpr/ir_simplifier.h
@@ -283,10 +283,7 @@
// Merge and simplify multiplication.
const Expr* mutate(const Mul* v) override;
- const Expr* mutate(const Div* v) override {
- // TODO div simplification will require a rational node.
- return mutateBinaryOp(v, this);
- }
+ const Expr* mutate(const Div* v) override;
const Expr* mutate(const Mod* v) override {
return mutateBinaryOp(v, this);