Make op creation order explicit in unrolling batch matmul

Remove DAG post fixing difference due to explicit order of evaluation (it was using unspecified order previously). NFC change while here to remove optional space: bugprone-argument-comment doesn't care about the spaces, but we don't use the optional spaces elsewhere, so make consistent.

PiperOrigin-RevId: 268284706
diff --git a/tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir
index cc61caa..09f1dfc 100644
--- a/tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir
@@ -5,19 +5,19 @@
   return %0 : tensor<2x3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV2TwoDim
-  // CHECK-DAG: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
-  // CHECK-DAG: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
-  // CHECK-DAG: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
+  // CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
+  // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
+  // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
+  // CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
+  // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
+  // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
+  // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
+  // CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
+  // CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
+  // CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
+  // CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
+  // CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
+  // CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
 
   // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
   // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
@@ -65,16 +65,16 @@
   return %0 : tensor<3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulV2FlatInput
-  // CHECK-DAG: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
-  // CHECK-DAG: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
-  // CHECK-DAG: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
+  // CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
+  // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
+  // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
+  // CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
+  // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
+  // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
+  // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
+  // CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
+  // CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
+  // CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
 
   // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
   // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
@@ -116,19 +116,19 @@
   return %0 : tensor<2x3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulTwoDim
-  // CHECK-DAG: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
-  // CHECK-DAG: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
-  // CHECK-DAG: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
+  // CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
+  // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
+  // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
+  // CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
+  // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
+  // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
+  // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
+  // CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
+  // CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
+  // CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
+  // CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
+  // CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
+  // CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
 
   // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
   // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
@@ -176,16 +176,16 @@
   return %0 : tensor<3x4x6xf32>
 
   // CHECK-LABEL: batchMatMulFlatInput
-  // CHECK-DAG: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
-  // CHECK-DAG: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
-  // CHECK-DAG: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
-  // CHECK-DAG: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
+  // CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
+  // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
+  // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
+  // CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
+  // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
+  // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
+  // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
+  // CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
+  // CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
+  // CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
 
   // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
   // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
diff --git a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc
index 1fde6ac..80a2a3e 100644
--- a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc
@@ -92,8 +92,8 @@
   auto constant_attr = DenseElementsAttr::get(shapeSpecType, shape);
   auto shapeTensor =
       rewriter.create<ConstantOp>(loc, shapeSpecType, constant_attr);
-  return rewriter.create<TF::ReshapeOp>(loc, resultType, /* tensor = */ value,
-                                        /* shape = */ shapeTensor);
+  return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
+                                        /*shape=*/shapeTensor);
 }
 
 template <typename BatchMatMulOpType>
@@ -124,13 +124,11 @@
     auto begin_attr =
         DenseElementsAttr::get<int64_t>(vector3Type, {batch_idx, 0, 0});
     auto size_attr = DenseElementsAttr::get<int64_t>(vector3Type, sliceSize);
-    auto sliceOp = rewriter.create<TF::SliceOp>(
-        loc, sliceResultType,
-        /* input = */ reshapeOp.output(),
-        /* begin = */
-        rewriter.create<ConstantOp>(loc, vector3Type, begin_attr),
-        /* size = */
-        rewriter.create<ConstantOp>(loc, vector3Type, size_attr));
+    auto begin = rewriter.create<ConstantOp>(loc, vector3Type, begin_attr);
+    auto size = rewriter.create<ConstantOp>(loc, vector3Type, size_attr);
+    auto sliceOp =
+        rewriter.create<TF::SliceOp>(loc, sliceResultType,
+                                     /*input=*/reshapeOp.output(), begin, size);
 
     // Squeeze matrix, i.e. reshape [1, num_rows, num_cols] -> [num_rows,
     // num_cols]
@@ -192,12 +190,12 @@
       lhs_batch_idx = batch_idx;
       rhs_batch_idx = batch_idx;
     }
-    auto matmul = rewriter.create<TF::MatMulOp>(
-        loc, matmulType,
-        /* a = */ sliced_lhs[lhs_batch_idx],
-        /* b = */ sliced_rhs[rhs_batch_idx],
-        /* transpose_a = */ rewriter.getBoolAttr(false),
-        /* transpose_b = */ rewriter.getBoolAttr(false));
+    auto false_attr = rewriter.getBoolAttr(false);
+    auto matmul = rewriter.create<TF::MatMulOp>(loc, matmulType,
+                                                /*a=*/sliced_lhs[lhs_batch_idx],
+                                                /*b=*/sliced_rhs[rhs_batch_idx],
+                                                /*transpose_a=*/false_attr,
+                                                /*transpose_b=*/false_attr);
     matmuls.emplace_back(matmul.product());
   }
 
@@ -205,11 +203,10 @@
   Type packedType = rewriter.getTensorType(
       {bcast.output_batch_size(), rows, cols}, elementType);
 
-  return rewriter.create<TF::PackOp>(
-      loc, packedType,
-      /* values = */ matmuls,
-      /* N = */ rewriter.getI64IntegerAttr(matmuls.size()),
-      /* axis = */ rewriter.getI64IntegerAttr(0));
+  auto N = rewriter.getI64IntegerAttr(matmuls.size());
+  auto axis = rewriter.getI64IntegerAttr(0);
+  return rewriter.create<TF::PackOp>(loc, packedType,
+                                     /*values=*/matmuls, N, axis);
 }
 
 template <typename BatchMatMulOpType>
@@ -276,12 +273,12 @@
     // When both inputs are matrices, just replace the op to a matmul op.
     Type resultType =
         rewriter.getTensorType({lhs_shape[0], rhs_shape[1]}, elementType);
-    rewriter.replaceOpWithNewOp<TF::MatMulOp>(
-        op, resultType,
-        /* a = */ input_lhs,
-        /* b = */ input_rhs,
-        /* transpose_a = */ rewriter.getBoolAttr(false),
-        /* transpose_b = */ rewriter.getBoolAttr(false));
+    auto false_attr = rewriter.getBoolAttr(false);
+    rewriter.replaceOpWithNewOp<TF::MatMulOp>(op, resultType,
+                                              /*a=*/input_lhs,
+                                              /*b=*/input_rhs,
+                                              /*transpose_a=*/false_attr,
+                                              /*transpose_b=*/false_attr);
     return this->matchSuccess();
   }