Simplify some tests

PiperOrigin-RevId: 311910223
Change-Id: I751aa9344c08a490261822dc8010d1704da95a7c
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
index 3281595..2119e78 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
@@ -302,15 +302,13 @@
   return %0: tensor<2xi32>
 }
 
-func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
   %cst = constant dense<0.0> : tensor<2x2xf32>
-  %0 = "tf.Add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  %1 = "tf.Add"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  return %1 : tensor<2x2xf32>
+  %0 = "tf.Add"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
 
   // CHECK-LABEL: RemoveTrivialAdd
-  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
+  // CHECK-NEXT: return %arg0 : tensor<2x2xf32>
 }
 
 func @RemoveTrivialAddBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
@@ -331,26 +329,22 @@
   // CHECK-NEXT: return %arg0 : tensor<2x2xbf16>
 }
 
-func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
   %cst = constant dense<0.0> : tensor<2x2xf32>
-  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  %1 = "tf.AddV2"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  return %1 : tensor<2x2xf32>
+  %0 = "tf.AddV2"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
 
   // CHECK-LABEL: RemoveTrivialAddV2
-  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
+  // CHECK-NEXT: return %arg0 : tensor<2x2xf32>
 }
 
-func @RemoveTrivialSub(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+func @RemoveTrivialSub(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
   %cst = constant dense<0.0> : tensor<2x2xf32>
-  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  %1 = "tf.Sub"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  return %1 : tensor<2x2xf32>
+  %0 = "tf.Sub"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
 
   // CHECK-LABEL: RemoveTrivialSub
-  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
+  // CHECK-NEXT: return %arg0 : tensor<2x2xf32>
 }
 
 func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> {
@@ -362,26 +356,22 @@
   // CHECK-NEXT: return %arg0 : tensor<2x2xi8>
 }
 
-func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+func @RemoveTrivialMul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
   %cst = constant dense<1.0> : tensor<2x2xf32>
-  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  %1 = "tf.Mul"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  return %1 : tensor<2x2xf32>
+  %0 = "tf.Mul"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
 
   // CHECK-LABEL: RemoveTrivialMul
-  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
+  // CHECK-NEXT: return %arg0 : tensor<2x2xf32>
 }
 
-func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
   %cst = constant dense<1.0> : tensor<2x2xf32>
-  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  %1 = "tf.Div"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  return %1 : tensor<2x2xf32>
+  %0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
 
   // CHECK-LABEL: RemoveTrivialDiv
-  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
+  // CHECK-NEXT: return %arg0 : tensor<2x2xf32>
 }
 
 func @RemoveTrivialRealDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
@@ -420,28 +410,24 @@
   // CHECK: tf.Div
 }
 
-func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
+func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> {
   %cst = constant dense<0.0> : tensor<2x2xf32>
-  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
-  %1 = "tf.AddV2"(%0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  return %1 : tensor<2x2xf32>
+  %0 = "tf.AddV2"(%arg0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
 
   // CHECK-LABEL: DontRemoveTrivialAdd
   // CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32>
-  // CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
-  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
   // CHECK: return %[[RESULT]] : tensor<2x2xf32>
 }
 
-func @DontRemoveTrivialAdd2(%arg0: tensor<?x?xf32>, %arg1: tensor<2x2xf32>) -> tensor<?x?xf32> {
+func @DontRemoveTrivialAdd2(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %cst = constant dense<0.0> : tensor<2x2xf32>
-  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
-  %1 = "tf.AddV2"(%0, %cst) : (tensor<?x?xf32> , tensor<2x2xf32>) -> tensor<?x?xf32>
-  return %1 :tensor<?x?xf32>
+  %0 = "tf.AddV2"(%arg0, %cst) : (tensor<?x?xf32> , tensor<2x2xf32>) -> tensor<?x?xf32>
+  return %0 :tensor<?x?xf32>
 
   // CHECK-LABEL: DontRemoveTrivialAdd2
   // CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32>
-  // CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
-  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %[[CONST]]) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
   // CHECK: return %[[RESULT]] : tensor<?x?xf32>
 }