Extend/improve getSliceBounds() / complete TODO + update unionBoundingBox

- compute slices precisely where the destination iteration depends on multiple source
  iterations (instead of over-approximating to the whole source loop extent)
- update unionBoundingBox to deal with input with non-matching symbols
- reenable disabled backend test case

PiperOrigin-RevId: 234714069
diff --git a/include/mlir/IR/AffineExpr.h b/include/mlir/IR/AffineExpr.h
index d7eab0f..a652ff6 100644
--- a/include/mlir/IR/AffineExpr.h
+++ b/include/mlir/IR/AffineExpr.h
@@ -222,6 +222,15 @@
 AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
                                  AffineExpr rhs);
 
+/// Constructs an affine expression from a flat ArrayRef. If there are local
+/// identifiers (neither dimensional nor symbolic) that appear in the sum of
+/// products expression, 'localExprs' is expected to have the AffineExpr
+/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
+/// format [dims, symbols, locals, constant term].
+AffineExpr toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
+                        unsigned numSymbols, ArrayRef<AffineExpr> localExprs,
+                        MLIRContext *context);
+
 raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr);
 
 template <typename U> bool AffineExpr::isa() const {
diff --git a/include/mlir/IR/AffineStructures.h b/include/mlir/IR/AffineStructures.h
index c90731c..20ca7d7 100644
--- a/include/mlir/IR/AffineStructures.h
+++ b/include/mlir/IR/AffineStructures.h
@@ -424,7 +424,8 @@
   bool findId(const Value &id, unsigned *pos) const;
 
   // Add identifiers of the specified kind - specified positions are relative to
-  // the kind of identifier. 'id' is the Value corresponding to the
+  // the kind of identifier. The coefficient column corresponding to the added
+  // identifier is initialized to zero. 'id' is the Value corresponding to the
   // identifier that can optionally be provided.
   void addDimId(unsigned pos, Value *id = nullptr);
   void addSymbolId(unsigned pos, Value *id = nullptr);
@@ -579,6 +580,17 @@
   /// one; None otherwise.
   Optional<int64_t> getConstantUpperBound(unsigned pos) const;
 
+  /// Gets the lower and upper bound of the pos^th identifier treating
+  /// [dimStartPos, symbStartPos) as dimensions and [symStartPos,
+  /// getNumDimAndSymbolIds) as symbols. The returned multi-dimensional maps
+  /// in the pair represent the max and min of potentially multiple affine
+  /// expressions. The upper bound is exclusive. 'localExprs' holds pre-computed
+  /// AffineExpr's for all local identifiers in the system.
+  std::pair<AffineMap, AffineMap>
+  getLowerAndUpperBound(unsigned pos, unsigned dimStartPos,
+                        unsigned symStartPos, ArrayRef<AffineExpr> localExprs,
+                        MLIRContext *context);
+
   /// Returns true if the set can be trivially detected as being
   /// hyper-rectangular on the specified contiguous set of identifiers.
   bool isHyperRectangular(unsigned pos, unsigned num) const;
@@ -588,6 +600,9 @@
   /// constraint.
   void removeTrivialRedundancy();
 
+  /// A more expensive check to detect redundant inequalities.
+  void removeRedundantInequalities();
+
   // Removes all equalities and inequalities.
   void clearConstraints();
 
diff --git a/lib/IR/AffineExpr.cpp b/lib/IR/AffineExpr.cpp
index c029ef3..5cfb146 100644
--- a/lib/IR/AffineExpr.cpp
+++ b/lib/IR/AffineExpr.cpp
@@ -301,11 +301,10 @@
 /// products expression, 'localExprs' is expected to have the AffineExpr
 /// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
 /// format [dims, symbols, locals, constant term].
-//  TODO(bondhugula): refactor getAddMulPureAffineExpr to reuse it from here.
-static AffineExpr toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
-                               unsigned numSymbols,
-                               ArrayRef<AffineExpr> localExprs,
-                               MLIRContext *context) {
+AffineExpr mlir::toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
+                              unsigned numSymbols,
+                              ArrayRef<AffineExpr> localExprs,
+                              MLIRContext *context) {
   // Assert expected numLocals = eq.size() - numDims - numSymbols - 1
   assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() &&
          "unexpected number of local expressions");
diff --git a/lib/IR/AffineStructures.cpp b/lib/IR/AffineStructures.cpp
index d043e78..5114f56 100644
--- a/lib/IR/AffineStructures.cpp
+++ b/lib/IR/AffineStructures.cpp
@@ -809,9 +809,6 @@
   if (posStart >= posLimit)
     return 0;
 
-  LLVM_DEBUG(llvm::dbgs() << "Eliminating by Gaussian [" << posStart << ", "
-                          << posLimit << ")\n");
-
   GCDTightenInequalities();
 
   unsigned pivotCol = 0;
@@ -909,25 +906,36 @@
   return false;
 }
 
+// Gather lower and upper bounds for the pos^th identifier.
+static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst,
+                                         unsigned pos,
+                                         SmallVectorImpl<unsigned> *lbIndices,
+                                         SmallVectorImpl<unsigned> *ubIndices) {
+  assert(pos < cst.getNumIds() && "invalid position");
+
+  // Gather all lower bounds and upper bounds of the variable. Since the
+  // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
+  // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
+  for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
+    if (cst.atIneq(r, pos) >= 1) {
+      // Lower bound.
+      lbIndices->push_back(r);
+    } else if (cst.atIneq(r, pos) <= -1) {
+      // Upper bound.
+      ubIndices->push_back(r);
+    }
+  }
+}
+
 // Check if the pos^th identifier can be expressed as a floordiv of an affine
 // function of other identifiers (where the divisor is a positive constant).
 // For eg: 4q <= i + j <= 4q + 3   <=>   q = (i + j) floordiv 4.
 bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
                       SmallVectorImpl<AffineExpr> *memo, MLIRContext *context) {
   assert(pos < cst.getNumIds() && "invalid position");
-  SmallVector<unsigned, 4> lbIndices, ubIndices;
 
-  // Gather all lower bounds and upper bound constraints of this identifier.
-  // Since the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint
-  // is a lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
-  for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
-    if (cst.atIneq(r, pos) >= 1)
-      // Lower bound.
-      lbIndices.push_back(r);
-    else if (cst.atIneq(r, pos) <= -1)
-      // Upper bound.
-      ubIndices.push_back(r);
-  }
+  SmallVector<unsigned, 4> lbIndices, ubIndices;
+  getLowerAndUpperBoundIndices(cst, pos, &lbIndices, &ubIndices);
 
   // Check if any lower bound, upper bound pair is of the form:
   // divisor * id >=  expr - (divisor - 1)    <-- Lower bound for 'id'
@@ -993,6 +1001,107 @@
   return false;
 }
 
+// Fills an inequality row with the value 'val'.
+static inline void fillInequality(FlatAffineConstraints *cst, unsigned r,
+                                  int64_t val) {
+  for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
+    cst->atIneq(r, c) = val;
+  }
+}
+
+// Negates an inequality.
+static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) {
+  for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
+    cst->atIneq(r, c) = -cst->atIneq(r, c);
+  }
+}
+
+// A more complex check to eliminate redundant inequalities.
+void FlatAffineConstraints::removeRedundantInequalities() {
+  SmallVector<bool, 32> redun(getNumInequalities(), false);
+  // To check if an inequality is redundant, we replace the inequality by its
+  // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
+  // system is empty. If it is, the inequality is redundant.
+  FlatAffineConstraints tmpCst(*this);
+  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+    // Change the inequality to its complement.
+    negateInequality(&tmpCst, r);
+    tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
+    if (tmpCst.isEmpty()) {
+      redun[r] = true;
+      // Zero fill the redundant inequality.
+      fillInequality(this, r, /*val=*/0);
+      fillInequality(&tmpCst, r, /*val=*/0);
+    } else {
+      // Reverse the change (to avoid recreating tmpCst each time).
+      tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
+      negateInequality(&tmpCst, r);
+    }
+  }
+
+  // Scan to get rid of all rows marked redundant, in-place.
+  auto copyRow = [&](unsigned src, unsigned dest) {
+    if (src == dest)
+      return;
+    for (unsigned c = 0, e = getNumCols(); c < e; c++) {
+      atIneq(dest, c) = atIneq(src, c);
+    }
+  };
+  unsigned pos = 0;
+  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+    if (!redun[r])
+      copyRow(r, pos++);
+  }
+  inequalities.resize(numReservedCols * pos);
+}
+
+std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
+    unsigned pos, unsigned dimStartPos, unsigned symStartPos,
+    ArrayRef<AffineExpr> localExprs, MLIRContext *context) {
+  assert(pos < dimStartPos && "invalid dim start pos");
+  assert(symStartPos >= dimStartPos && "invalid sym start pos");
+  assert(getNumLocalIds() == localExprs.size() &&
+         "incorrect local exprs count");
+
+  SmallVector<unsigned, 4> lbIndices, ubIndices;
+  getLowerAndUpperBoundIndices(*this, pos, &lbIndices, &ubIndices);
+
+  SmallVector<int64_t, 8> lb, ub;
+  SmallVector<AffineExpr, 4> exprs;
+  unsigned dimCount = symStartPos - dimStartPos;
+  unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
+  exprs.reserve(lbIndices.size());
+  // Lower bound expressions.
+  for (auto idx : lbIndices) {
+    auto ineq = getInequality(idx);
+    // Extract the lower bound (in terms of other coeff's + const), i.e., if
+    // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
+    // - 1.
+    lb.assign(ineq.begin() + dimStartPos, ineq.end());
+    std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
+    auto expr = mlir::toAffineExpr(lb, dimCount, symCount, localExprs, context);
+    exprs.push_back(expr);
+  }
+  auto lbMap = exprs.empty() ? AffineMap()
+                             : AffineMap::get(dimCount, symCount, exprs, {});
+
+  exprs.clear();
+  exprs.reserve(ubIndices.size());
+  // Upper bound expressions.
+  for (auto idx : ubIndices) {
+    auto ineq = getInequality(idx);
+    // Extract the upper bound (in terms of other coeff's + const).
+    ub.assign(ineq.begin() + dimStartPos, ineq.end());
+    auto expr = mlir::toAffineExpr(ub, dimCount, symCount, localExprs, context);
+    // Upper bound is exclusive.
+    exprs.push_back(expr + 1);
+  }
+  auto ubMap = exprs.empty() ? AffineMap()
+                             : AffineMap::get(dimCount, symCount, exprs, {});
+
+  return {lbMap, ubMap};
+}
+
 /// Computes the lower and upper bounds of the first 'num' dimensional
 /// identifiers as affine maps of the remaining identifiers (dimensional and
 /// symbolic identifiers). Local identifiers are themselves explicitly computed
@@ -1097,6 +1206,7 @@
   // Set the lower and upper bound maps for all the identifiers that were
   // computed as affine expressions of the rest as the "detected expr" and
   // "detected expr + 1" respectively; set the undetected ones to Null().
+  Optional<FlatAffineConstraints> tmpClone;
   for (unsigned pos = 0; pos < num; pos++) {
     unsigned numMapDims = getNumDimIds() - num;
     unsigned numMapSymbols = getNumSymbolIds();
@@ -1108,24 +1218,49 @@
       (*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {});
       (*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {});
     } else {
-      // TODO(andydavis, bondhugula) Add support for computing slice bounds
-      // symbolic in the identifies [num, numIds).
-      auto lbConst = getConstantLowerBound(pos);
-      auto ubConst = getConstantUpperBound(pos);
-      if (lbConst.hasValue() && ubConst.hasValue()) {
-        (*lbMaps)[pos] = AffineMap::get(
-            numMapDims, numMapSymbols,
-            getAffineConstantExpr(lbConst.getValue(), context), {});
-        (*ubMaps)[pos] = AffineMap::get(
-            numMapDims, numMapSymbols,
-            getAffineConstantExpr(ubConst.getValue() + 1, context), {});
-      } else {
-        (*lbMaps)[pos] = AffineMap();
-        (*ubMaps)[pos] = AffineMap();
+      // TODO(bondhugula): Whenever there have local identifiers in the
+      // dependence constraints, we'll conservatively over-approximate, since we
+      // don't always explicitly compute them above (in the while loop).
+      if (getNumLocalIds() == 0) {
+        // Work on a copy so that we don't update this constraint system.
+        if (!tmpClone) {
+          tmpClone.emplace(FlatAffineConstraints(*this));
+          // Removing redudnant inequalities is necessary so that we don't get
+          // redundant loop bounds.
+          tmpClone->removeRedundantInequalities();
+        }
+        std::tie((*lbMaps)[pos], (*ubMaps)[pos]) =
+            tmpClone->getLowerAndUpperBound(pos, num, getNumDimIds(), {},
+                                            context);
+      }
+
+      // If the above fails, we'll just use the constant lower bound and the
+      // constant upper bound (if they exist) as the slice bounds.
+      if (!(*lbMaps)[pos]) {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "WARNING: Potentially over-approximating slice lb\n");
+        auto lbConst = getConstantLowerBound(pos);
+        if (lbConst.hasValue()) {
+          (*lbMaps)[pos] = AffineMap::get(
+              numMapDims, numMapSymbols,
+              getAffineConstantExpr(lbConst.getValue(), context), {});
+        }
+      }
+      if (!(*ubMaps)[pos]) {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "WARNING: Potentially over-approximating slice ub\n");
+        auto ubConst = getConstantUpperBound(pos);
+        if (ubConst.hasValue()) {
+          (*ubMaps)[pos] = AffineMap::get(
+              numMapDims, numMapSymbols,
+              getAffineConstantExpr(ubConst.getValue() + 1, context), {});
+        }
       }
     }
     LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: ");
-    LLVM_DEBUG(expr.dump(););
+    LLVM_DEBUG((*lbMaps)[pos].dump(););
+    LLVM_DEBUG(llvm::dbgs() << "ub map for pos = " << Twine(pos) << ", expr: ");
+    LLVM_DEBUG((*ubMaps)[pos].dump(););
   }
 }
 
@@ -1454,6 +1589,7 @@
         break;
     }
     if (c < getNumDimIds())
+      // Not a pure symbolic bound.
       continue;
     if (atIneq(r, pos) >= 1)
       // Lower bound.
@@ -2037,14 +2173,53 @@
 }
 }; // namespace
 
+// TODO(bondhugula,andydavis): This still doesn't do a comprehensive merge of
+// the symbols. Assumes the common symbols appear in the same order (the
+// current/common use case).
+static void mergeSymbols(FlatAffineConstraints *A, FlatAffineConstraints *B) {
+  SmallVector<Value *, 4> symbolsA, symbolsB;
+  A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &symbolsA);
+  B->getIdValues(B->getNumDimIds(), B->getNumDimAndSymbolIds(), &symbolsB);
+
+  // Both symbol list have a handful symbols each typically (3-4); a merge
+  // quadratic in complexity with a linear search is fine.
+  for (auto *symbolB : symbolsB) {
+    if (llvm::is_contained(symbolsA, symbolB)) {
+      A->addSymbolId(symbolsA.size(), symbolB);
+      symbolsA.push_back(symbolB);
+    }
+  }
+  // symbolsA now holds the merged symbol list.
+  symbolsB.reserve(symbolsA.size());
+  unsigned iB = 0;
+  for (auto *symbolA : symbolsA) {
+    assert(iB < symbolsB.size());
+    if (symbolA != symbolsB[iB]) {
+      symbolsB.insert(symbolsB.begin() + iB, symbolA);
+      B->addSymbolId(iB, symbolA);
+    }
+    ++iB;
+  }
+}
+
 // Compute the bounding box with respect to 'other' by finding the min of the
 // lower bounds and the max of the upper bounds along each of the dimensions.
 bool FlatAffineConstraints::unionBoundingBox(
-    const FlatAffineConstraints &other) {
-  assert(other.getNumDimIds() == numDims);
-  assert(other.getNumSymbolIds() == getNumSymbolIds());
-  assert(other.getNumLocalIds() == 0);
-  assert(getNumLocalIds() == 0);
+    const FlatAffineConstraints &otherArg) {
+  assert(otherArg.getNumDimIds() == numDims && "dims mismatch");
+
+  Optional<FlatAffineConstraints> copy;
+  if (!otherArg.getIds().equals(getIds())) {
+    copy.emplace(FlatAffineConstraints(otherArg));
+    mergeSymbols(this, &copy.getValue());
+    assert(getIds().equals(copy->getIds()) && "merge failed");
+  }
+
+  const auto &other = copy ? *copy : otherArg;
+
+  assert(other.getNumLocalIds() == 0 && "local ids not eliminated");
+  assert(getNumLocalIds() == 0 && "local ids not eliminated");
+
   std::vector<SmallVector<int64_t, 8>> boundingLbs;
   std::vector<SmallVector<int64_t, 8>> boundingUbs;
   boundingLbs.reserve(2 * getNumDimIds());
@@ -2082,7 +2257,11 @@
       minLb = otherLb;
     } else {
       // Uncomparable.
-      return false;
+      auto constLb = getConstantLowerBound(d);
+      auto constOtherLb = other.getConstantLowerBound(d);
+      if (!constLb.hasValue() || !constOtherLb.hasValue())
+        return false;
+      minLb = std::min(constLb.getValue(), constOtherLb.getValue());
     }
 
     // Do the same for ub's but max of upper bounds.
@@ -2098,7 +2277,11 @@
       maxUb = otherUb;
     } else {
       // Uncomparable.
-      return false;
+      auto constUb = getConstantUpperBound(d);
+      auto constOtherUb = other.getConstantUpperBound(d);
+      if (!constUb.hasValue() || !constOtherUb.hasValue())
+        return false;
+      maxUb = std::max(constUb.getValue(), constOtherUb.getValue());
     }
 
     SmallVector<int64_t, 8> newLb(getNumCols(), 0);
diff --git a/lib/Transforms/LoopFusion.cpp b/lib/Transforms/LoopFusion.cpp
index 63a681a..524b34b 100644
--- a/lib/Transforms/LoopFusion.cpp
+++ b/lib/Transforms/LoopFusion.cpp
@@ -58,8 +58,8 @@
 /// A threshold in percent of additional computation allowed when fusing.
 static llvm::cl::opt<double> clFusionAddlComputeTolerance(
     "fusion-compute-tolerance", llvm::cl::Hidden,
-    llvm::cl::desc("Fractional increase in additional"
-                   " computation tolerated while fusing"),
+    llvm::cl::desc("Fractional increase in additional "
+                   "computation tolerated while fusing"),
     llvm::cl::cat(clOptionsCategory));
 
 static llvm::cl::opt<unsigned> clFusionFastMemorySpace(
@@ -1260,12 +1260,9 @@
                                unsigned *dstLoopDepth) {
   LLVM_DEBUG({
     llvm::dbgs() << "Checking whether fusion is profitable between:\n";
-    llvm::dbgs() << " ";
-    srcOpInst->dump();
-    llvm::dbgs() << " and \n";
+    llvm::dbgs() << " " << *srcOpInst << " and \n";
     for (auto dstOpInst : dstLoadOpInsts) {
-      llvm::dbgs() << " ";
-      dstOpInst->dump();
+      llvm::dbgs() << " " << *dstOpInst << "\n";
     };
   });
 
@@ -1423,7 +1420,10 @@
           << 100.0 * additionalComputeFraction << "%\n"
           << "   storage reduction factor: " << storageReduction << "x\n"
           << "   fused nest cost: " << fusedLoopNestComputeCost << "\n"
-          << "   slice iteration count: " << sliceIterationCount << "\n";
+          << "   slice iteration count: " << sliceIterationCount << "\n"
+          << "   src write region size: " << srcWriteRegionSizeBytes << "\n"
+          << "   slice write region size: " << sliceWriteRegionSizeBytes
+          << "\n";
       llvm::dbgs() << msg.str();
     });
 
@@ -1450,9 +1450,10 @@
   // -maximal-fusion is set, fuse nevertheless.
 
   if (!clMaximalLoopFusion && !bestDstLoopDepth.hasValue()) {
-    LLVM_DEBUG(llvm::dbgs()
-               << "All fusion choices involve more than the threshold amount of"
-                  "redundant computation; NOT fusing.\n");
+    LLVM_DEBUG(
+        llvm::dbgs()
+        << "All fusion choices involve more than the threshold amount of "
+           "redundant computation; NOT fusing.\n");
     return false;
   }
 
@@ -1694,6 +1695,9 @@
           auto sliceLoopNest = mlir::insertBackwardComputationSlice(
               srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
           if (sliceLoopNest != nullptr) {
+            LLVM_DEBUG(llvm::dbgs()
+                       << "\tslice loop nest:\n"
+                       << *sliceLoopNest->getInstruction() << "\n");
             // Move 'dstAffineForOp' before 'insertPointInst' if needed.
             auto dstAffineForOp = dstNode->inst->cast<AffineForOp>();
             if (insertPointInst != dstAffineForOp->getInstruction()) {
diff --git a/test/Transforms/loop-fusion.mlir b/test/Transforms/loop-fusion.mlir
index c671adc..2458049 100644
--- a/test/Transforms/loop-fusion.mlir
+++ b/test/Transforms/loop-fusion.mlir
@@ -1810,3 +1810,61 @@
   // CHECK-NEXT:  }
   // CHECK-NEXT:  return %arg0 : memref<10xf32>
 }
+
+// -----
+
+// The fused slice has 16 iterations from along %i0.
+
+// CHECK-DAG: [[MAP_LB:#map[0-9]+]] = (d0) -> (d0 * 16)
+// CHECK-DAG: [[MAP_UB:#map[0-9]+]] = (d0) -> (d0 * 16 + 16)
+
+#map = (d0, d1) -> (d0 * 16 + d1)
+
+// CHECK-LABEL: slice_tile
+func @slice_tile(%arg1: memref<32x8xf32>, %arg2: memref<32x8xf32>, %0 : f32) -> memref<32x8xf32> {
+  for %i0 = 0 to 32 {
+    for %i1 = 0 to 8 {
+      store %0, %arg2[%i0, %i1] : memref<32x8xf32>
+    }
+  }
+  for %i = 0 to 2 {
+    for %j = 0 to 8 {
+      for %k = 0 to 8 {
+        for %kk = 0 to 16 {
+          %1 = affine.apply #map(%k, %kk)
+          %2 = load %arg1[%1, %j] : memref<32x8xf32>
+          %3 = "foo"(%2) : (f32) -> f32
+        }
+        for %ii = 0 to 16 {
+          %6 = affine.apply #map(%i, %ii)
+          %7 = load %arg2[%6, %j] : memref<32x8xf32>
+          %8 = addf %7, %7 : f32
+          store %8, %arg2[%6, %j] : memref<32x8xf32>
+        }
+      }
+    }
+  }
+  return %arg2 : memref<32x8xf32>
+}
+// CHECK:       for %i0 = 0 to 2 {
+// CHECK-NEXT:    for %i1 = 0 to 8 {
+// CHECK-NEXT:      for %i2 = [[MAP_LB]](%i0) to [[MAP_UB]](%i0) {
+// CHECK-NEXT:        store %arg2, %arg1[%i2, %i1] : memref<32x8xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:      for %i3 = 0 to 8 {
+// CHECK-NEXT:        for %i4 = 0 to 16 {
+// CHECK-NEXT:          %0 = affine.apply #map{{[0-9]+}}(%i3, %i4)
+// CHECK-NEXT:          %1 = load %arg0[%0, %i1] : memref<32x8xf32>
+// CHECK-NEXT:          %2 = "foo"(%1) : (f32) -> f32
+// CHECK-NEXT:        }
+// CHECK-NEXT:        for %i5 = 0 to 16 {
+// CHECK-NEXT:          %3 = affine.apply #map{{[0-9]+}}(%i0, %i5)
+// CHECK-NEXT:          %4 = load %arg1[%3, %i1] : memref<32x8xf32>
+// CHECK-NEXT:          %5 = addf %4, %4 : f32
+// CHECK-NEXT:          store %5, %arg1[%3, %i1] : memref<32x8xf32>
+// CHECK-NEXT:        }
+// CHECK-NEXT:      }
+// CHECK-NEXT:    }
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return %arg1 : memref<32x8xf32>
+// CHECK-NEXT:}