Enable MLIR generated Sub GPU kernels for unsigned types.

They can use the same generated code as for the signed types of the same size.
Also remove the template instantiation for uint8 and uint16, there is no GPU
kernel registered for those types.

PiperOrigin-RevId: 386832546
Change-Id: I1946f6b4f3e78c006c45401d9053d2e887a4c5c7
diff --git a/tensorflow/core/kernels/cwise_op_gpu_sub.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sub.cu.cc
index d51e458..f440919 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_sub.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_sub.cu.cc
@@ -20,9 +20,9 @@
 namespace tensorflow {
 namespace functor {
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
-DEFINE_BINARY6(sub, Eigen::half, float, double, int64, complex64, complex128);
+DEFINE_BINARY8(sub, Eigen::half, float, double, int64, uint32, uint64,
+               complex64, complex128);
 #endif
-DEFINE_BINARY4(sub, uint8, uint16, uint32, uint64);
 }  // namespace functor
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/cwise_op_sub.cc b/tensorflow/core/kernels/cwise_op_sub.cc
index 36a677b..307def2 100644
--- a/tensorflow/core/kernels/cwise_op_sub.cc
+++ b/tensorflow/core/kernels/cwise_op_sub.cc
@@ -33,8 +33,6 @@
 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
 REGISTER8(BinaryOp, GPU, "Sub", functor::sub, float, Eigen::half, double, int64,
           complex64, complex128, uint32, uint64);
-#else
-REGISTER2(BinaryOp, GPU, "Sub", functor::sub, uint64, uint32);
 #endif
 
 // A special GPU kernel for int32.
diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD
index 87adb56..03912fe 100644
--- a/tensorflow/core/kernels/mlir_generated/BUILD
+++ b/tensorflow/core/kernels/mlir_generated/BUILD
@@ -940,6 +940,7 @@
         "f16",
         "f32",
         "f64",
+        "i32",
         "i64",
         "c64",
         "c128",
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc
index fa94e574..4a4d70e 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc
@@ -1106,6 +1106,23 @@
                        test::OpsTestConfig().ExpectStrictlyEqual())
 GENERATE_DEFAULT_TESTS(Sub, /*test_name=*/Int64, int64_t, int64_t, baseline_sub,
                        test::OpsTestConfig().ExpectStrictlyEqual())
+GENERATE_DEFAULT_TESTS(Sub, /*test_name=*/UInt32, uint32_t, uint32_t,
+                       baseline_sub,
+                       test::OpsTestConfig().ExpectStrictlyEqual())
+GENERATE_DEFAULT_TESTS(Sub, /*test_name=*/UInt64, uint64_t, uint64_t,
+                       baseline_sub,
+                       test::OpsTestConfig().ExpectStrictlyEqual())
+
+TEST_F(BinaryOpsTest, SubUint32SpecialCases) {
+  TestEqualShapes<uint32_t, uint32_t, uint32_t, uint32_t>(
+      "Sub", /*shape=*/{20},
+      test::InputAsVector<uint32_t>(
+          {std::numeric_limits<uint32_t>::max(), 0u, 0u, 2u}),
+      test::InputAsVector<uint32_t>({std::numeric_limits<uint32_t>::max(),
+                                     std::numeric_limits<uint32_t>::max(), 1u,
+                                     1u}),
+      baseline_sub, test::OpsTestConfig().ExpectStrictlyEqual());
+}
 
 /// Test `tf.Xlogy`.
 
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc
index 4abe6d6..3cdfa9e 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc
@@ -22,6 +22,10 @@
 GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Sub, DT_FLOAT);
 GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Sub, DT_DOUBLE);
 GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Sub, DT_INT64);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL3(Sub, DT_INT32, DT_INT32, DT_UINT32,
+                                         DT_UINT32);
+GENERATE_AND_REGISTER_BINARY_GPU_KERNEL3(Sub, DT_INT64, DT_INT64, DT_UINT64,
+                                         DT_UINT64);
 GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Sub, DT_COMPLEX64);
 GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(Sub, DT_COMPLEX128);