Fix tokenization tests and update testing_utils to transfer state between layer creation.
PiperOrigin-RevId: 317379253
Change-Id: I786c2eb0506239de0e7f1a5f314a8f1b0bda10d4
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 1928588..cceaabe 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -94,7 +94,8 @@
expected_output_shape=None,
validate_training=True,
adapt_data=None,
- custom_objects=None):
+ custom_objects=None,
+ test_harness=None):
"""Test routine for a layer with a single input and single output.
Arguments:
@@ -114,6 +115,8 @@
be tested for this layer. This is only relevant for PreprocessingLayers.
custom_objects: Optional dictionary mapping name strings to custom objects
in the layer class. This is helpful for testing custom layers.
+ test_harness: The Tensorflow test, if any, that this function is being
+ called in.
Returns:
The output data (Numpy array) returned by the layer, for additional
@@ -143,9 +146,15 @@
expected_output_dtype = input_dtype
if dtypes.as_dtype(expected_output_dtype) == dtypes.string:
- assert_equal = string_test
+ if test_harness:
+ assert_equal = test_harness.assertAllEqual
+ else:
+ assert_equal = string_test
else:
- assert_equal = numeric_test
+ if test_harness:
+ assert_equal = test_harness.assertAllClose
+ else:
+ assert_equal = numeric_test
# instantiation
kwargs = kwargs or {}
@@ -228,6 +237,7 @@
# test training mode (e.g. useful for dropout tests)
# Rebuild the model to avoid the graph being reused between predict() and
# See b/120160788 for more details. This should be mitigated after 2.0.
+ layer_weights = layer.get_weights() # Get the layer weights BEFORE training.
if validate_training:
model = models.Model(x, layer(x))
if _thread_local_data.run_eagerly is not None:
@@ -252,6 +262,8 @@
model = models.Sequential()
model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype))
model.add(layer)
+
+ layer.set_weights(layer_weights)
actual_output = model.predict(input_data)
actual_output_shape = actual_output.shape
for expected_dim, actual_dim in zip(computed_output_shape,