[MLIR] Value types for AffineXXXExpr

This CL makes AffineExprRef into a value type.

Notably:
1. drops llvm isa, cast, dyn_cast on pointer type and uses member functions on
the value type. It may be possible to still use classof  (in a followup CL)
2. AffineBaseExprRef aggressively casts constness away: if we mean the type is
immutable then let's jump in with both feet;
3. Drop implicit casts to the underlying pointer type because that always
results in surprising behavior and is not needed in practice once enough
cleanup has been applied.

The remaining negative I see is that we still need to mix operator. and
operator->. There is an ugly solution that forwards the methods but that ends
up duplicating the class hierarchy which I tried to avoid as much as
possible. But maybe it's not that bad anymore since AffineExpr.h would still
contain a single class hierarchy (the duplication would be impl detail in.cpp)

PiperOrigin-RevId: 216188003
diff --git a/include/mlir/IR/AffineExpr.h b/include/mlir/IR/AffineExpr.h
index e494abd..34a6fae 100644
--- a/include/mlir/IR/AffineExpr.h
+++ b/include/mlir/IR/AffineExpr.h
@@ -26,6 +26,121 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/Support/Casting.h"
+#include <type_traits>
+
+namespace mlir {
+
+class AffineExpr;
+class AffineBinaryOpExpr;
+class AffineDimExpr;
+class AffineSymbolExpr;
+class AffineConstantExpr;
+
+/// Helper structure to build AffineExpr with intuitive operators in order to
+/// operate on chainable, lightweight value types instead of pointer types.
+/// This structure operates on immutable types so it freely casts constness
+/// away.
+/// TODO(ntv): Remove all redundant MLIRContext* arguments through the API
+/// TODO(ntv): Remove all uses of AffineExpr* in Parser.cpp
+/// TODO(ntv): Add extra out-of-class operators for int op AffineExprBaseRef
+/// TODO(ntv): Rename
+/// TODO(ntv): Drop const everywhere it makes sense in AffineExpr
+/// TODO(ntv): remove const comment
+/// TODO(ntv): pointer pair
+template <typename AffineExprType> class AffineExprBaseRef {
+public:
+  typedef AffineExprBaseRef TemplateType;
+  typedef AffineExprType ImplType;
+
+  AffineExprBaseRef() : expr(nullptr) {}
+  /* implicit */ AffineExprBaseRef(const AffineExprType *expr)
+      : expr(const_cast<AffineExprType *>(expr)) {}
+
+  AffineExprBaseRef(const AffineExprBaseRef &other) : expr(other.expr) {}
+  AffineExprBaseRef &operator=(AffineExprBaseRef other) {
+    expr = other.expr;
+    return *this;
+  }
+
+  bool operator==(AffineExprBaseRef other) const { return expr == other.expr; }
+
+  AffineExprType *operator->() const { return expr; }
+
+  /* implicit */ operator AffineExprBaseRef<AffineExpr>() const {
+    return const_cast<AffineExpr *>(static_cast<const AffineExpr *>(expr));
+  }
+  explicit operator bool() const { return expr; }
+
+  bool empty() const { return expr == nullptr; }
+  bool operator!() const { return expr == nullptr; }
+
+  template <typename U> bool isa() const {
+    using PtrType = typename U::ImplType;
+    return llvm::isa<PtrType>(const_cast<AffineExprType *>(this->expr));
+  }
+  template <typename U> U dyn_cast() const {
+    using PtrType = typename U::ImplType;
+    return U(llvm::dyn_cast<PtrType>(const_cast<AffineExprType *>(this->expr)));
+  }
+  template <typename U> U cast() const {
+    using PtrType = typename U::ImplType;
+    return U(llvm::cast<PtrType>(const_cast<AffineExprType *>(this->expr)));
+  }
+
+  AffineExprBaseRef operator+(int64_t v) const;
+  AffineExprBaseRef operator+(AffineExprBaseRef other) const;
+  AffineExprBaseRef operator-() const;
+  AffineExprBaseRef operator-(int64_t v) const;
+  AffineExprBaseRef operator-(AffineExprBaseRef other) const;
+  AffineExprBaseRef operator*(int64_t v) const;
+  AffineExprBaseRef operator*(AffineExprBaseRef other) const;
+  AffineExprBaseRef floorDiv(uint64_t v) const;
+  AffineExprBaseRef floorDiv(AffineExprBaseRef other) const;
+  AffineExprBaseRef ceilDiv(uint64_t v) const;
+  AffineExprBaseRef ceilDiv(AffineExprBaseRef other) const;
+  AffineExprBaseRef operator%(uint64_t v) const;
+  AffineExprBaseRef operator%(AffineExprBaseRef other) const;
+
+  friend ::llvm::hash_code hash_value(AffineExprBaseRef arg);
+
+private:
+  AffineExprType *expr;
+};
+
+using AffineExprRef = AffineExprBaseRef<AffineExpr>;
+using AffineBinaryOpExprRef = AffineExprBaseRef<AffineBinaryOpExpr>;
+using AffineDimExprRef = AffineExprBaseRef<AffineDimExpr>;
+using AffineSymbolExprRef = AffineExprBaseRef<AffineSymbolExpr>;
+using AffineConstantExprRef = AffineExprBaseRef<AffineConstantExpr>;
+
+// Make AffineExprRef hashable.
+inline ::llvm::hash_code hash_value(AffineExprRef arg) {
+  return ::llvm::hash_value(static_cast<AffineExpr *>(arg.expr));
+}
+
+} // namespace mlir
+
+namespace llvm {
+
+// AffineExprRef hash just like pointers
+template <> struct DenseMapInfo<mlir::AffineExprRef> {
+  static mlir::AffineExprRef getEmptyKey() {
+    auto pointer = llvm::DenseMapInfo<mlir::AffineExpr *>::getEmptyKey();
+    return mlir::AffineExprRef(pointer);
+  }
+  static mlir::AffineExprRef getTombstoneKey() {
+    auto pointer = llvm::DenseMapInfo<mlir::AffineExpr *>::getTombstoneKey();
+    return mlir::AffineExprRef(pointer);
+  }
+  static unsigned getHashValue(mlir::AffineExprRef val) {
+    return mlir::hash_value(val);
+  }
+  static bool isEqual(mlir::AffineExprRef LHS, mlir::AffineExprRef RHS) {
+    return LHS == RHS;
+  }
+};
+
+} // namespace llvm
 
 namespace mlir {
 
@@ -99,93 +214,6 @@
   return os;
 }
 
-/// Helper structure to build AffineExpr with intuitive operators in order to
-/// operate on chainable, lightweight value types instead of pointer types.
-/// This structure operates on immutable types so it freely casts constness
-/// away.
-/// TODO(ntv): Remove all redundant MLIRContext* arguments through the API
-/// TODO(ntv): Remove all uses of AffineExpr* in Parser.cpp
-/// TODO(ntv): Add extra out-of-class operators for int op AffineExprBaseRef
-/// TODO(ntv): Rename
-/// TODO(ntv): Drop const everywhere it makes sense in AffineExpr
-/// TODO(ntv): remove const comment
-/// TODO(ntv): pointer pair
-template <typename AffineExprType> class AffineExprBaseRef {
-public:
-  /* implicit */ AffineExprBaseRef(AffineExprType *expr) : expr(expr) {}
-
-  AffineExprBaseRef(const AffineExprBaseRef &other) : expr(other.expr){};
-  AffineExprBaseRef &operator=(AffineExprBaseRef other) {
-    expr = other;
-    return *this;
-  };
-  bool operator==(AffineExprBaseRef other) const { return expr == other.expr; };
-  AffineExprType *operator->() { return expr; }
-  /* implicit */ operator AffineExprType *() { return expr; }
-
-  bool operator!() { return expr == nullptr; }
-
-  AffineExprBaseRef operator+(int64_t v) const;
-  AffineExprBaseRef operator+(AffineExprBaseRef other) const;
-  AffineExprBaseRef operator-() const;
-  AffineExprBaseRef operator-(int64_t v) const;
-  AffineExprBaseRef operator-(AffineExprBaseRef other) const;
-  AffineExprBaseRef operator*(int64_t v) const;
-  AffineExprBaseRef operator*(AffineExprBaseRef other) const;
-  AffineExprBaseRef floorDiv(uint64_t v) const;
-  AffineExprBaseRef floorDiv(AffineExprBaseRef other) const;
-  AffineExprBaseRef ceilDiv(uint64_t v) const;
-  AffineExprBaseRef ceilDiv(AffineExprBaseRef other) const;
-  AffineExprBaseRef operator%(uint64_t v) const;
-  AffineExprBaseRef operator%(AffineExprBaseRef other) const;
-
-private:
-  AffineExprType *expr;
-};
-
-using AffineExprRef = AffineExprBaseRef<AffineExpr>;
-
-inline ::llvm::hash_code hash_value(AffineExprRef arg);
-} // namespace mlir
-
-namespace llvm {
-
-/// This helper structure allows classof/isa/cast/dyn_cast to operate on
-/// AffineExprBaseRef<T>.
-template <typename T> struct simplify_type<mlir::AffineExprBaseRef<T>> {
-  using SimpleType = T *;
-  static SimpleType getSimplifiedValue(mlir::AffineExprBaseRef<T> &input) {
-    return input;
-  }
-};
-
-// AffineExprRef hash just like pointers
-template <> struct DenseMapInfo<mlir::AffineExprRef> {
-  static mlir::AffineExprRef getEmptyKey() {
-    auto pointer = llvm::DenseMapInfo<mlir::AffineExpr *>::getEmptyKey();
-    return mlir::AffineExprRef(pointer);
-  }
-  static mlir::AffineExprRef getTombstoneKey() {
-    auto pointer = llvm::DenseMapInfo<mlir::AffineExpr *>::getTombstoneKey();
-    return mlir::AffineExprRef(pointer);
-  }
-  static unsigned getHashValue(mlir::AffineExprRef val) {
-    return mlir::hash_value(val);
-  }
-  static bool isEqual(mlir::AffineExprRef LHS, mlir::AffineExprRef RHS) {
-    return LHS == RHS;
-  }
-};
-
-} // namespace llvm
-
-namespace mlir {
-
-// Make AffineExprRef hashable.
-inline ::llvm::hash_code hash_value(AffineExprRef arg) {
-  return ::llvm::hash_value(static_cast<AffineExpr *>(arg));
-}
-
 /// Affine binary operation expression. An affine binary operation could be an
 /// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
 /// represented through a multiply by -1 and add.) These expressions are always
diff --git a/include/mlir/IR/AffineExprVisitor.h b/include/mlir/IR/AffineExprVisitor.h
index a4dcb10..7dff48a 100644
--- a/include/mlir/IR/AffineExprVisitor.h
+++ b/include/mlir/IR/AffineExprVisitor.h
@@ -46,7 +46,7 @@
 ///  struct DimExprCounter : public AffineExprVisitor<DimExprCounter> {
 ///    unsigned numDimExprs;
 ///    DimExprCounter() : numDimExprs(0) {}
-///    void visitAffineDimExpr(AffineDimExpr *expr) { ++numDimExprs; }
+///    void visitAffineDimExpr(AffineDimExprRef expr) { ++numDimExprs; }
 ///  };
 ///
 ///  And this class would be used like this:
@@ -83,39 +83,39 @@
                   "Must instantiate with a derived type of AffineExprVisitor");
     switch (expr->getKind()) {
     case AffineExpr::Kind::Add: {
-      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
       walkOperandsPostOrder(binOpExpr);
       return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
     }
     case AffineExpr::Kind::Mul: {
-      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
       walkOperandsPostOrder(binOpExpr);
       return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
     }
     case AffineExpr::Kind::Mod: {
-      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
       walkOperandsPostOrder(binOpExpr);
       return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
     }
     case AffineExpr::Kind::FloorDiv: {
-      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
       walkOperandsPostOrder(binOpExpr);
       return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
     }
     case AffineExpr::Kind::CeilDiv: {
-      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
       walkOperandsPostOrder(binOpExpr);
       return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
     }
     case AffineExpr::Kind::Constant:
       return static_cast<SubClass *>(this)->visitConstantExpr(
-          cast<AffineConstantExpr>(expr));
+          expr.cast<AffineConstantExprRef>());
     case AffineExpr::Kind::DimId:
       return static_cast<SubClass *>(this)->visitDimExpr(
-          cast<AffineDimExpr>(expr));
+          expr.cast<AffineDimExprRef>());
     case AffineExpr::Kind::SymbolId:
       return static_cast<SubClass *>(this)->visitSymbolExpr(
-          cast<AffineSymbolExpr>(expr));
+          expr.cast<AffineSymbolExprRef>());
     }
   }
 
@@ -125,34 +125,34 @@
                   "Must instantiate with a derived type of AffineExprVisitor");
     switch (expr->getKind()) {
     case AffineExpr::Kind::Add: {
-      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
       return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
     }
     case AffineExpr::Kind::Mul: {
-      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
       return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
     }
     case AffineExpr::Kind::Mod: {
-      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
       return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
     }
     case AffineExpr::Kind::FloorDiv: {
-      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
       return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
     }
     case AffineExpr::Kind::CeilDiv: {
-      auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+      auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
       return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
     }
     case AffineExpr::Kind::Constant:
       return static_cast<SubClass *>(this)->visitConstantExpr(
-          cast<AffineConstantExpr>(expr));
+          expr.cast<AffineConstantExprRef>());
     case AffineExpr::Kind::DimId:
       return static_cast<SubClass *>(this)->visitDimExpr(
-          cast<AffineDimExpr>(expr));
+          expr.cast<AffineDimExprRef>());
     case AffineExpr::Kind::SymbolId:
       return static_cast<SubClass *>(this)->visitSymbolExpr(
-          cast<AffineSymbolExpr>(expr));
+          expr.cast<AffineSymbolExprRef>());
     }
   }
 
@@ -166,29 +166,29 @@
 
   // Default visit methods. Note that the default op-specific binary op visit
   // methods call the general visitAffineBinaryOpExpr visit method.
-  void visitAffineBinaryOpExpr(AffineBinaryOpExpr *expr) {}
-  void visitAddExpr(AffineBinaryOpExpr *expr) {
+  void visitAffineBinaryOpExpr(AffineBinaryOpExprRef expr) {}
+  void visitAddExpr(AffineBinaryOpExprRef expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitMulExpr(AffineBinaryOpExpr *expr) {
+  void visitMulExpr(AffineBinaryOpExprRef expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitModExpr(AffineBinaryOpExpr *expr) {
+  void visitModExpr(AffineBinaryOpExprRef expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
+  void visitFloorDivExpr(AffineBinaryOpExprRef expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
+  void visitCeilDivExpr(AffineBinaryOpExprRef expr) {
     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitConstantExpr(AffineConstantExpr *expr) {}
-  void visitAffineDimExpr(AffineDimExpr *expr) {}
-  void visitAffineSymbolExpr(AffineSymbolExpr *expr) {}
+  void visitConstantExpr(AffineConstantExprRef expr) {}
+  void visitAffineDimExpr(AffineDimExprRef expr) {}
+  void visitAffineSymbolExpr(AffineSymbolExprRef expr) {}
 
 private:
   // Walk the operands - each operand is itself walked in post order.
-  void walkOperandsPostOrder(AffineBinaryOpExpr *expr) {
+  void walkOperandsPostOrder(AffineBinaryOpExprRef expr) {
     walkPostOrder(expr->getLHS());
     walkPostOrder(expr->getRHS());
   }
diff --git a/lib/Analysis/AffineAnalysis.cpp b/lib/Analysis/AffineAnalysis.cpp
index ebcc50a..2f58500 100644
--- a/lib/Analysis/AffineAnalysis.cpp
+++ b/lib/Analysis/AffineAnalysis.cpp
@@ -139,10 +139,10 @@
     operandExprStack.reserve(8);
   }
 
-  void visitMulExpr(AffineBinaryOpExpr *expr) {
+  void visitMulExpr(AffineBinaryOpExprRef expr) {
     assert(operandExprStack.size() >= 2);
     // This is a pure affine expr; the RHS will be a constant.
-    assert(isa<AffineConstantExpr>(expr->getRHS()));
+    assert(expr->getRHS().isa<AffineConstantExprRef>());
     // Get the RHS constant.
     auto rhsConst = operandExprStack.back()[getConstantIndex()];
     operandExprStack.pop_back();
@@ -153,7 +153,7 @@
     }
   }
 
-  void visitAddExpr(AffineBinaryOpExpr *expr) {
+  void visitAddExpr(AffineBinaryOpExprRef expr) {
     assert(operandExprStack.size() >= 2);
     const auto &rhs = operandExprStack.back();
     auto &lhs = operandExprStack[operandExprStack.size() - 2];
@@ -166,10 +166,10 @@
     operandExprStack.pop_back();
   }
 
-  void visitModExpr(AffineBinaryOpExpr *expr) {
+  void visitModExpr(AffineBinaryOpExprRef expr) {
     assert(operandExprStack.size() >= 2);
     // This is a pure affine expr; the RHS will be a constant.
-    assert(isa<AffineConstantExpr>(expr->getRHS()));
+    assert(expr->getRHS().isa<AffineConstantExprRef>());
     auto rhsConst = operandExprStack.back()[getConstantIndex()];
     operandExprStack.pop_back();
     auto &lhs = operandExprStack.back();
@@ -195,32 +195,32 @@
         AffineConstantExpr::get(rhsConst, context), context));
     lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
   }
-  void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
+  void visitCeilDivExpr(AffineBinaryOpExprRef expr) {
     visitDivExpr(expr, /*isCeil=*/true);
   }
-  void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
+  void visitFloorDivExpr(AffineBinaryOpExprRef expr) {
     visitDivExpr(expr, /*isCeil=*/false);
   }
-  void visitDimExpr(AffineDimExpr *expr) {
+  void visitDimExpr(AffineDimExprRef expr) {
     operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
     auto &eq = operandExprStack.back();
     eq[getDimStartIndex() + expr->getPosition()] = 1;
   }
-  void visitSymbolExpr(AffineSymbolExpr *expr) {
+  void visitSymbolExpr(AffineSymbolExprRef expr) {
     operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
     auto &eq = operandExprStack.back();
     eq[getSymbolStartIndex() + expr->getPosition()] = 1;
   }
-  void visitConstantExpr(AffineConstantExpr *expr) {
+  void visitConstantExpr(AffineConstantExprRef expr) {
     operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
     auto &eq = operandExprStack.back();
     eq[getConstantIndex()] = expr->getValue();
   }
 
 private:
-  void visitDivExpr(AffineBinaryOpExpr *expr, bool isCeil) {
+  void visitDivExpr(AffineBinaryOpExprRef expr, bool isCeil) {
     assert(operandExprStack.size() >= 2);
-    assert(isa<AffineConstantExpr>(expr->getRHS()));
+    assert(expr->getRHS().isa<AffineConstantExprRef>());
     // This is a pure affine expr; the RHS is a positive constant.
     auto rhsConst = operandExprStack.back()[getConstantIndex()];
     // TODO(bondhugula): handle division by zero at the same time the issue is
diff --git a/lib/Analysis/HyperRectangularSet.cpp b/lib/Analysis/HyperRectangularSet.cpp
index 772ec85..4d72808 100644
--- a/lib/Analysis/HyperRectangularSet.cpp
+++ b/lib/Analysis/HyperRectangularSet.cpp
@@ -38,8 +38,7 @@
     unsigned j = 0;
     AffineBoundExprList::const_iterator it, e;
     for (it = ubs.begin(), e = ubs.end(); it != e; it++, j++) {
-      if (auto *cExpr = const_cast<AffineConstantExpr *>(
-              dyn_cast<AffineConstantExpr>(*it))) {
+      if (auto cExpr = it->dyn_cast<AffineConstantExprRef>()) {
         if (val == None) {
           val = cExpr->getValue();
           *idx = j;
@@ -69,7 +68,7 @@
     }
     if (it == lhsList.end()) {
       // There can only be one constant affine expr in this bound list.
-      if (auto cExpr = dyn_cast<AffineConstantExpr>(expr)) {
+      if (auto cExpr = expr.dyn_cast<AffineConstantExprRef>()) {
         unsigned idx;
         if (lb) {
           auto cb = getReducedConstBound(
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index babe95f..0b50494 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -61,7 +61,7 @@
     auto loopSpanExpr = simplifyAffineExpr(
         ubExpr - lbExpr + 1, std::max(lbMap->getNumDims(), ubMap->getNumDims()),
         std::max(lbMap->getNumSymbols(), ubMap->getNumSymbols()));
-    auto *cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr);
+    auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExprRef>();
     if (!cExpr)
       return AffineBinaryOpExpr::getCeilDiv(loopSpanExpr, step, context);
     loopSpan = cExpr->getValue();
@@ -81,7 +81,10 @@
 llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
   auto tripCountExpr = getTripCountExpr(forStmt);
 
-  if (auto constExpr = dyn_cast_or_null<AffineConstantExpr>(tripCountExpr))
+  if (!tripCountExpr)
+    return None;
+
+  if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExprRef>())
     return constExpr->getValue();
 
   return None;
@@ -96,7 +99,7 @@
   if (!tripCountExpr)
     return 1;
 
-  if (auto constExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
+  if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExprRef>()) {
     uint64_t tripCount = constExpr->getValue();
 
     // 0 iteration loops (greatest divisor is 2^64 - 1).
diff --git a/lib/IR/AffineExpr.cpp b/lib/IR/AffineExpr.cpp
index 1ddf697..8f00fb1 100644
--- a/lib/IR/AffineExpr.cpp
+++ b/lib/IR/AffineExpr.cpp
@@ -27,10 +27,10 @@
   // We verify affine op expr forms at construction time.
   switch (kind) {
   case Kind::Add:
-    assert(!isa<AffineConstantExpr>(lhs));
+    assert(!lhs.isa<AffineConstantExprRef>());
     break;
   case Kind::Mul:
-    assert(!isa<AffineConstantExpr>(lhs));
+    assert(!lhs.isa<AffineConstantExprRef>());
     assert(rhs->isSymbolicOrConstant());
     break;
   case Kind::FloorDiv:
@@ -124,15 +124,15 @@
     // possible, allowing this to merge into the next case.
     auto *op = cast<AffineBinaryOpExpr>(this);
     return op->getLHS()->isPureAffine() && op->getRHS()->isPureAffine() &&
-           (isa<AffineConstantExpr>(op->getLHS()) ||
-            isa<AffineConstantExpr>(op->getRHS()));
+           (op->getLHS().isa<AffineConstantExprRef>() ||
+            op->getRHS().isa<AffineConstantExprRef>());
   }
   case Kind::FloorDiv:
   case Kind::CeilDiv:
   case Kind::Mod: {
     auto *op = cast<AffineBinaryOpExpr>(this);
     return op->getLHS()->isPureAffine() &&
-           isa<AffineConstantExpr>(op->getRHS());
+           op->getRHS().isa<AffineConstantExprRef>();
   }
   }
 }
@@ -214,7 +214,7 @@
 }
 // Unary minus, delegate to operator*.
 template <> AffineExprRef AffineExprRef::operator-() const {
-  return *this * (-1);
+  return AffineBinaryOpExpr::getMul(expr, -1, expr->getContext());
 }
 // Delegate to operator+.
 template <> AffineExprRef AffineExprRef::operator-(int64_t v) const {
diff --git a/lib/IR/AffineMap.cpp b/lib/IR/AffineMap.cpp
index f98df70..4643183 100644
--- a/lib/IR/AffineMap.cpp
+++ b/lib/IR/AffineMap.cpp
@@ -55,14 +55,15 @@
       return constantFoldBinExpr(
           expr, [](int64_t lhs, uint64_t rhs) { return ceilDiv(lhs, rhs); });
     case AffineExpr::Kind::Constant:
-      return IntegerAttr::get(cast<AffineConstantExpr>(expr)->getValue(),
+      return IntegerAttr::get(expr.cast<AffineConstantExprRef>()->getValue(),
                               expr->getContext());
     case AffineExpr::Kind::DimId:
       return dyn_cast_or_null<IntegerAttr>(
-          operandConsts[cast<AffineDimExpr>(expr)->getPosition()]);
+          operandConsts[expr.cast<AffineDimExprRef>()->getPosition()]);
     case AffineExpr::Kind::SymbolId:
       return dyn_cast_or_null<IntegerAttr>(
-          operandConsts[numDims + cast<AffineSymbolExpr>(expr)->getPosition()]);
+          operandConsts[numDims +
+                        expr.cast<AffineSymbolExprRef>()->getPosition()]);
     }
   }
 
@@ -70,7 +71,7 @@
   IntegerAttr *
   constantFoldBinExpr(AffineExprRef expr,
                       std::function<uint64_t(int64_t, uint64_t)> op) {
-    auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+    auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
     auto *lhs = constantFold(binOpExpr->getLHS());
     auto *rhs = constantFold(binOpExpr->getRHS());
     if (!lhs || !rhs)
@@ -104,8 +105,7 @@
     return false;
   ArrayRef<AffineExprRef> results = getResults();
   for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
-    auto *expr =
-        const_cast<AffineDimExpr *>(dyn_cast<AffineDimExpr>(results[i]));
+    auto expr = results[i].dyn_cast<AffineDimExprRef>();
     if (!expr || expr->getPosition() != i)
       return false;
   }
@@ -113,14 +113,12 @@
 }
 
 bool AffineMap::isSingleConstant() {
-  return getNumResults() == 1 && isa<AffineConstantExpr>(getResult(0));
+  return getNumResults() == 1 && getResult(0).isa<AffineConstantExprRef>();
 }
 
 int64_t AffineMap::getSingleConstantResult() {
   assert(isSingleConstant() && "map must have a single constant result");
-  return const_cast<AffineConstantExpr *>(
-             cast<AffineConstantExpr>(getResult(0)))
-      ->getValue();
+  return getResult(0).cast<AffineConstantExprRef>()->getValue();
 }
 
 AffineExprRef AffineMap::getResult(unsigned idx) { return results[idx]; }
@@ -129,8 +127,8 @@
 AffineExprRef AffineBinaryOpExpr::simplifyAdd(AffineExprRef lhs,
                                               AffineExprRef rhs,
                                               MLIRContext *context) {
-  auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
-  auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+  auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
 
   // Fold if both LHS, RHS are a constant.
   if (lhsConst && rhsConst)
@@ -139,7 +137,7 @@
 
   // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
   // If only one of them is a symbolic expressions, make it the RHS.
-  if (isa<AffineConstantExpr>(lhs) ||
+  if (lhs.isa<AffineConstantExprRef>() ||
       (lhs->isSymbolicOrConstant() && !rhs->isSymbolicOrConstant())) {
     return AffineBinaryOpExpr::getAdd(rhs, lhs, context);
   }
@@ -152,19 +150,16 @@
       return lhs;
   }
   // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
-  auto *lBin =
-      const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
+  auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
   if (lBin && rhsConst && lBin->getKind() == Kind::Add) {
-    if (auto *lrhs = const_cast<AffineConstantExpr *>(
-            dyn_cast<AffineConstantExpr>(lBin->getRHS())))
+    if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>())
       return lBin->getLHS() + (lrhs->getValue() + rhsConst->getValue());
   }
 
   // When doing successive additions, bring constant to the right: turn (d0 + 2)
   // + d1 into (d0 + d1) + 2.
   if (lBin && lBin->getKind() == Kind::Add) {
-    if (auto *lrhs = const_cast<AffineConstantExpr *>(
-            dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
+    if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
       return lBin->getLHS() + rhs + lrhs;
     }
   }
@@ -176,8 +171,8 @@
 AffineExprRef AffineBinaryOpExpr::simplifyMul(AffineExprRef lhs,
                                               AffineExprRef rhs,
                                               MLIRContext *context) {
-  auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
-  auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+  auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
 
   if (lhsConst && rhsConst)
     return AffineConstantExpr::get(lhsConst->getValue() * rhsConst->getValue(),
@@ -188,7 +183,7 @@
   // Canonicalize the mul expression so that the constant/symbolic term is the
   // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
   // constant. (Note that a constant is trivially symbolic).
-  if (!rhs->isSymbolicOrConstant() || isa<AffineConstantExpr>(lhs)) {
+  if (!rhs->isSymbolicOrConstant() || lhs.isa<AffineConstantExprRef>()) {
     // At least one of them has to be symbolic.
     return AffineBinaryOpExpr::getMul(rhs, lhs, context);
   }
@@ -205,19 +200,16 @@
   }
 
   // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
-  auto *lBin =
-      const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
+  auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
   if (lBin && rhsConst && lBin->getKind() == Kind::Mul) {
-    if (auto *lrhs = const_cast<AffineConstantExpr *>(
-            dyn_cast<AffineConstantExpr>(lBin->getRHS())))
+    if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>())
       return lBin->getLHS() * (lrhs->getValue() * rhsConst->getValue());
   }
 
   // When doing successive multiplication, bring constant to the right: turn (d0
   // * 2) * d1 into (d0 * d1) * 2.
   if (lBin && lBin->getKind() == Kind::Mul) {
-    if (auto *lrhs = const_cast<AffineConstantExpr *>(
-            dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
+    if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
       return (lBin->getLHS() * rhs) * lrhs;
     }
   }
@@ -228,8 +220,8 @@
 AffineExprRef AffineBinaryOpExpr::simplifyFloorDiv(AffineExprRef lhs,
                                                    AffineExprRef rhs,
                                                    MLIRContext *context) {
-  auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
-  auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+  auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
 
   if (lhsConst && rhsConst)
     return AffineConstantExpr::get(
@@ -241,11 +233,9 @@
     if (rhsConst->getValue() == 1)
       return lhs;
 
-    auto *lBin =
-        const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
+    auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
     if (lBin && lBin->getKind() == Kind::Mul) {
-      if (auto *lrhs = const_cast<AffineConstantExpr *>(
-              dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
+      if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
         // rhsConst is known to be positive if a constant.
         if (lrhs->getValue() % rhsConst->getValue() == 0)
           return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue());
@@ -259,8 +249,8 @@
 AffineExprRef AffineBinaryOpExpr::simplifyCeilDiv(AffineExprRef lhs,
                                                   AffineExprRef rhs,
                                                   MLIRContext *context) {
-  auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
-  auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+  auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
 
   if (lhsConst && rhsConst)
     return AffineConstantExpr::get(
@@ -272,11 +262,9 @@
     if (rhsConst->getValue() == 1)
       return lhs;
 
-    auto *lBin =
-        const_cast<AffineBinaryOpExpr *>(dyn_cast<AffineBinaryOpExpr>(lhs));
+    auto lBin = lhs.dyn_cast<AffineBinaryOpExprRef>();
     if (lBin && lBin->getKind() == Kind::Mul) {
-      if (auto *lrhs = const_cast<AffineConstantExpr *>(
-              dyn_cast<AffineConstantExpr>(lBin->getRHS()))) {
+      if (auto lrhs = lBin->getRHS().dyn_cast<AffineConstantExprRef>()) {
         // rhsConst is known to be positive if a constant.
         if (lrhs->getValue() % rhsConst->getValue() == 0)
           return lBin->getLHS() * (lrhs->getValue() / rhsConst->getValue());
@@ -290,8 +278,8 @@
 AffineExprRef AffineBinaryOpExpr::simplifyMod(AffineExprRef lhs,
                                               AffineExprRef rhs,
                                               MLIRContext *context) {
-  auto *lhsConst = dyn_cast<AffineConstantExpr>(lhs);
-  auto *rhsConst = dyn_cast<AffineConstantExpr>(rhs);
+  auto lhsConst = lhs.dyn_cast<AffineConstantExprRef>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExprRef>();
 
   if (lhsConst && rhsConst)
     return AffineConstantExpr::get(
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index aa55cea..ed87d7f 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -107,8 +107,8 @@
     // Check if the affine map is single dim id or single symbol identity -
     // (i)->(i) or ()[s]->(i)
     return boundMap->getNumInputs() == 1 && boundMap->getNumResults() == 1 &&
-           (isa<AffineDimExpr>(boundMap->getResult(0)) ||
-            isa<AffineSymbolExpr>(boundMap->getResult(0)));
+           (boundMap->getResult(0).isa<AffineDimExprRef>() ||
+            boundMap->getResult(0).isa<AffineSymbolExprRef>());
   }
 
   // Visit functions.
@@ -579,13 +579,13 @@
   const char *binopSpelling = nullptr;
   switch (expr->getKind()) {
   case AffineExpr::Kind::SymbolId:
-    os << 's' << cast<AffineSymbolExpr>(expr)->getPosition();
+    os << 's' << expr.cast<AffineSymbolExprRef>()->getPosition();
     return;
   case AffineExpr::Kind::DimId:
-    os << 'd' << cast<AffineDimExpr>(expr)->getPosition();
+    os << 'd' << expr.cast<AffineDimExprRef>()->getPosition();
     return;
   case AffineExpr::Kind::Constant:
-    os << cast<AffineConstantExpr>(expr)->getValue();
+    os << expr.cast<AffineConstantExprRef>()->getValue();
     return;
   case AffineExpr::Kind::Add:
     binopSpelling = " + ";
@@ -604,7 +604,7 @@
     break;
   }
 
-  auto *binOp = cast<AffineBinaryOpExpr>(expr);
+  auto binOp = expr.cast<AffineBinaryOpExprRef>();
 
   // Handle tightly binding binary operators.
   if (binOp->getKind() != AffineExpr::Kind::Add) {
@@ -627,10 +627,10 @@
   // Pretty print addition to a product that has a negative operand as a
   // subtraction.
   AffineExprRef rhsExpr = binOp->getRHS();
-  if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(rhsExpr)) {
+  if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExprRef>()) {
     if (rhs->getKind() == AffineExpr::Kind::Mul) {
       AffineExprRef rrhsExpr = rhs->getRHS();
-      if (auto *rrhs = dyn_cast<AffineConstantExpr>(rrhsExpr)) {
+      if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExprRef>()) {
         if (rrhs->getValue() == -1) {
           printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
           os << " - ";
@@ -655,7 +655,7 @@
   }
 
   // Pretty print addition to a negative number as a subtraction.
-  if (auto *rhs = dyn_cast<AffineConstantExpr>(rhsExpr)) {
+  if (auto rhs = rhsExpr.dyn_cast<AffineConstantExprRef>()) {
     if (rhs->getValue() < 0) {
       printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
       os << " - " << -rhs->getValue();
@@ -1435,7 +1435,7 @@
 
     // Print constant bound.
     if (map->getNumDims() == 0 && map->getNumSymbols() == 0) {
-      if (auto *constExpr = dyn_cast<AffineConstantExpr>(expr)) {
+      if (auto constExpr = expr.dyn_cast<AffineConstantExprRef>()) {
         os << constExpr->getValue();
         return;
       }
@@ -1444,7 +1444,7 @@
     // Print bound that consists of a single SSA symbol if the map is over a
     // single symbol.
     if (map->getNumDims() == 0 && map->getNumSymbols() == 1) {
-      if (auto *symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
+      if (auto symExpr = expr.dyn_cast<AffineSymbolExprRef>()) {
         printOperand(bound.getOperand(0));
         return;
       }
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 7affe7e..9ef2eed 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -831,37 +831,38 @@
   // Identifier lists for polyhedral structures.
   ParseResult parseDimIdList(unsigned &numDims);
   ParseResult parseSymbolIdList(unsigned &numSymbols);
-  ParseResult parseIdentifierDefinition(AffineExpr *idExpr);
+  ParseResult parseIdentifierDefinition(AffineExprRef idExpr);
 
-  AffineExpr *parseAffineExpr();
-  AffineExpr *parseParentheticalExpr();
-  AffineExpr *parseNegateExpression(AffineExpr *lhs);
-  AffineExpr *parseIntegerExpr();
-  AffineExpr *parseBareIdExpr();
+  AffineExprRef parseAffineExpr();
+  AffineExprRef parseParentheticalExpr();
+  AffineExprRef parseNegateExpression(AffineExprRef lhs);
+  AffineExprRef parseIntegerExpr();
+  AffineExprRef parseBareIdExpr();
 
-  AffineExpr *getBinaryAffineOpExpr(AffineHighPrecOp op, AffineExpr *lhs,
-                                    AffineExpr *rhs, SMLoc opLoc);
-  AffineExpr *getBinaryAffineOpExpr(AffineLowPrecOp op, AffineExpr *lhs,
-                                    AffineExpr *rhs);
-  AffineExpr *parseAffineOperandExpr(AffineExpr *lhs);
-  AffineExpr *parseAffineLowPrecOpExpr(AffineExpr *llhs,
-                                       AffineLowPrecOp llhsOp);
-  AffineExpr *parseAffineHighPrecOpExpr(AffineExpr *llhs,
-                                        AffineHighPrecOp llhsOp,
-                                        SMLoc llhsOpLoc);
-  AffineExpr *parseAffineConstraint(bool *isEq);
+  AffineExprRef getBinaryAffineOpExpr(AffineHighPrecOp op, AffineExprRef lhs,
+                                      AffineExprRef rhs, SMLoc opLoc);
+  AffineExprRef getBinaryAffineOpExpr(AffineLowPrecOp op, AffineExprRef lhs,
+                                      AffineExprRef rhs);
+  AffineExprRef parseAffineOperandExpr(AffineExprRef lhs);
+  AffineExprRef parseAffineLowPrecOpExpr(AffineExprRef llhs,
+                                         AffineLowPrecOp llhsOp);
+  AffineExprRef parseAffineHighPrecOpExpr(AffineExprRef llhs,
+                                          AffineHighPrecOp llhsOp,
+                                          SMLoc llhsOpLoc);
+  AffineExprRef parseAffineConstraint(bool *isEq);
 
 private:
-  SmallVector<std::pair<StringRef, AffineExpr *>, 4> dimsAndSymbols;
+  SmallVector<std::pair<StringRef, AffineExprRef>, 4> dimsAndSymbols;
 };
 } // end anonymous namespace
 
 /// Create an affine binary high precedence op expression (mul's, div's, mod).
 /// opLoc is the location of the op token to be used to report errors
 /// for non-conforming expressions.
-AffineExpr *AffineParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
-                                                AffineExpr *lhs,
-                                                AffineExpr *rhs, SMLoc opLoc) {
+AffineExprRef AffineParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
+                                                  AffineExprRef lhs,
+                                                  AffineExprRef rhs,
+                                                  SMLoc opLoc) {
   // TODO: make the error location info accurate.
   switch (op) {
   case Mul:
@@ -899,9 +900,9 @@
 }
 
 /// Create an affine binary low precedence op expression (add, sub).
-AffineExpr *AffineParser::getBinaryAffineOpExpr(AffineLowPrecOp op,
-                                                AffineExpr *lhs,
-                                                AffineExpr *rhs) {
+AffineExprRef AffineParser::getBinaryAffineOpExpr(AffineLowPrecOp op,
+                                                  AffineExprRef lhs,
+                                                  AffineExprRef rhs) {
   switch (op) {
   case AffineLowPrecOp::Add:
     return builder.getAddExpr(lhs, rhs);
@@ -959,10 +960,10 @@
 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
 /// null. llhsOpLoc is the location of the llhsOp token that will be used to
 /// report an error for non-conforming expressions.
-AffineExpr *AffineParser::parseAffineHighPrecOpExpr(AffineExpr *llhs,
-                                                    AffineHighPrecOp llhsOp,
-                                                    SMLoc llhsOpLoc) {
-  AffineExpr *lhs = parseAffineOperandExpr(llhs);
+AffineExprRef AffineParser::parseAffineHighPrecOpExpr(AffineExprRef llhs,
+                                                      AffineHighPrecOp llhsOp,
+                                                      SMLoc llhsOpLoc) {
+  AffineExprRef lhs = parseAffineOperandExpr(llhs);
   if (!lhs)
     return nullptr;
 
@@ -970,7 +971,7 @@
   auto opLoc = getToken().getLoc();
   if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
     if (llhs) {
-      AffineExpr *expr = getBinaryAffineOpExpr(llhsOp, llhs, lhs, opLoc);
+      AffineExprRef expr = getBinaryAffineOpExpr(llhsOp, llhs, lhs, opLoc);
       if (!expr)
         return nullptr;
       return parseAffineHighPrecOpExpr(expr, op, opLoc);
@@ -990,13 +991,13 @@
 /// Parse an affine expression inside parentheses.
 ///
 ///   affine-expr ::= `(` affine-expr `)`
-AffineExpr *AffineParser::parseParentheticalExpr() {
+AffineExprRef AffineParser::parseParentheticalExpr() {
   if (parseToken(Token::l_paren, "expected '('"))
     return nullptr;
   if (getToken().is(Token::r_paren))
     return (emitError("no expression inside parentheses"), nullptr);
 
-  auto *expr = parseAffineExpr();
+  auto expr = parseAffineExpr();
   if (!expr)
     return nullptr;
   if (parseToken(Token::r_paren, "expected ')'"))
@@ -1008,11 +1009,11 @@
 /// Parse the negation expression.
 ///
 ///   affine-expr ::= `-` affine-expr
-AffineExpr *AffineParser::parseNegateExpression(AffineExpr *lhs) {
+AffineExprRef AffineParser::parseNegateExpression(AffineExprRef lhs) {
   if (parseToken(Token::minus, "expected '-'"))
     return nullptr;
 
-  AffineExpr *operand = parseAffineOperandExpr(lhs);
+  AffineExprRef operand = parseAffineOperandExpr(lhs);
   // Since negation has the highest precedence of all ops (including high
   // precedence ops) but lower than parentheses, we are only going to use
   // parseAffineOperandExpr instead of parseAffineExpr here.
@@ -1027,7 +1028,7 @@
 /// Parse a bare id that may appear in an affine expression.
 ///
 ///   affine-expr ::= bare-id
-AffineExpr *AffineParser::parseBareIdExpr() {
+AffineExprRef AffineParser::parseBareIdExpr() {
   if (getToken().isNot(Token::bare_identifier))
     return (emitError("expected bare identifier"), nullptr);
 
@@ -1045,7 +1046,7 @@
 /// Parse a positive integral constant appearing in an affine expression.
 ///
 ///   affine-expr ::= integer-literal
-AffineExpr *AffineParser::parseIntegerExpr() {
+AffineExprRef AffineParser::parseIntegerExpr() {
   auto val = getToken().getUInt64IntegerValue();
   if (!val.hasValue() || (int64_t)val.getValue() < 0)
     return (emitError("constant too large for index"), nullptr);
@@ -1063,7 +1064,7 @@
 //  operand expression, it's an op expression and will be parsed via
 //  parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and -l
 //  are valid operands that will be parsed by this function.
-AffineExpr *AffineParser::parseAffineOperandExpr(AffineExpr *lhs) {
+AffineExprRef AffineParser::parseAffineOperandExpr(AffineExprRef lhs) {
   switch (getToken().getKind()) {
   case Token::bare_identifier:
     return parseBareIdExpr();
@@ -1113,16 +1114,16 @@
 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where (e2*e3)
 /// will be parsed using parseAffineHighPrecOpExpr().
-AffineExpr *AffineParser::parseAffineLowPrecOpExpr(AffineExpr *llhs,
-                                                   AffineLowPrecOp llhsOp) {
-  AffineExpr *lhs;
+AffineExprRef AffineParser::parseAffineLowPrecOpExpr(AffineExprRef llhs,
+                                                     AffineLowPrecOp llhsOp) {
+  AffineExprRef lhs;
   if (!(lhs = parseAffineOperandExpr(llhs)))
     return nullptr;
 
   // Found an LHS. Deal with the ops.
   if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
     if (llhs) {
-      AffineExpr *sum = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
+      AffineExprRef sum = getBinaryAffineOpExpr(llhsOp, llhs, lhs);
       return parseAffineLowPrecOpExpr(sum, lOp);
     }
     // No LLHS, get RHS and form the expression.
@@ -1132,13 +1133,13 @@
   if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
     // We have a higher precedence op here. Get the rhs operand for the llhs
     // through parseAffineHighPrecOpExpr.
-    AffineExpr *highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
+    AffineExprRef highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
     if (!highRes)
       return nullptr;
 
     // If llhs is null, the product forms the first operand of the yet to be
     // found expression. If non-null, the op to associate with llhs is llhsOp.
-    AffineExpr *expr =
+    AffineExprRef expr =
         llhs ? getBinaryAffineOpExpr(llhsOp, llhs, highRes) : highRes;
 
     // Recurse for subsequent low prec op's after the affine high prec op
@@ -1169,14 +1170,14 @@
 /// Additional conditions are checked depending on the production. For eg., one
 /// of the operands for `*` has to be either constant/symbolic; the second
 /// operand for floordiv, ceildiv, and mod has to be a positive integer.
-AffineExpr *AffineParser::parseAffineExpr() {
+AffineExprRef AffineParser::parseAffineExpr() {
   return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
 }
 
 /// Parse a dim or symbol from the lists appearing before the actual expressions
 /// of the affine map. Update our state to store the dimensional/symbolic
 /// identifier.
-ParseResult AffineParser::parseIdentifierDefinition(AffineExpr *idExpr) {
+ParseResult AffineParser::parseIdentifierDefinition(AffineExprRef idExpr) {
   if (getToken().isNot(Token::bare_identifier))
     return emitError("expected bare identifier");
 
@@ -1240,7 +1241,7 @@
 
   SmallVector<AffineExprRef, 4> exprs;
   auto parseElt = [&]() -> ParseResult {
-    auto *elt = parseAffineExpr();
+    auto elt = parseAffineExpr();
     ParseResult res = elt ? ParseSuccess : ParseFailure;
     exprs.push_back(elt);
     return res;
@@ -1266,7 +1267,7 @@
 
     auto parseRangeSize = [&]() -> ParseResult {
       auto loc = getToken().getLoc();
-      auto *elt = parseAffineExpr();
+      auto elt = parseAffineExpr();
       if (!elt)
         return ParseFailure;
 
@@ -2445,8 +2446,8 @@
 /// isEq is set to true if the parsed constraint is an equality, false if it is
 /// an inequality (greater than or equal).
 ///
-AffineExpr *AffineParser::parseAffineConstraint(bool *isEq) {
-  AffineExpr *expr = parseAffineExpr();
+AffineExprRef AffineParser::parseAffineConstraint(bool *isEq) {
+  AffineExprRef expr = parseAffineExpr();
   if (!expr)
     return nullptr;
 
@@ -2504,7 +2505,7 @@
   SmallVector<bool, 4> isEqs;
   auto parseElt = [&]() -> ParseResult {
     bool isEq;
-    auto *elt = parseAffineConstraint(&isEq);
+    auto elt = parseAffineConstraint(&isEq);
     ParseResult res = elt ? ParseSuccess : ParseFailure;
     if (elt) {
       constraints.push_back(elt);
diff --git a/lib/Transforms/Utils.cpp b/lib/Transforms/Utils.cpp
index 008a364..cc1c797 100644
--- a/lib/Transforms/Utils.cpp
+++ b/lib/Transforms/Utils.cpp
@@ -53,6 +53,7 @@
                                     ArrayRef<SSAValue *> extraIndices,
                                     AffineMap *indexRemap) {
   unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
+  (void)newMemRefRank; // unused in opt mode
   unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
   (void)newMemRefRank;
   if (indexRemap) {