[TensorExpr] scalar factorization of Div (#36154)

Summary:
Add support for the TensorExpr IR Simplifier to factorize common terms on either side of a Div node. e.g. `(8 * x) / (4 * y) => (2 * x) / y`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36154

Differential Revision: D20910580

Pulled By: nickgg

fbshipit-source-id: ee071d93bc4711b1e710be312de599d18ab506f3
diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp
index 26356be..7be7df1 100644
--- a/test/cpp/tensorexpr/test_simplify.cpp
+++ b/test/cpp/tensorexpr/test_simplify.cpp
@@ -586,6 +586,16 @@
     IS_VAR_WITH_NAME(rhs->lhs(), "y");
     IS_VAR_WITH_NAME(rhs->rhs(), "x");
   }
+
+  {
+    // (x + x + x + x) => 4 * x
+    ExprHandle body = (x + x + x + x);
+    ExprHandle simplified = IRSimplifier::simplify(body);
+
+    IS_NODE_WITH_NAME(Mul, simplified.node(), root);
+    IS_IMM_WITH_VAL(Int, root->lhs(), 4);
+    IS_VAR_WITH_NAME(root->rhs(), "x");
+  }
 }
 
 void testSimplifyMuls() {
@@ -1393,5 +1403,77 @@
   }
 }
 
+void testSimplifyDivisionScalarFactorization() {
+  KernelScope kernel_scope;
+
+  {
+    // Simple factorization of numerator and denominator.
+    // 8x / 4y => 2x / y.
+    VarHandle x("x", kInt);
+    VarHandle y("y", kInt);
+    ExprHandle body = (x * 8) / (y * 4);
+    ExprHandle simplified = IRSimplifier::simplify(body);
+    IS_NODE_WITH_NAME(Div, simplified.node(), div);
+    IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
+    IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
+    IS_VAR_WITH_NAME(lhs->rhs(), "x");
+    IS_VAR_WITH_NAME(div->rhs(), "y");
+  }
+
+  {
+    // Don't change anything if we can't factorize.
+    VarHandle x("x", kInt);
+    VarHandle y("y", kInt);
+    ExprHandle body = (x * 7) / (y * 4);
+    ExprHandle simplified = IRSimplifier::simplify(body);
+    IS_NODE_WITH_NAME(Div, simplified.node(), div);
+    IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
+    IS_IMM_WITH_VAL(Int, lhs->lhs(), 7);
+    IS_VAR_WITH_NAME(lhs->rhs(), "x");
+    IS_NODE_WITH_NAME(Mul, div->rhs(), rhs);
+    IS_IMM_WITH_VAL(Int, rhs->lhs(), 4);
+    IS_VAR_WITH_NAME(rhs->rhs(), "y");
+  }
+
+  {
+    // Don't reorder floats.
+    VarHandle x("x", kFloat);
+    VarHandle y("y", kFloat);
+    ExprHandle body = (x * 8) / (y * 4);
+    ExprHandle simplified = IRSimplifier::simplify(body);
+    IS_NODE_WITH_NAME(Div, simplified.node(), div);
+    IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
+    IS_VAR_WITH_NAME(lhs->lhs(), "x");
+    IS_IMM_WITH_VAL(Float, lhs->rhs(), 8.f);
+    IS_NODE_WITH_NAME(Mul, div->rhs(), rhs);
+    IS_VAR_WITH_NAME(rhs->lhs(), "y");
+    IS_IMM_WITH_VAL(Float, rhs->rhs(), 4.f);
+  }
+
+  {
+    // Sanity check we do nothing if there are only scalar parts.
+    VarHandle x("x", kInt);
+    VarHandle y("y", kInt);
+    ExprHandle body = (x * 1) / (y * 1);
+    ExprHandle simplified = IRSimplifier::simplify(body);
+    IS_NODE_WITH_NAME(Div, simplified.node(), div);
+    IS_VAR_WITH_NAME(div->lhs(), "x");
+    IS_VAR_WITH_NAME(div->rhs(), "y");
+  }
+
+  {
+    // Can factorize amounts of variables.
+    VarHandle x("x", kInt);
+    VarHandle y("y", kInt);
+    ExprHandle body = (x + x + x + x) / (y + y);
+    ExprHandle simplified = IRSimplifier::simplify(body);
+    IS_NODE_WITH_NAME(Div, simplified.node(), div);
+    IS_NODE_WITH_NAME(Mul, div->lhs(), lhs);
+    IS_IMM_WITH_VAL(Int, lhs->lhs(), 2);
+    IS_VAR_WITH_NAME(lhs->rhs(), "x");
+    IS_VAR_WITH_NAME(div->rhs(), "y");
+  }
+}
+
 } // namespace jit
 } // namespace torch
diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h
index c970ffb..d9f4149 100644
--- a/test/cpp/tensorexpr/tests.h
+++ b/test/cpp/tensorexpr/tests.h
@@ -124,6 +124,7 @@
   _(SimplifyRoundModPattern)              \
   _(SimplifyRoundModPatternFactorization) \
   _(SimplifyRoundModPatternMultivar)      \
+  _(SimplifyDivisionScalarFactorization)  \
   _(StmtClone)
 
 #define TH_FORALL_TENSOREXPR_TESTS_LLVM(_) \
diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp
index bd3dbea..63d00ef 100644
--- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp
+++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp
@@ -4,6 +4,15 @@
 namespace jit {
 namespace tensorexpr {
 
+// Simple recursive GCD.
+template <typename T>
+T gcd(T a, T b) {
+  if (b == 0) {
+    return a;
+  }
+  return gcd(b, a % b);
+}
+
 SimplifierHashType Term::hashVars() const {
   SimplifierHashType hash;
   for (auto* v : variables_) {
@@ -334,6 +343,13 @@
     return insertTerm(poly, lhsTerm ? lhsTerm : rhsTerm);
   }
 
+  if (lhsTerm->hashVars() == rhsTerm->hashVars()) {
+    return new Term(
+        hasher_,
+        evaluateOp(new Add(lhsTerm->scalar(), rhsTerm->scalar())),
+        lhsTerm->variables());
+  }
+
   // If all else fails we have a new Polynomial with two new variable Terms.
   return new Polynomial(
       hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm);
@@ -831,6 +847,79 @@
   return new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new, rhs_new);
 }
 
+const Expr* factorizeDivision(const Expr* lhs_new, const Expr* rhs_new) {
+  if (!lhs_new || !rhs_new) {
+    return nullptr;
+  }
+
+  const Expr* leftScalar = lhs_new->isConstant() ? lhs_new : nullptr;
+  const Expr* rightScalar = rhs_new->isConstant() ? rhs_new : nullptr;
+
+  auto* lhsTerm = dynamic_cast<const Term*>(lhs_new);
+  auto* rhsTerm = dynamic_cast<const Term*>(rhs_new);
+  if (lhsTerm) {
+    leftScalar = lhsTerm->scalar();
+  }
+
+  if (rhsTerm) {
+    rightScalar = rhsTerm->scalar();
+  }
+
+  if (!leftScalar || !rightScalar) {
+    return nullptr;
+  }
+
+  long left = immediateAs<long>(leftScalar);
+  long right = immediateAs<long>(rightScalar);
+
+  long GCD = gcd<long>(left, right);
+  if (GCD <= 1) {
+    return nullptr;
+  }
+
+  leftScalar = evaluateOp(
+      new Div(leftScalar, getImmediateByType(leftScalar->dtype(), GCD)));
+  rightScalar = evaluateOp(
+      new Div(rightScalar, getImmediateByType(rightScalar->dtype(), GCD)));
+
+  if (lhsTerm) {
+    lhs_new = new Term(lhsTerm->hasher(), leftScalar, lhsTerm->variables());
+  } else {
+    lhs_new = leftScalar;
+  }
+
+  if (rhsTerm) {
+    rhs_new = new Term(rhsTerm->hasher(), rightScalar, rhsTerm->variables());
+  } else {
+    rhs_new = rightScalar;
+  }
+
+  return new Div(lhs_new, rhs_new);
+}
+
+const Expr* PolynomialTransformer::mutate(const Div* v) {
+  const Expr* lhs_new = v->lhs()->accept_mutator(this);
+  const Expr* rhs_new = v->rhs()->accept_mutator(this);
+
+  // Constant Folding.
+  if (lhs_new->isConstant() && rhs_new->isConstant()) {
+    return evaluateOp(new Div(lhs_new, rhs_new));
+  }
+
+  // If this is a floating point Div then order of operations is important, we
+  // dont want to combine ops.
+  if (lhs_new->dtype().is_floating_point() ||
+      rhs_new->dtype().is_floating_point()) {
+    return new Div(lhs_new, rhs_new);
+  }
+
+  if (auto ret = factorizeDivision(lhs_new, rhs_new)) {
+    return ret;
+  }
+
+  return new Div(lhs_new, rhs_new);
+}
+
 const Expr* PolynomialTransformer::mutate(const Intrinsics* v) {
   std::vector<const Expr*> new_params;
   bool changed = false;
@@ -955,15 +1044,6 @@
   return lastNode;
 }
 
-// Simple recursive GCD.
-template <typename T>
-T gcd(T a, T b) {
-  if (b == 0) {
-    return a;
-  }
-  return gcd(b, a % b);
-}
-
 // Returns an immediate containing the greatest common divisor of all terms
 // (inc. the scalar term) in the polynomial. If the GCD is uninteresting
 // (e.g. 1) then returns nullptr.
diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.h b/torch/csrc/jit/tensorexpr/ir_simplifier.h
index 7d20a5d..5f3b8fc 100644
--- a/torch/csrc/jit/tensorexpr/ir_simplifier.h
+++ b/torch/csrc/jit/tensorexpr/ir_simplifier.h
@@ -283,10 +283,7 @@
   // Merge and simplify multiplication.
   const Expr* mutate(const Mul* v) override;
 
-  const Expr* mutate(const Div* v) override {
-    // TODO div simplification will require a rational node.
-    return mutateBinaryOp(v, this);
-  }
+  const Expr* mutate(const Div* v) override;
 
   const Expr* mutate(const Mod* v) override {
     return mutateBinaryOp(v, this);