Fix misc bugs / TODOs / other improvements to analysis utils

- fix for getConstantBoundOnDimSize: floordiv -> ceildiv for extent
- make getConstantBoundOnDimSize also return the identifier upper bound
- fix unionBoundingBox to correctly use the divisor and upper bound identified by
  getConstantBoundOnDimSize
- deal with loop step correctly in addAffineForOpDomain (covers most cases now)
- fully compose bound map / operands and simplify/canonicalize before adding
  dim/symbol to FlatAffineConstraints; fixes false positives in -memref-bound-check; add
  test case there
- expose mlir::isTopLevelSymbol from AffineOps

PiperOrigin-RevId: 238050395
diff --git a/include/mlir/AffineOps/AffineOps.h b/include/mlir/AffineOps/AffineOps.h
index f4d0338..487b0a3 100644
--- a/include/mlir/AffineOps/AffineOps.h
+++ b/include/mlir/AffineOps/AffineOps.h
@@ -33,6 +33,10 @@
 class FlatAffineConstraints;
 class FuncBuilder;
 
+/// A utility function to check if a value is defined at the top level of a
+/// function. A value defined at the top level is always a valid symbol.
+bool isTopLevelSymbol(const Value *value);
+
 class AffineOpsDialect : public Dialect {
 public:
   AffineOpsDialect(MLIRContext *context);
diff --git a/include/mlir/Analysis/AffineStructures.h b/include/mlir/Analysis/AffineStructures.h
index 13c4ccc..81481cc 100644
--- a/include/mlir/Analysis/AffineStructures.h
+++ b/include/mlir/Analysis/AffineStructures.h
@@ -464,11 +464,12 @@
   void addId(IdKind kind, unsigned pos, Value *id = nullptr);
 
   /// Add the specified values as a dim or symbol id depending on its nature, if
-  /// it already doesn't exist in the system. The identifier is added to the end
-  /// of the existing dims or symbols. Additional information on the identifier
-  /// is extracted from the IR (if it's a loop IV or a symbol, and added to the
-  /// constraint system).
-  void addDimOrSymbolId(Value *id);
+  /// it already doesn't exist in the system. `id' has to be either a terminal
+  /// symbol or a loop IV, i.e., it cannot be the result affine.apply of any
+  /// symbols or loop IVs. The identifier is added to the end of the existing
+  /// dims or symbols. Additional information on the identifier is extracted
+  /// from the IR and added to the constraint system.
+  void addInductionVarOrTerminalSymbol(Value *id);
 
   /// Composes the affine value map with this FlatAffineConstrains, adding the
   /// results of the map as dimensions at the front [0, vMap->getNumResults())
@@ -619,7 +620,8 @@
   Optional<int64_t>
   getConstantBoundOnDimSize(unsigned pos,
                             SmallVectorImpl<int64_t> *lb = nullptr,
-                            int64_t *lbFloorDivisor = nullptr) const;
+                            int64_t *lbFloorDivisor = nullptr,
+                            SmallVectorImpl<int64_t> *ub = nullptr) const;
 
   /// Returns the constant lower bound for the pos^th identifier if there is
   /// one; None otherwise.
diff --git a/lib/AffineOps/AffineOps.cpp b/lib/AffineOps/AffineOps.cpp
index c19565c..a9de42e 100644
--- a/lib/AffineOps/AffineOps.cpp
+++ b/lib/AffineOps/AffineOps.cpp
@@ -40,8 +40,8 @@
 }
 
 /// A utility function to check if a value is defined at the top level of a
-/// function.
-static bool isDefinedAtTopLevel(const Value *value) {
+/// function. A value defined at the top level is always a valid symbol.
+bool mlir::isTopLevelSymbol(const Value *value) {
   if (auto *arg = dyn_cast<BlockArgument>(value))
     return arg->getOwner()->getParent()->getContainingFunction();
   return value->getDefiningInst()->getParentInst() == nullptr;
@@ -65,7 +65,7 @@
     // The dim op is okay if its operand memref/tensor is defined at the top
     // level.
     if (auto dimOp = inst->dyn_cast<DimOp>())
-      return isDefinedAtTopLevel(dimOp->getOperand());
+      return isTopLevelSymbol(dimOp->getOperand());
     return false;
   }
   // This value is a block argument (which also includes 'for' loop IVs).
@@ -90,7 +90,7 @@
     // The dim op is okay if its operand memref/tensor is defined at the top
     // level.
     if (auto dimOp = inst->dyn_cast<DimOp>())
-      return isDefinedAtTopLevel(dimOp->getOperand());
+      return isTopLevelSymbol(dimOp->getOperand());
     return false;
   }
   // Otherwise, the only valid symbol is a top level block argument.
diff --git a/lib/Analysis/AffineAnalysis.cpp b/lib/Analysis/AffineAnalysis.cpp
index 4978574..8b845af 100644
--- a/lib/Analysis/AffineAnalysis.cpp
+++ b/lib/Analysis/AffineAnalysis.cpp
@@ -675,6 +675,7 @@
                                                memref->getType().getContext());
   SmallVector<Value *, 8> operands(indices.begin(), indices.end());
   fullyComposeAffineMapAndOperands(&map, &operands);
+  map = simplifyAffineMap(map);
   canonicalizeMapAndOperands(&map, &operands);
   accessMap->reset(map, operands);
 }
diff --git a/lib/Analysis/AffineStructures.cpp b/lib/Analysis/AffineStructures.cpp
index e7b7961..68fccf7 100644
--- a/lib/Analysis/AffineStructures.cpp
+++ b/lib/Analysis/AffineStructures.cpp
@@ -709,24 +709,27 @@
   }
 }
 
-void FlatAffineConstraints::addDimOrSymbolId(Value *id) {
+void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) {
   if (containsId(*id))
     return;
-  if (isValidSymbol(id)) {
-    addSymbolId(getNumSymbolIds(), id);
-    // Check if the symbol is a constant.
-    if (auto *opInst = id->getDefiningInst()) {
-      if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {
-        setIdToConstant(*id, constOp->getValue());
-      }
-    }
-  } else {
+
+  // Caller is expected to fully compose map/operands if necessary.
+  assert((isTopLevelSymbol(id) || isForInductionVar(id)) &&
+         "non-terminal symbol / loop IV expected");
+  // Outer loop IVs could be used in forOp's bounds.
+  if (auto loop = getForInductionVarOwner(id)) {
     addDimId(getNumDimIds(), id);
-    if (auto loop = getForInductionVarOwner(id)) {
-      // Outer loop IVs could be used in forOp's bounds.
-      if (failed(this->addAffineForOpDomain(loop)))
-        LLVM_DEBUG(loop->emitWarning(
-            "failed to add domain info to constraint system"));
+    if (failed(this->addAffineForOpDomain(loop)))
+      LLVM_DEBUG(
+          loop->emitWarning("failed to add domain info to constraint system"));
+    return;
+  }
+  // Add top level symbol.
+  addSymbolId(getNumSymbolIds(), id);
+  // Check if the symbol is a constant.
+  if (auto *opInst = id->getDefiningInst()) {
+    if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {
+      setIdToConstant(*id, constOp->getValue());
     }
   }
 }
@@ -740,11 +743,30 @@
     return failure();
   }
 
-  if (forOp->getStep() != 1)
-    LLVM_DEBUG(
-        forOp->emitWarning("Domain conservative: non-unit stride not handled"));
-
   int64_t step = forOp->getStep();
+  if (step != 1) {
+    if (!forOp->hasConstantLowerBound())
+      forOp->emitWarning("domain conservatively approximated");
+    else {
+      // Add constraints for the stride.
+      // (iv - lb) % step = 0 can be written as:
+      // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
+      // Add local variable 'q' and add the above equality.
+      // The first constraint is q = (iv - lb) floordiv step
+      SmallVector<int64_t, 8> dividend(getNumCols(), 0);
+      int64_t lb = forOp->getConstantLowerBound();
+      dividend[pos] = 1;
+      dividend.back() -= lb;
+      addLocalFloorDiv(dividend, step);
+      // Second constraint: (iv - lb) - step * q = 0.
+      SmallVector<int64_t, 8> eq(getNumCols(), 0);
+      eq[pos] = 1;
+      eq.back() -= lb;
+      // For the local var just added above.
+      eq[getNumCols() - 2] = -step;
+      addEquality(eq);
+    }
+  }
 
   if (forOp->hasConstantLowerBound()) {
     addConstantLowerBound(pos, forOp->getConstantLowerBound());
@@ -760,7 +782,7 @@
   }
 
   if (forOp->hasConstantUpperBound()) {
-    addConstantUpperBound(pos, forOp->getConstantUpperBound() - step);
+    addConstantUpperBound(pos, forOp->getConstantUpperBound() - 1);
     return success();
   }
   // Non-constant upper bound case.
@@ -1617,8 +1639,8 @@
 
 LogicalResult
 FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
-                                            ArrayRef<Value *> operands, bool eq,
-                                            bool lower) {
+                                            ArrayRef<Value *> boundOperands,
+                                            bool eq, bool lower) {
   assert(pos < getNumDimAndSymbolIds() && "invalid position");
   // Equality follows the logic of lower bound except that we add an equality
   // instead of an inequality.
@@ -1626,12 +1648,19 @@
   if (eq)
     lower = true;
 
+  // Fully commpose map and operands; canonicalize and simplify so that we
+  // transitively get to terminal symbols or loop IVs.
+  auto map = boundMap;
+  SmallVector<Value *, 4> operands(boundOperands.begin(), boundOperands.end());
+  fullyComposeAffineMapAndOperands(&map, &operands);
+  map = simplifyAffineMap(map);
+  canonicalizeMapAndOperands(&map, &operands);
   for (auto *operand : operands)
-    addDimOrSymbolId(operand);
+    addInductionVarOrTerminalSymbol(operand);
 
   FlatAffineConstraints localVarCst;
   std::vector<SmallVector<int64_t, 8>> flatExprs;
-  if (failed(getFlattenedAffineExprs(boundMap, &flatExprs, &localVarCst))) {
+  if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) {
     LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
     return failure();
   }
@@ -1671,7 +1700,7 @@
     SmallVector<int64_t, 4> ineq(getNumCols(), 0);
     ineq[pos] = lower ? 1 : -1;
     // Dims and symbols.
-    for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
+    for (unsigned j = 0, e = map.getNumInputs(); j < e; j++) {
       ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j];
     }
     // Copy over the local id coefficients.
@@ -1961,7 +1990,8 @@
 //       s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
 //       ceil(s0 - 7 / 8) = floor(s0 / 8)).
 Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
-    unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *lbFloorDivisor) const {
+    unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *lbFloorDivisor,
+    SmallVectorImpl<int64_t> *ub) const {
   assert(pos < getNumDimIds() && "Invalid identifier position");
   assert(getNumLocalIds() == 0);
 
@@ -1977,12 +2007,17 @@
     if (lb) {
       // Set lb to the symbolic value.
       lb->resize(getNumSymbolIds() + 1);
+      if (ub)
+        ub->resize(getNumSymbolIds() + 1);
       for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
         int64_t v = atEq(eqRow, pos);
         // atEq(eqRow, pos) is either -1 or 1.
         assert(v * v == 1);
         (*lb)[c] = v < 0 ? atEq(eqRow, getNumDimIds() + c) / -v
                          : -atEq(eqRow, getNumDimIds() + c) / v;
+        // Since this is an equality, ub = lb.
+        if (ub)
+          (*ub)[c] = (*lb)[c];
       }
       assert(lbFloorDivisor &&
              "both lb and divisor or none should be provided");
@@ -2028,7 +2063,7 @@
   // powerful. Not needed for hyper-rectangular iteration spaces.
 
   Optional<int64_t> minDiff = None;
-  unsigned minLbPosition;
+  unsigned minLbPosition, minUbPosition;
   for (auto ubPos : ubIndices) {
     for (auto lbPos : lbIndices) {
       // Look for a lower bound and an upper bound that only differ by a
@@ -2044,28 +2079,38 @@
         }
       if (j < getNumCols() - 1)
         continue;
-      int64_t diff = floorDiv(atIneq(ubPos, getNumCols() - 1) +
-                                  atIneq(lbPos, getNumCols() - 1) + 1,
-                              atIneq(lbPos, pos));
+      int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
+                                 atIneq(lbPos, getNumCols() - 1) + 1,
+                             atIneq(lbPos, pos));
       if (minDiff == None || diff < minDiff) {
         minDiff = diff;
         minLbPosition = lbPos;
+        minUbPosition = ubPos;
       }
     }
   }
   if (lb && minDiff.hasValue()) {
     // Set lb to the symbolic lower bound.
     lb->resize(getNumSymbolIds() + 1);
+    if (ub)
+      ub->resize(getNumSymbolIds() + 1);
     // The lower bound is the ceildiv of the lb constraint over the coefficient
     // of the variable at 'pos'. We express the ceildiv equivalently as a floor
     // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
     // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
     *lbFloorDivisor = atIneq(minLbPosition, pos);
+    assert(*lbFloorDivisor == -atIneq(minUbPosition, pos));
     for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
       (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
     }
-    // ceildiv (val / d) = floordiv (val + d - 1 / d); hence, the addition of
-    // 'atIneq(minLbPosition, pos) - 1' to the constant term.
+    if (ub) {
+      for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++)
+        (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c);
+    }
+    // The lower bound leads to a ceildiv while the upper bound is a floordiv
+    // whenever the cofficient at pos != 1. ceildiv (val / d) = floordiv (val +
+    // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
+    // the constant term for the lower bound.
     (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
   }
   return minDiff;
@@ -2664,28 +2709,33 @@
   // To compute final new lower and upper bounds for the union.
   SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
 
-  int64_t lbDivisor, otherLbDivisor;
+  int64_t lbFloorDivisor, otherLbFloorDivisor;
   for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
-    auto extent = getConstantBoundOnDimSize(d, &lb, &lbDivisor);
+    auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
     if (!extent.hasValue())
       // TODO(bondhugula): symbolic extents when necessary.
       // TODO(bondhugula): handle union if a dimension is unbounded.
       return failure();
 
-    auto otherExtent =
-        other.getConstantBoundOnDimSize(d, &otherLb, &otherLbDivisor);
-    if (!otherExtent.hasValue() || lbDivisor != otherLbDivisor)
+    auto otherExtent = other.getConstantBoundOnDimSize(
+        d, &otherLb, &otherLbFloorDivisor, &otherUb);
+    if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
       // TODO(bondhugula): symbolic extents when necessary.
       return failure();
 
-    assert(lbDivisor > 0 && "divisor always expected to be positive");
+    assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
 
     auto res = compareBounds(lb, otherLb);
     // Identify min.
     if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
       minLb = lb;
+      // Since the divisor is for a floordiv, we need to convert to ceildiv,
+      // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
+      // div * i >= expr - div + 1.
+      minLb.back() -= lbFloorDivisor - 1;
     } else if (res == BoundCmpResult::Greater) {
       minLb = otherLb;
+      minLb.back() -= otherLbFloorDivisor - 1;
     } else {
       // Uncomparable - check for constant lower/upper bounds.
       auto constLb = getConstantLowerBound(d);
@@ -2696,13 +2746,7 @@
       minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
     }
 
-    // Do the same for ub's but max of upper bounds.
-    ub = lb;
-    otherUb = otherLb;
-    ub.back() += extent.getValue() - 1;
-    otherUb.back() += otherExtent.getValue() - 1;
-
-    // Identify max.
+    // Do the same for ub's but max of upper bounds. Identify max.
     auto uRes = compareBounds(ub, otherUb);
     if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
       maxUb = ub;
@@ -2723,8 +2767,8 @@
 
     // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
     // and so it's the divisor for newLb and newUb as well.
-    newLb[d] = lbDivisor;
-    newUb[d] = -lbDivisor;
+    newLb[d] = lbFloorDivisor;
+    newUb[d] = -lbFloorDivisor;
     // Copy over the symbolic part + constant term.
     std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds());
     std::transform(newLb.begin() + getNumDimIds(), newLb.end(),
@@ -2741,6 +2785,9 @@
     addInequality(boundingLbs[d]);
     addInequality(boundingUbs[d]);
   }
+  // TODO(mlir-team): copy over pure symbolic constraints from this and 'other'
+  // over to the union (since the above are just the union along dimensions); we
+  // shouldn't be discarding any other constraints on the symbols.
 
   return success();
 }
diff --git a/lib/Analysis/Utils.cpp b/lib/Analysis/Utils.cpp
index b09f84c..0b5e6ab 100644
--- a/lib/Analysis/Utils.cpp
+++ b/lib/Analysis/Utils.cpp
@@ -255,7 +255,7 @@
   if (sliceState != nullptr) {
     // Add dim and symbol slice operands.
     for (const auto &operand : sliceState->lbOperands[0]) {
-      cst.addDimOrSymbolId(const_cast<Value *>(operand));
+      cst.addInductionVarOrTerminalSymbol(const_cast<Value *>(operand));
     }
     // Add upper/lower bounds from 'sliceState' to 'cst'.
     LogicalResult ret =
diff --git a/lib/Transforms/PipelineDataTransfer.cpp b/lib/Transforms/PipelineDataTransfer.cpp
index fa9c4dc..97532fd 100644
--- a/lib/Transforms/PipelineDataTransfer.cpp
+++ b/lib/Transforms/PipelineDataTransfer.cpp
@@ -105,8 +105,6 @@
   }
 
   // Create and place the alloc right before the 'for' instruction.
-  // TODO(mlir-team): we are assuming scoped allocation here, and aren't
-  // inserting a dealloc -- this isn't the right thing.
   Value *newMemRef =
       bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
 
diff --git a/test/Transforms/dma-generate.mlir b/test/Transforms/dma-generate.mlir
index 1799631..1b3d35e 100644
--- a/test/Transforms/dma-generate.mlir
+++ b/test/Transforms/dma-generate.mlir
@@ -487,6 +487,38 @@
 
 // ----
 
+// This should create a buffer of size 2 for %arg2.
+
+#map_lb = (d0) -> (d0)
+#map_ub = (d0) -> (d0 + 3)
+#map_acc = (d0) -> (d0 floordiv 8)
+// CHECK-LABEL: func @test_analysis_util
+func @test_analysis_util(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9xf32>, %arg2: memref<2xf32>) -> (memref<144x9xf32>, memref<2xf32>) {
+  %c0 = constant 0 : index
+  %0 = alloc() : memref<64x1xf32>
+  %1 = alloc() : memref<144x4xf32>
+  %2 =  constant 0.0 : f32
+  for %i8 = 0 to 9 step 3 {
+    for %i9 = #map_lb(%i8) to #map_ub(%i8) {
+      for %i17 = 0 to 64 {
+        %23 = affine.apply #map_acc(%i9)
+        %25 = load %arg2[%23] : memref<2xf32>
+        %26 = affine.apply #map_lb(%i17)
+        %27 = load %0[%26, %c0] : memref<64x1xf32>
+        store %27, %arg2[%23] : memref<2xf32>
+      }
+    }
+  }
+  return %arg1, %arg2 : memref<144x9xf32>, memref<2xf32>
+}
+// CHECK:       for %i0 = 0 to 9 step 3 {
+// CHECK:         [[BUF:%[0-9]+]] = alloc() : memref<2xf32, 2>
+// CHECK:         dma_start %arg2[%4], [[BUF]]
+// CHECK:         dma_wait %6[%c0], %c2_0 : memref<1xi32>
+// CHECK:         for %i1 =
+
+// -----
+
 // Since the fast memory size is 4 KB, DMA generation will happen right under
 // %i0.
 
@@ -515,7 +547,7 @@
   return
 }
 
-// ----
+// -----
 
 // This a 3-d loop nest tiled by 4 x 4 x 4. Under %i, %j, %k, the size of a
 // tile of arg0, arg1, and arg2 accessed is 4 KB (each), i.e., 12 KB in total.
@@ -557,9 +589,9 @@
 // FAST-MEM-16KB:       dma_wait
 // FAST-MEM-16KB:       dma_start %arg1
 // FAST-MEM-16KB:       dma_wait
-// FAST-MEM-16KB:       for %i3 = #map2(%i0) to #map3(%i0) {
-// FAST-MEM-16KB-NEXT:    for %i4 = #map2(%i1) to #map3(%i1) {
-// FAST-MEM-16KB-NEXT:      for %i5 = #map2(%i2) to #map3(%i2) {
+// FAST-MEM-16KB:       for %i3 = #map{{[0-9]+}}(%i0) to #map{{[0-9]+}}(%i0) {
+// FAST-MEM-16KB-NEXT:    for %i4 = #map{{[0-9]+}}(%i1) to #map{{[0-9]+}}(%i1) {
+// FAST-MEM-16KB-NEXT:      for %i5 = #map{{[0-9]+}}(%i2) to #map{{[0-9]+}}(%i2) {
 // FAST-MEM-16KB:           }
 // FAST-MEM-16KB:         }
 // FAST-MEM-16KB:       }
diff --git a/test/Transforms/memref-bound-check.mlir b/test/Transforms/memref-bound-check.mlir
index d78261c..8a276d6 100644
--- a/test/Transforms/memref-bound-check.mlir
+++ b/test/Transforms/memref-bound-check.mlir
@@ -266,3 +266,22 @@
   }
   return
 }
+
+// -----
+
+// This should not give an out of bounds error. The result of the affine.apply
+// is composed into the bound map during analysis.
+
+#map_lb = (d0) -> (d0)
+#map_ub = (d0) -> (d0 + 4)
+
+// CHECK-LABEL: func @non_composed_bound_operand
+func @non_composed_bound_operand(%arg0: memref<1024xf32>) {
+  for %i0 = 4 to 1028 step 4 {
+    %i1 = affine.apply (d0) -> (d0 - 4) (%i0)
+    for %i2 = #map_lb(%i1) to #map_ub(%i1) {
+        %0 = load %arg0[%i2] : memref<1024xf32>
+    }
+  }
+  return
+}