[tfdbg2] Reinstate a test in distributed_callbacks_test

- CL/296090972 changed internal tf.function names associated with
  tf.keras.Model.fit() calls. A test in distributed_callbacks_test
  was disabled temporarily to unblock the CL.
- This CL updates the logic for determining which op_types
  in tfdbg2's debug data are tf.functions, uses the updated logic
  from the said test, and un-disables it.

PiperOrigin-RevId: 298212480
Change-Id: I4481ef5344b3e72848bfd38c6e316c5d6e348213
diff --git a/tensorflow/python/debug/lib/distributed_callbacks_test.py b/tensorflow/python/debug/lib/distributed_callbacks_test.py
index 606f14b..0a09512 100644
--- a/tensorflow/python/debug/lib/distributed_callbacks_test.py
+++ b/tensorflow/python/debug/lib/distributed_callbacks_test.py
@@ -195,7 +195,6 @@
           self.assertAllClose(device_1_matmul_values[0], [[10.0]])
           self.assertAllClose(device_1_bias_add_values[0], [[11.0]])
 
-  # TODO(b/148461691): Fix for new Keras internals.
   @combinations.generate(
       combinations.combine(
           distribution=[
@@ -207,8 +206,7 @@
           mode=["eager"],
           tensor_debug_mode=["NO_TENSOR", "FULL_TENSOR"],
       ))
-  def DISABLED_testKerasModelFitOnOneOrTwoDevices(self, distribution,
-                                                  tensor_debug_mode):
+  def testKerasModelFitOnOneOrTwoDevices(self, distribution, tensor_debug_mode):
     writer = dumping_callback.enable_dump_debug_info(
         self.dump_root, tensor_debug_mode=tensor_debug_mode)
 
@@ -235,7 +233,7 @@
       fit_executions = [
           execution.op_type
           for execution in executions
-          if "_distributed_function" in execution.op_type
+          if dumping_callback.is_op_type_function(execution.op_type)
       ]
       self.assertLen(fit_executions, epochs)
 
diff --git a/tensorflow/python/debug/lib/dumping_callback.py b/tensorflow/python/debug/lib/dumping_callback.py
index 706181d..9218910 100644
--- a/tensorflow/python/debug/lib/dumping_callback.py
+++ b/tensorflow/python/debug/lib/dumping_callback.py
@@ -49,6 +49,17 @@
 _state = threading.local()
 DEFAULT_TENSOR_DEBUG_MODE = "NO_TENSOR"
 
+# pylint:disable=protected-access
+_FUNCTION_PREFIXES = (
+    compat.as_bytes(function_lib._FORWARD_PREFIX),
+    compat.as_bytes(function_lib._BACKWARD_PREFIX),
+    compat.as_bytes(function_lib._INFERENCE_PREFIX))
+# pylint:enable=protected-access
+
+
+def is_op_type_function(op_type):
+  return compat.as_bytes(op_type).startswith(_FUNCTION_PREFIXES)
+
 
 @ops.RegisterGradient("DebugIdentityV2")
 def _debug_identity_v2_grad(op, dy):
@@ -89,12 +100,6 @@
     # Mapping op context to unique ID.
     self._context_to_id = dict()
     self._function_to_graph_id = dict()
-    # pylint:disable=protected-access
-    self._function_prefixes = (
-        compat.as_bytes(function_lib._FORWARD_PREFIX),
-        compat.as_bytes(function_lib._BACKWARD_PREFIX),
-        compat.as_bytes(function_lib._INFERENCE_PREFIX))
-    # pylint:enable=protected-access
     self._op_type_to_context_id = dict()
     # Keeps track of counter for symbolic tensors output by in-graph ops.
     self._symbolic_tensor_counter = 0
@@ -583,7 +588,7 @@
       Else, `None` is returned.
     """
     op_type = compat.as_bytes(op_type)
-    if op_type.startswith(self._function_prefixes):
+    if is_op_type_function(op_type):
       # op_type for eagerly-executed FuncGraphs have the prefixed and suffixed
       # form such as "__inference_my_function_13579", wherein the middle part
       # "my_function" is the name of the Python function from which the