Fix tokenization tests and update testing_utils to transfer state between layer creation.

PiperOrigin-RevId: 317379253
Change-Id: I786c2eb0506239de0e7f1a5f314a8f1b0bda10d4
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 1928588..cceaabe 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -94,7 +94,8 @@
                expected_output_shape=None,
                validate_training=True,
                adapt_data=None,
-               custom_objects=None):
+               custom_objects=None,
+               test_harness=None):
   """Test routine for a layer with a single input and single output.
 
   Arguments:
@@ -114,6 +115,8 @@
       be tested for this layer. This is only relevant for PreprocessingLayers.
     custom_objects: Optional dictionary mapping name strings to custom objects
       in the layer class. This is helpful for testing custom layers.
+    test_harness: The Tensorflow test, if any, that this function is being
+      called in.
 
   Returns:
     The output data (Numpy array) returned by the layer, for additional
@@ -143,9 +146,15 @@
     expected_output_dtype = input_dtype
 
   if dtypes.as_dtype(expected_output_dtype) == dtypes.string:
-    assert_equal = string_test
+    if test_harness:
+      assert_equal = test_harness.assertAllEqual
+    else:
+      assert_equal = string_test
   else:
-    assert_equal = numeric_test
+    if test_harness:
+      assert_equal = test_harness.assertAllClose
+    else:
+      assert_equal = numeric_test
 
   # instantiation
   kwargs = kwargs or {}
@@ -228,6 +237,7 @@
   # test training mode (e.g. useful for dropout tests)
   # Rebuild the model to avoid the graph being reused between predict() and
   # See b/120160788 for more details. This should be mitigated after 2.0.
+  layer_weights = layer.get_weights()  # Get the layer weights BEFORE training.
   if validate_training:
     model = models.Model(x, layer(x))
     if _thread_local_data.run_eagerly is not None:
@@ -252,6 +262,8 @@
   model = models.Sequential()
   model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype))
   model.add(layer)
+
+  layer.set_weights(layer_weights)
   actual_output = model.predict(input_data)
   actual_output_shape = actual_output.shape
   for expected_dim, actual_dim in zip(computed_output_shape,