Fixing rpc_ops_test for when we start wrapping RPC eager ops in tf.functions. The issue was that the RPCClient op would omit returning an output in some cases which seemed to work in eager mode but when executed within a function, it complains. We ensure we return an empty string instead.

Also, marking the RpcServerRegister op as Input colocation exempt as it runs a function - ops that run functions go through the multi device backend and can handle inputs on different devices. So we don't need the placer constraint to force colocation of input resources.

PiperOrigin-RevId: 406857024
Change-Id: Id413a358244c6c42848c198023594b14806dc351
diff --git a/tensorflow/distribute/experimental/rpc/kernels/rpc_ops.cc b/tensorflow/distribute/experimental/rpc/kernels/rpc_ops.cc
index e6812ca..fc6e867 100644
--- a/tensorflow/distribute/experimental/rpc/kernels/rpc_ops.cc
+++ b/tensorflow/distribute/experimental/rpc/kernels/rpc_ops.cc
@@ -28,6 +28,7 @@
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
 // Needed for encoding and decoding ResourceDeleter Variant.
+#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
 #include "tensorflow/core/data/dataset_utils.h"
 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
@@ -573,6 +574,10 @@
   ctx->set_output(0, handle);
 
   if (!list_registered_methods_) {
+    Tensor* method_output_t;
+    OP_REQUIRES_OK_ASYNC(
+        ctx, ctx->allocate_output(1, TensorShape({}), &method_output_t), done);
+    method_output_t->scalar<tstring>()() = "";
     done();
     return;
   }
@@ -895,5 +900,6 @@
 REGISTER_KERNEL_BUILDER(Name("DeleteRpcFutureResource").Device(DEVICE_CPU),
                         DeleteRpcFutureResourceOp);
 
+REGISTER_INPUT_COLOCATION_EXEMPTION("RpcServerRegister");
 }  // namespace rpc
 }  // namespace tensorflow
diff --git a/tensorflow/python/distribute/experimental/rpc/rpc_ops_test.py b/tensorflow/python/distribute/experimental/rpc/rpc_ops_test.py
index caf763b..d703a9c 100644
--- a/tensorflow/python/distribute/experimental/rpc/rpc_ops_test.py
+++ b/tensorflow/python/distribute/experimental/rpc/rpc_ops_test.py
@@ -28,6 +28,7 @@
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
@@ -36,6 +37,7 @@
 from tensorflow.python.util import nest
 
 
+@test_util.with_eager_op_as_function
 class RpcOpsTest(test.TestCase):
 
   def setUp(self):