[tfdbg2] Reinstate a test in distributed_callbacks_test
- CL/296090972 changed internal tf.function names associated with
tf.keras.Model.fit() calls. A test in distributed_callbacks_test
was disabled temporarily to unblock the CL.
- This CL updates the logic for determining which op_types
in tfdbg2's debug data are tf.functions, uses the updated logic
from the said test, and un-disables it.
PiperOrigin-RevId: 298212480
Change-Id: I4481ef5344b3e72848bfd38c6e316c5d6e348213
diff --git a/tensorflow/python/debug/lib/distributed_callbacks_test.py b/tensorflow/python/debug/lib/distributed_callbacks_test.py
index 606f14b..0a09512 100644
--- a/tensorflow/python/debug/lib/distributed_callbacks_test.py
+++ b/tensorflow/python/debug/lib/distributed_callbacks_test.py
@@ -195,7 +195,6 @@
self.assertAllClose(device_1_matmul_values[0], [[10.0]])
self.assertAllClose(device_1_bias_add_values[0], [[11.0]])
- # TODO(b/148461691): Fix for new Keras internals.
@combinations.generate(
combinations.combine(
distribution=[
@@ -207,8 +206,7 @@
mode=["eager"],
tensor_debug_mode=["NO_TENSOR", "FULL_TENSOR"],
))
- def DISABLED_testKerasModelFitOnOneOrTwoDevices(self, distribution,
- tensor_debug_mode):
+ def testKerasModelFitOnOneOrTwoDevices(self, distribution, tensor_debug_mode):
writer = dumping_callback.enable_dump_debug_info(
self.dump_root, tensor_debug_mode=tensor_debug_mode)
@@ -235,7 +233,7 @@
fit_executions = [
execution.op_type
for execution in executions
- if "_distributed_function" in execution.op_type
+ if dumping_callback.is_op_type_function(execution.op_type)
]
self.assertLen(fit_executions, epochs)
diff --git a/tensorflow/python/debug/lib/dumping_callback.py b/tensorflow/python/debug/lib/dumping_callback.py
index 706181d..9218910 100644
--- a/tensorflow/python/debug/lib/dumping_callback.py
+++ b/tensorflow/python/debug/lib/dumping_callback.py
@@ -49,6 +49,17 @@
_state = threading.local()
DEFAULT_TENSOR_DEBUG_MODE = "NO_TENSOR"
+# pylint:disable=protected-access
+_FUNCTION_PREFIXES = (
+ compat.as_bytes(function_lib._FORWARD_PREFIX),
+ compat.as_bytes(function_lib._BACKWARD_PREFIX),
+ compat.as_bytes(function_lib._INFERENCE_PREFIX))
+# pylint:enable=protected-access
+
+
+def is_op_type_function(op_type):
+ return compat.as_bytes(op_type).startswith(_FUNCTION_PREFIXES)
+
@ops.RegisterGradient("DebugIdentityV2")
def _debug_identity_v2_grad(op, dy):
@@ -89,12 +100,6 @@
# Mapping op context to unique ID.
self._context_to_id = dict()
self._function_to_graph_id = dict()
- # pylint:disable=protected-access
- self._function_prefixes = (
- compat.as_bytes(function_lib._FORWARD_PREFIX),
- compat.as_bytes(function_lib._BACKWARD_PREFIX),
- compat.as_bytes(function_lib._INFERENCE_PREFIX))
- # pylint:enable=protected-access
self._op_type_to_context_id = dict()
# Keeps track of counter for symbolic tensors output by in-graph ops.
self._symbolic_tensor_counter = 0
@@ -583,7 +588,7 @@
Else, `None` is returned.
"""
op_type = compat.as_bytes(op_type)
- if op_type.startswith(self._function_prefixes):
+ if is_op_type_function(op_type):
# op_type for eagerly-executed FuncGraphs have the prefixed and suffixed
# form such as "__inference_my_function_13579", wherein the middle part
# "my_function" is the name of the Python function from which the