Lower linalg.indexed_generic to loops.

PiperOrigin-RevId: 281169885
Change-Id: I3951c74ff03faebc335fcf3a084411a81f58cde4
diff --git a/third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 867b8ab..b137764 100644
--- a/third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -108,12 +108,11 @@
 }
 
 template <typename GenericOpType>
-LogicalResult verifyBlockArgs(GenericOpType op, Block &block, unsigned nViews,
-                              unsigned nLoops, unsigned nInputViews);
+LogicalResult verifyBlockArgs(GenericOpType op, Block &block);
 
-template <>
-LogicalResult verifyBlockArgs(GenericOp op, Block &block, unsigned nViews,
-                              unsigned nLoops, unsigned nInputViews) {
+template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
+  auto nViews = op.getNumInputsAndOutputs();
+  auto nInputViews = op.getNumInputs();
   if (block.getNumArguments() != nViews)
     return op.emitError(
         "op expected number of block arguments to match number of views");
@@ -129,10 +128,10 @@
   return success();
 }
 
-template <>
-LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block,
-                              unsigned nViews, unsigned nLoops,
-                              unsigned nInputViews) {
+template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
+  auto nInputViews = op.getNumInputs();
+  auto nLoops = op.getNumLoops();
+  auto nViews = op.getNumInputsAndOutputs();
   if (block.getNumArguments() != nViews + nLoops)
     return op.emitError(
         "op expected number of block arguments to match number of views + "
@@ -158,6 +157,76 @@
 }
 
 template <typename GenericOpType>
+LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType);
+
+template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) {
+  auto nViews = op.getNumInputsAndOutputs();
+  auto nInputViews = op.getNumInputs();
+  if (funType.getNumInputs() != nViews)
+    return op.emitError("op expected fun arguments to match number of views");
+  if (funType.getNumResults() != op.getNumOutputs())
+    return op.emitError(
+        "op expected fun results to match number of output views");
+
+  for (auto en : llvm::enumerate(op.indexing_maps())) {
+    auto idx = en.index();
+    auto view = (idx < nInputViews) ? op.getInputViewType(idx)
+                                    : op.getOutputViewType(idx - nInputViews);
+    if (funType.getInput(idx) != view.getElementType())
+      return op.emitError("op expected fun argument ")
+             << idx << " of the same type as elemental type "
+             << view.getElementType() << " of view " << idx;
+
+    if (idx >= nInputViews) {
+      auto resultIdx = idx - nInputViews;
+      if (funType.getResult(resultIdx) != view.getElementType())
+        return op.emitError("op expected fun result ")
+               << resultIdx << " of the same type as elemental type "
+               << view.getElementType() << " of view " << idx;
+    }
+  }
+  return success();
+}
+
+template <>
+LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) {
+  auto nLoops = op.getNumLoops();
+  auto nInputViews = op.getNumInputs();
+  auto nOutputs = op.getNumOutputs();
+  auto nViews = op.getNumInputsAndOutputs();
+  if (funType.getNumInputs() != nViews + nLoops)
+    return op.emitError(
+        "op expected fun arguments to match number of views + number of loops");
+  if (funType.getNumResults() != nOutputs)
+    return op.emitError(
+        "op expected fun results to match number of output views");
+  for (unsigned i = 0; i < nLoops; ++i) {
+    if (!funType.getInput(i).isIndex())
+      return op.emitError("op expected fun argument ")
+             << i << " to be of IndexType";
+  }
+  for (auto en : llvm::enumerate(op.indexing_maps())) {
+    auto idx = en.index();
+    auto funIdx = nLoops + idx;
+    auto view = (idx < nInputViews) ? op.getInputViewType(idx)
+                                    : op.getOutputViewType(idx - nInputViews);
+    if (funType.getInput(funIdx) != view.getElementType())
+      return op.emitError("op expected fun argument ")
+             << funIdx << " of the same type as elemental type "
+             << view.getElementType() << " of view " << idx;
+
+    if (idx >= nInputViews) {
+      auto resultIdx = idx - nInputViews;
+      if (funType.getResult(resultIdx) != view.getElementType())
+        return op.emitError("op expected fun result ")
+               << resultIdx << " of the same type as elemental type "
+               << view.getElementType() << " of view " << idx;
+    }
+  }
+  return success();
+}
+
+template <typename GenericOpType>
 LogicalResult verifyGenericOp(GenericOpType op) {
   auto nInputViews = op.getNumInputs();
   auto nLoops = op.getNumLoops();
@@ -171,20 +240,14 @@
   if (!region.empty()) {
     if (region.getBlocks().size() != 1)
       return op.emitError("op expected region with 1 block");
-
-    auto &block = region.getBlocks().front();
-    if (failed(verifyBlockArgs(op, block, nViews, nLoops, nInputViews))) {
+    if (failed(verifyBlockArgs(op, region.getBlocks().front())))
       return failure();
-    }
   } else {
     if (!funOp || !funOp.getType())
       return op.emitError(
           "op expected fun attribute to refer to a defined symbol");
-    if (funType.getNumInputs() != nViews)
-      return op.emitError("op expected fun arguments to match number of views");
-    if (funType.getNumResults() != op.getNumOutputs())
-      return op.emitError(
-          "op expected fun results to match number of output views");
+    if (failed(verifyFuncArgs(op, funType)))
+      return failure();
   }
 
   SmallVector<AffineMap, 4> indexingMaps;
@@ -215,19 +278,6 @@
     if (m.getNumResults() != view.getRank())
       return op.emitError("op expected indexing_map #")
              << idx << " results to match view rank: " << view;
-
-    if (funType) {
-      if (funType.getInput(idx) != view.getElementType())
-        return op.emitError("op expected fun argument ")
-               << idx
-               << " to match view element type: " << view.getElementType();
-
-      if (idx >= nInputViews)
-        if (funType.getResult(idx - nInputViews) != view.getElementType())
-          return op.emitError("op expected fun result ")
-                 << idx << " to match output view element type: "
-                 << view.getElementType();
-    }
   }
 
   auto concatMap = concatAffineMaps(indexingMaps);
@@ -718,6 +768,13 @@
       res.push_back(genericOp.getIndexingMap(i));
     }
     return res;
+  } else if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) {
+    SmallVector<AffineMap, 4> res;
+    unsigned nViews = indexedGenericOp.getNumInputsAndOutputs();
+    res.reserve(nViews);
+    for (unsigned i = 0, e = nViews; i < e; ++i)
+      res.push_back(indexedGenericOp.getIndexingMap(i));
+    return res;
   }
   llvm_unreachable("Missing loopToOperandRangesMaps for op");
 }
diff --git a/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
index 6aace80..6e97a7a 100644
--- a/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -474,10 +474,11 @@
                                            MLIRContext *ctx) {
   // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant
   // attribute values such as kernel striding and dilation.
-  patterns.insert<CopyTransposeConversion, LinalgOpConversion<CopyOp>,
-                  LinalgOpConversion<DotOp>, LinalgOpConversion<FillOp>,
-                  LinalgOpConversion<MatvecOp>, LinalgOpConversion<MatmulOp>,
-                  LinalgOpConversion<ConvOp>, LinalgOpConversion<GenericOp>>(
+  patterns.insert<CopyTransposeConversion, LinalgOpConversion<ConvOp>,
+                  LinalgOpConversion<CopyOp>, LinalgOpConversion<DotOp>,
+                  LinalgOpConversion<FillOp>, LinalgOpConversion<GenericOp>,
+                  LinalgOpConversion<IndexedGenericOp>,
+                  LinalgOpConversion<MatmulOp>, LinalgOpConversion<MatvecOp>>(
       ctx);
 }
 
diff --git a/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp b/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp
index 058dc07..0bf4cea 100644
--- a/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp
@@ -244,14 +244,14 @@
     SmallVector<Value *, 4> indexedValues(nInputs + nOutputs);
 
     // 1.a. Emit std_load from input views.
-    for (unsigned i = 0, e = nInputs; i < e; ++i) {
+    for (unsigned i = 0; i < nInputs; ++i) {
       ValueHandleArray indexing(foldedAffineApplies(
           b, loc, genericOp.getInputIndexingMap(i), allIvs, folder));
       indexedValues[i] = std_load(genericOp.getInput(i), indexing);
     }
 
     // 1.b. Emit std_load from output views.
-    for (unsigned i = 0, e = nOutputs; i < e; ++i) {
+    for (unsigned i = 0; i < nOutputs; ++i) {
       ValueHandleArray indexing(foldedAffineApplies(
           b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
       indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing);
@@ -264,49 +264,138 @@
       assert(callOp->getNumResults() == genericOp.getNumOutputs());
 
       // 3. Emit std_store.
-      for (unsigned i = 0, e = nOutputs; i < e; ++i) {
+      for (unsigned i = 0; i < nOutputs; ++i) {
         ValueHandleArray indexing(foldedAffineApplies(
             b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
         std_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
       }
-    } else {
-      // TODO(ntv): When a region inliner exists, use it.
-      // 2. Inline region, currently only works for a single basic block.
-      BlockAndValueMapping map;
-      auto &block = genericOp.region().front();
-      for (auto it : llvm::zip(block.getArguments(), indexedValues))
+      return;
+    }
+    // TODO(ntv): When a region inliner exists, use it.
+    // 2. Inline region, currently only works for a single basic block.
+    BlockAndValueMapping map;
+    auto &block = genericOp.region().front();
+    for (auto it : llvm::zip(block.getArguments(), indexedValues))
+      map.map(std::get<0>(it), std::get<1>(it));
+    for (auto &op : block.without_terminator()) {
+      assert(op.getNumRegions() == 0);
+      auto *newOp = b.clone(op, map);
+      for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
         map.map(std::get<0>(it), std::get<1>(it));
-      for (auto &op : block) {
-        // Skip terminator.
-        if (&op == &block.back())
-          continue;
-        assert(op.getNumRegions() == 0);
-        auto *newOp = b.clone(op, map);
-        for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
-          map.map(std::get<0>(it), std::get<1>(it));
-      }
+    }
 
-      // 3. Emit std_store.
-      auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
-      assert(yieldOp->getNumOperands() == nOutputs);
-      for (unsigned i = 0, e = nOutputs; i < e; ++i) {
-        ValueHandleArray indexing(foldedAffineApplies(
-            b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
-        std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
-                  indexing);
-      }
+    // 3. Emit std_store.
+    auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
+    assert(yieldOp->getNumOperands() == nOutputs);
+    for (unsigned i = 0; i < nOutputs; ++i) {
+      ValueHandleArray indexing(foldedAffineApplies(
+          b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
+      std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
+                indexing);
     }
   }
 };
 
+// Emits the MLIR for the scalar part of the indexed generic op by:
+//   1. Emitting std_load and std_store ops for each input and output view in
+//      order. This is achieved by applying the appropriate input or output map
+//      to the enclosing induction variables.
+//   2. Emitting a call to `op.fun()` that takes as arguments the induction
+//      variables and the scalars from point 1. above.
+//   3. Emitting std_store to store the results of 2. to the output views.
+//
+// An example output may resemble:
+//
+// ```
+//    loop.for %i = %c0 to %0 step %c1 {
+//      loop.for %j = %c0 to %1 step %c1 {
+//        loop.for %k = %c0 to %4 step %c1 {
+//          %11 = load %arg0[%i, %j] :
+//            memref<?x?xf32, stride_specification>
+//          %12 = load %arg1[%i, %j, %k] :
+//            memref<?x?x?xf32, stride_specification>
+//          %13 = load %arg2[%i, %k, %j] :
+//            memref<?x?x?xf32, stride_specification>
+//          %14:2 = call @foo(%i, %j, %k, %11, %12, %13) :
+//            (index, index, index, f32, f32, f32) -> (f32, f32)
+//          store %14#0, %arg1[%i, %j, %k] :
+//            memref<?x?x?Xf32, stride_specification>
+//          store %14#1, %arg2[%i, %k, %j] :
+//            memref<?x?x?Xf32, stride_specification>
+//       }
+//      }
+//    }
+// ```
 template <typename IndexedValueType>
 class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs,
-                                       IndexedGenericOp genericOp,
+                                       IndexedGenericOp indexedGenericOp,
                                        OperationFolder *folder) {
-    // This is just a shim to make Linalg compile.
-    // TODO(pifon): Implement lowering after IndexedGenericOp def is submitted.
+    auto b = ScopedContext::getBuilder();
+    auto loc = ScopedContext::getLocation();
+    using edsc::intrinsics::detail::ValueHandleArray;
+    unsigned nInputs = indexedGenericOp.getNumInputs();
+    unsigned nOutputs = indexedGenericOp.getNumOutputs();
+    unsigned nLoops = allIvs.size();
+    SmallVector<Value *, 4> indexedValues(nLoops + nInputs + nOutputs);
+
+    for (unsigned i = 0; i < nLoops; ++i) {
+      indexedValues[i] = allIvs[i];
+    }
+
+    // 1.a. Emit std_load from input views.
+    for (unsigned i = 0; i < nInputs; ++i) {
+      ValueHandleArray indexing(foldedAffineApplies(
+          b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs, folder));
+      indexedValues[nLoops + i] =
+          std_load(indexedGenericOp.getInput(i), indexing);
+    }
+
+    // 1.b. Emit std_load from output views.
+    for (unsigned i = 0; i < nOutputs; ++i) {
+      ValueHandleArray indexing(foldedAffineApplies(
+          b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
+      indexedValues[nLoops + nInputs + i] =
+          std_load(indexedGenericOp.getOutput(i), indexing);
+    }
+
+    if (auto funcOp = indexedGenericOp.getFunction()) {
+      // 2. Emit call.
+      Operation *callOp = call(funcOp, indexedValues);
+      assert(callOp->getNumResults() == indexedGenericOp.getNumOutputs());
+
+      // 3. Emit std_store.
+      for (unsigned i = 0; i < nOutputs; ++i) {
+        ValueHandleArray indexing(foldedAffineApplies(
+            b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
+        std_store(callOp->getResult(i), indexedGenericOp.getOutput(i),
+                  indexing);
+      }
+      return;
+    }
+    // TODO(ntv): When a region inliner exists, use it.
+    // 2. Inline region, currently only works for a single basic block.
+    BlockAndValueMapping map;
+    auto &block = indexedGenericOp.region().front();
+    for (auto it : llvm::zip(block.getArguments(), indexedValues))
+      map.map(std::get<0>(it), std::get<1>(it));
+    for (auto &op : block.without_terminator()) {
+      assert(op.getNumRegions() == 0);
+      auto *newOp = b.clone(op, map);
+      for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
+        map.map(std::get<0>(it), std::get<1>(it));
+    }
+
+    // 3. Emit std_store.
+    auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
+    assert(yieldOp->getNumOperands() == nOutputs);
+    for (unsigned i = 0; i < nOutputs; ++i) {
+      ValueHandleArray indexing(foldedAffineApplies(
+          b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
+      std_store(map.lookup(yieldOp->getOperand(i)),
+                indexedGenericOp.getOutput(i), indexing);
+    }
   }
 };