| // RUN: tf-opt %s -canonicalize | FileCheck %s |
| |
| // CHECK-LABEL: func @tfAssertTrue |
| func @tfAssertTrue(%arg0: tensor<1x1x6x2xf32>) { |
| %t = constant dense<true> : tensor<i1> |
| // CHECK-NOT: tf.Assert |
| "tf.Assert"(%t, %arg0) {summarize = 3} : (tensor<i1>, tensor<1x1x6x2xf32>) -> () |
| return |
| } |
| |
| // CHECK-LABEL: func @tfAssertFalse |
| func @tfAssertFalse(%arg0: tensor<1x1x6x2xf32>) { |
| %f = constant dense<false> : tensor<i1> |
| // CHECK: tf.Assert |
| "tf.Assert"(%f, %arg0) {summarize = 3} : (tensor<i1>, tensor<1x1x6x2xf32>) -> () |
| return |
| } |
| |
| // CHECK-LABEL: func @testLeakyRelu |
| func @testLeakyRelu(%arg0 : tensor<16xf32>) -> (tensor<16xf32>) { |
| %2 = "tf.LeakyRelu"(%arg0) {alpha = 1.0 : f32} : (tensor<16xf32>) -> tensor<16xf32> |
| // CHECK: return %arg0 |
| return %2 : tensor<16xf32> |
| } |
| |
| // CHECK-LABEL: testSameBitcastType |
| func @testSameBitcastType(%arg0: tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> { |
| %0 = "tf.Bitcast"(%arg0) : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> |
| return %0: tensor<8x16x32x64xf32> |
| |
| // CHECK: return %arg0 |
| } |
| |
| // CHECK-LABEL: testDifferentBitcastType |
| func @testDifferentBitcastType(%arg0: tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32> { |
| %0 = "tf.Bitcast"(%arg0) : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32> |
| return %0: tensor<8x16x32x64xi32> |
| |
| // CHECK: %0 = "tf.Bitcast"(%arg0) : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testDoubleBitcast |
| func @testDoubleBitcast(%arg0: tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32> { |
| %0 = "tf.Bitcast"(%arg0) : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64x2xi16> |
| %1 = "tf.Bitcast"(%0) : (tensor<8x16x32x64x2xi16>) -> tensor<8x16x32x64xi32> |
| return %1: tensor<8x16x32x64xi32> |
| |
| // CHECK: %0 = "tf.Bitcast"(%arg0) : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testDoubleBitcastWithDependentArg |
| func @testDoubleBitcastWithDependentArg(%arg0: tensor<8x16x32x64xf32>) -> (tensor<8x16x32x64xi32>, tensor<8x16x32x64x2xi16>) { |
| %0 = "tf.Bitcast"(%arg0) : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64x2xi16> |
| %1 = "tf.Bitcast"(%0) : (tensor<8x16x32x64x2xi16>) -> tensor<8x16x32x64xi32> |
| %2 = "tf.Identity"(%0) : (tensor<8x16x32x64x2xi16>) -> tensor<8x16x32x64x2xi16> |
| return %1, %2 : tensor<8x16x32x64xi32>, tensor<8x16x32x64x2xi16> |
| |
| // CHECK: %0 = "tf.Bitcast"(%arg0) : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64x2xi16> |
| // CHECK: %1 = "tf.Bitcast"(%arg0) : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32> |
| // CHECK: %2 = "tf.Identity"(%0) : (tensor<8x16x32x64x2xi16>) -> tensor<8x16x32x64x2xi16> |
| // CHECK: return %1, %2 |
| } |
| |
| // CHECK-LABEL: testSameCastType |
| func @testSameCastType(%arg0: tensor<8x16x32x64xf32>) -> (tensor<8x16x32x64xf32>, tensor<8x16x32x64xf32>) { |
| %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> |
| %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> |
| return %0, %1: tensor<8x16x32x64xf32>, tensor<8x16x32x64xf32> |
| |
| // CHECK: return %arg0, %arg0 |
| } |
| |
| // CHECK-LABEL: testDifferentCastType |
| func @testDifferentCastType(%arg0: tensor<8x16x32x64xf32>) -> (tensor<8x16x32x64xi32>, tensor<8x16x32x64xi32>) { |
| %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32> |
| %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32> |
| return %0, %1: tensor<8x16x32x64xi32>, tensor<8x16x32x64xi32> |
| |
| // CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32> |
| // CHECK: %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xi32> |
| // CHECK: return %0, %1 |
| } |
| |
| // CHECK-LABEL: testSameCastTypeAcrossBasicBlocks |
| func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> { |
| ^bb0(%arg0: tensor<8x16x32x64xf32>): |
| %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> |
| br ^bb1 |
| ^bb1: |
| %1 = "tf.Cast"(%0) {Truncate = true} : (tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> |
| br ^exit |
| ^exit: |
| return %1: tensor<8x16x32x64xf32> |
| |
| // CHECK: return %arg0 |
| } |
| |
| // CHECK-LABEL: testLogOfSoftmax |
| func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { |
| %0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| %1 = "tf.Log"(%0) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| return %1: tensor<8x16xf32> |
| |
| // CHECK: %0 = "tf.LogSoftmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testSubOfNeg |
| func @testSubOfNeg(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { |
| %0 = "tf.Neg"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| %1 = "tf.Sub"(%arg0, %0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| return %1: tensor<8x16xf32> |
| |
| // CHECK: %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testSquareOfSub |
| func @testSquareOfSub(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { |
| %0 = "tf.Sub"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| %1 = "tf.Square"(%0) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| return %1: tensor<8x16xf32> |
| |
| // CHECK: %0 = "tf.SquaredDifference"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testAddToAddV2 |
| func @testAddToAddV2(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { |
| %0 = "tf.Add"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| return %0: tensor<8x16xf32> |
| |
| // CHECK: %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testNoAddToAddV2ForStringType |
| func @testNoAddToAddV2ForStringType(%arg0: tensor<8x16x!tf.string>, %arg1: tensor<8x16x!tf.string>) -> tensor<8x16x!tf.string> { |
| %0 = "tf.Add"(%arg0, %arg1) : (tensor<8x16x!tf.string>, tensor<8x16x!tf.string>) -> tensor<8x16x!tf.string> |
| return %0: tensor<8x16x!tf.string> |
| |
| // CHECK: %0 = "tf.Add"(%arg0, %arg1) : (tensor<8x16x!tf.string>, tensor<8x16x!tf.string>) -> tensor<8x16x!tf.string> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testAddOfNegLeft |
| func @testAddOfNegLeft(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { |
| %0 = "tf.Neg"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| %1 = "tf.Add"(%0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| return %1: tensor<8x16xf32> |
| |
| // CHECK: %0 = "tf.Sub"(%arg1, %arg0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testAddOfNegRight |
| func @testAddOfNegRight(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { |
| %0 = "tf.Neg"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| %1 = "tf.Add"(%arg0, %0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| return %1: tensor<8x16xf32> |
| |
| // CHECK: %0 = "tf.Sub"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testAddV2OfNegLeft |
| func @testAddV2OfNegLeft(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { |
| %0 = "tf.Neg"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| %1 = "tf.AddV2"(%0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| return %1: tensor<8x16xf32> |
| |
| // CHECK: %0 = "tf.Sub"(%arg1, %arg0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testAddV2OfNegRight |
| func @testAddV2OfNegRight(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { |
| %0 = "tf.Neg"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| %1 = "tf.AddV2"(%arg0, %0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| return %1: tensor<8x16xf32> |
| |
| // CHECK: %0 = "tf.Sub"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testDoubleConj |
| func @testDoubleConj(%arg0: tensor<8x16x32x64x!tf.complex64>) -> tensor<8x16x32x64x!tf.complex64> { |
| %0 = "tf.Conj"(%arg0) : (tensor<8x16x32x64x!tf.complex64>) -> tensor<8x16x32x64x!tf.complex64> |
| %1 = "tf.Conj"(%0) : (tensor<8x16x32x64x!tf.complex64>) -> tensor<8x16x32x64x!tf.complex64> |
| return %1: tensor<8x16x32x64x!tf.complex64> |
| |
| // CHECK: return %arg0 |
| } |
| |
| // CHECK-LABEL: testDoubleInvert |
| func @testDoubleInvert(%arg0: tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi32> { |
| %0 = "tf.Invert"(%arg0) : (tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi32> |
| %1 = "tf.Invert"(%0) : (tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi32> |
| return %1: tensor<8x16x32x64xi32> |
| |
| // CHECK: return %arg0 |
| } |
| |
| // CHECK-LABEL: testDoubleLogicalNot |
| func @testDoubleLogicalNot(%arg0: tensor<8x16x32x64xi1>) -> tensor<8x16x32x64xi1> { |
| %0 = "tf.LogicalNot"(%arg0) : (tensor<8x16x32x64xi1>) -> tensor<8x16x32x64xi1> |
| %1 = "tf.LogicalNot"(%0) : (tensor<8x16x32x64xi1>) -> tensor<8x16x32x64xi1> |
| return %1: tensor<8x16x32x64xi1> |
| |
| // CHECK: return %arg0 |
| } |
| |
| // CHECK-LABEL: testDoubleNeg |
| func @testDoubleNeg(%arg0: tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi32> { |
| %0 = "tf.Neg"(%arg0) : (tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi32> |
| %1 = "tf.Neg"(%0) : (tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi32> |
| return %1: tensor<8x16x32x64xi32> |
| |
| // CHECK: return %arg0 |
| } |
| |
| // CHECK-LABEL: testDoubleReciprocal |
| func @testDoubleReciprocal(%arg0: tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi32> { |
| %0 = "tf.Reciprocal"(%arg0) : (tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi32> |
| %1 = "tf.Reciprocal"(%0) : (tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi32> |
| return %1: tensor<8x16x32x64xi32> |
| |
| // CHECK: return %arg0 |
| } |
| |
| // CHECK-LABEL: testLogicalNotOfEqual |
| func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> { |
| %0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> |
| %1 = "tf.LogicalNot"(%0) : (tensor<8x16xi1>) -> tensor<8x16xi1> |
| return %1: tensor<8x16xi1> |
| |
| // CHECK: %0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testLogicalNotOfNotEqual |
| func @testLogicalNotOfNotEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> { |
| %0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> |
| %1 = "tf.LogicalNot"(%0) : (tensor<8x16xi1>) -> tensor<8x16xi1> |
| return %1: tensor<8x16xi1> |
| |
| // CHECK: %0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testLogicalNotOfGreater |
| func @testLogicalNotOfGreater(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> { |
| %0 = "tf.Greater"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> |
| %1 = "tf.LogicalNot"(%0) : (tensor<8x16xi1>) -> tensor<8x16xi1> |
| return %1: tensor<8x16xi1> |
| |
| // CHECK: %0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testLogicalNotOfGreaterEqual |
| func @testLogicalNotOfGreaterEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> { |
| %0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> |
| %1 = "tf.LogicalNot"(%0) : (tensor<8x16xi1>) -> tensor<8x16xi1> |
| return %1: tensor<8x16xi1> |
| |
| // CHECK: %0 = "tf.Less"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testLogicalNotOfLess |
| func @testLogicalNotOfLess(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> { |
| %0 = "tf.Less"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> |
| %1 = "tf.LogicalNot"(%0) : (tensor<8x16xi1>) -> tensor<8x16xi1> |
| return %1: tensor<8x16xi1> |
| |
| // CHECK: %0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testLogicalNotOfLessEqual |
| func @testLogicalNotOfLessEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> { |
| %0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> |
| %1 = "tf.LogicalNot"(%0) : (tensor<8x16xi1>) -> tensor<8x16xi1> |
| return %1: tensor<8x16xi1> |
| |
| // CHECK: %0 = "tf.Greater"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> |
| // CHECK: return %0 |
| } |
| |
| // CHECK-LABEL: testDivWithSqrtDivisor |
| func @testDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { |
| %0 = "tf.Sqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| %1 = "tf.Div"(%arg0, %0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| return %1: tensor<8x16xf32> |
| |
| // CHECK: %0 = "tf.Rsqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: %1 = "tf.Mul"(%arg0, %0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: return %1 |
| } |
| |
| // CHECK-LABEL: testRealDivWithSqrtDivisor |
| func @testRealDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { |
| %0 = "tf.Sqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| %1 = "tf.RealDiv"(%arg0, %0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| return %1: tensor<8x16xf32> |
| |
| // CHECK: %0 = "tf.Rsqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: %1 = "tf.Mul"(%arg0, %0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: return %1 |
| } |
| |
| // CHECK-LABEL: testTruncateDivWithSqrtDivisor |
| func @testTruncateDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { |
| %0 = "tf.Sqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| %1 = "tf.TruncateDiv"(%arg0, %0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| return %1: tensor<8x16xf32> |
| |
| // CHECK: %0 = "tf.Rsqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: %1 = "tf.Mul"(%arg0, %0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: return %1 |
| } |
| |
| // CHECK-LABEL: testXdivyWithSqrtDivisor |
| func @testXdivyWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { |
| %0 = "tf.Sqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| %1 = "tf.Xdivy"(%arg0, %0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| return %1: tensor<8x16xf32> |
| |
| // CHECK: %0 = "tf.Rsqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: %1 = "tf.MulNoNan"(%0, %arg0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> |
| // CHECK: return %1 |
| } |