Fix flakiness in rpc_ops_test due to delay in server start, which resulted in getting Unavailable error instead of Deadline_exceeded error occasionally.
Also removed the deadline exceeded error check for client created with list_methods=true. Since sometimes the request succeeded instead of getting error.
PiperOrigin-RevId: 417914184
Change-Id: If6e22717db13eaf764e097a8254d932edea0592b
diff --git a/tensorflow/python/distribute/experimental/rpc/rpc_ops_test.py b/tensorflow/python/distribute/experimental/rpc/rpc_ops_test.py
index 905b91c..e600917 100644
--- a/tensorflow/python/distribute/experimental/rpc/rpc_ops_test.py
+++ b/tensorflow/python/distribute/experimental/rpc/rpc_ops_test.py
@@ -246,7 +246,9 @@
self.assertAllEqual(result_or.is_ok(), True)
nest.map_structure(self.assertAllEqual, result_or.get_value(), {"key": 2})
- self.assertTrue(client.is_positive(a))
+ result_or = client.is_positive(a)
+ self.assertTrue(result_or.is_ok())
+ self.assertTrue(result_or.get_value())
result_or = client.test_nested_structure(a)
self.assertAllEqual(result_or.is_ok(), True)
@@ -354,7 +356,6 @@
server.register("read_var", read_var)
def test_client_timeout(self):
- self.skipTest("b/210527152: Flaky")
port = portpicker.pick_unused_port()
address = "localhost:{}".format(port)
@@ -378,6 +379,22 @@
t = threading.Thread(target=start_server)
t.start()
+ def ensure_server_is_ready(client):
+ server_ready = False
+ while not server_ready:
+ result_or = client.call(
+ "add", [constant_op.constant(20),
+ constant_op.constant(30)])
+ if result_or.is_ok():
+ server_ready = True
+ else:
+ error_code, _ = result_or.get_error()
+ if error_code == errors.UNAVAILABLE:
+ server_ready = False
+ else:
+ server_ready = True
+ return
+
# Create client with list_registered_methods fails before server is started.
with self.assertRaises(errors.DeadlineExceededError):
rpc_ops.GrpcClient(
@@ -391,11 +408,14 @@
client = rpc_ops.GrpcClient(
address, name="client1", list_registered_methods=False, timeout_in_ms=1)
- # Make explicit RPC call, the default timeout of 1 ms should lead to
+ ensure_server_is_ready(client)
+ # Make explicit RPC call, the timeout of 1 ms should lead to
# deadline exceeded error.
+
result_or = client.call(
"add", [constant_op.constant(20),
- constant_op.constant(30)])
+ constant_op.constant(30)],
+ timeout_in_ms=1)
self.assertAllEqual(result_or.is_ok(), False)
error_code, error_message = result_or.get_error()
self.assertAllEqual(error_code, errors.DEADLINE_EXCEEDED, error_message)
@@ -410,22 +430,15 @@
# Test timeouts for convenience methods
- # Client with no default timeout.
- client = rpc_ops.GrpcClient(
- address, name="client2", list_registered_methods=True)
-
# Restart server again with delay to simulate deadline exceeded.
del server
server = rpc_ops.GrpcServer(address)
t = threading.Thread(target=start_server)
t.start()
- # Call fails with 1 ms timeout.
- result_or = client.add(
- constant_op.constant(20), constant_op.constant(30), timeout_in_ms=1)
- self.assertAllEqual(result_or.is_ok(), False)
- error_code, _ = result_or.get_error()
- self.assertAllEqual(error_code, errors.DEADLINE_EXCEEDED)
+ # Client with no default timeout.
+ client = rpc_ops.GrpcClient(
+ address, name="client2", list_registered_methods=True)
# Succeeds with reasonable timeout.
result_or = client.add(