Add patterns to lower from ParallelOp and ForOp to SCF loops.

For now, test the case without any input and output arguments. This requires
adjusting the Parser to allow optional return type.

PiperOrigin-RevId: 454826173
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td
index 22d46c8..17f2a52 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td
@@ -100,6 +100,11 @@
     ValueRange getInductionVars() {
       return getBody()->getArguments().take_front(getNumLoops());
     }
+
+    /// Return whether the op has no output tensors.
+    bool hasBufferSemantics() {
+      return this->getOperation()->getNumResults() == 0;
+    }
   }];
 
   let hasCustomAssemblyFormat = 1;
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/IR/gml_st_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/IR/gml_st_ops.cc
index cbfef8b..4b6b4f9 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/IR/gml_st_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/IR/gml_st_ops.cc
@@ -605,8 +605,7 @@
                                 static_cast<int32_t>(subsets.size())}));
 
   // Parser result types.
-  if (parser.parseColon() || parser.parseTypeList(result.types))
-    return failure();
+  if (parser.parseOptionalColonTypeList(result.types)) return failure();
 
   return success();
 }
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/gml_st_to_scf.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/gml_st_to_scf.cc
index 2f391f6..09488c9 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/gml_st_to_scf.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/gml_st_to_scf.cc
@@ -93,11 +93,45 @@
   }
 };
 
+/// Converts gml_st.parallel or gml_st.for to SCF loop nest.
+template <typename LoopTy>
+struct LoopLikeToSCFPattern : public OpRewritePattern<LoopTy> {
+  using OpRewritePattern<LoopTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(LoopTy loop,
+                                PatternRewriter &rewriter) const override {
+    // Fail conversion if the loop has not been bufferized.
+    if (!loop.hasBufferSemantics()) return failure();
+
+    auto cloneBody = [&](OpBuilder &builder, Location /*loc*/, ValueRange ivs) {
+      BlockAndValueMapping bvm;
+      bvm.map(loop.getInductionVars(), ivs);
+      bvm.map(loop.getBody()->getArguments().take_back(loop.outputs().size()),
+              loop.outputs());
+
+      for (auto &op : loop.getBody()->without_terminator())
+        builder.clone(op, bvm);
+    };
+
+    Location loc = loop.getLoc();
+    if (std::is_same<LoopTy, ParallelOp>::value) {
+      rewriter.create<scf::ParallelOp>(
+          loc, loop.lowerBound(), loop.upperBound(), loop.step(), cloneBody);
+    } else {
+      scf::buildLoopNest(rewriter, loc, loop.lowerBound(), loop.upperBound(),
+                         loop.step(), cloneBody);
+    }
+    rewriter.eraseOp(loop);
+    return success();
+  }
+};
+
 struct GmlStToScfPass : public GmlStToScfBase<GmlStToScfPass> {
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     RewritePatternSet patterns(context);
-    patterns.add<LoopToSCFPattern>(patterns.getContext());
+    patterns.add<LoopToSCFPattern, LoopLikeToSCFPattern<ForOp>,
+                 LoopLikeToSCFPattern<ParallelOp>>(patterns.getContext());
     if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                             std::move(patterns)))) {
       signalPassFailure();
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/gml_st_to_scf.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/gml_st_to_scf.mlir
index c44cf1b..3dec5a4 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/gml_st_to_scf.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/gml_st/gml_st_to_scf.mlir
@@ -50,6 +50,33 @@
 
 // -----
 
+
+func.func @parallel(%A: memref<192x192xf32>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c24 = arith.constant 24 : index
+  %c16 = arith.constant 16 : index
+  %c0 = arith.constant 0 : index
+  %c192 = arith.constant 192 : index
+
+  gml_st.parallel (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) {
+    linalg.fill ins(%cst : f32) outs(%A : memref<192x192xf32>)
+    gml_st.subset_yield
+  }
+  func.return
+}
+
+// CHECK-LABEL: @parallel
+// CHECK-SAME:  %[[A:.*]]: memref<192x192xf32>
+// CHECK-DAG:   %[[C24:.*]] = arith.constant 24 : index
+// CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C192:.*]] = arith.constant 192 : index
+// CHECK:       scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
+// CHECK-SAME:      to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]]) {
+// CHECK:         linalg.fill
+
+// -----
+
 func.func @loop_reduction(%A: memref<192x192xf32>,
                            %B: memref<192x192xf32>,
                            %C: memref<f32>) {
@@ -80,6 +107,31 @@
 
 // -----
 
+func.func @for(%A: memref<192x192xf32>) {
+   %c24 = arith.constant 24 : index
+   %c16 = arith.constant 16 : index
+   %c0 = arith.constant 0 : index
+   %c192 = arith.constant 192 : index
+   %cst = arith.constant 0.000000e+00 : f32
+
+  gml_st.for (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) {
+    linalg.fill ins(%cst : f32) outs(%A : memref<192x192xf32>)
+    gml_st.subset_yield
+  }
+  func.return
+}
+
+// CHECK-LABEL: @for
+// CHECK-DAG:   %[[C24:.*]] = arith.constant 24 : index
+// CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C192:.*]] = arith.constant 192 : index
+// CHECK:       scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]]
+// CHECK:         scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]]
+// CHECK:           linalg.fill
+
+// -----
+
 #strided_1d = affine_map<(d0)[s0] -> (d0 + s0)>
 #strided_2d = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)>