Only set device assignment if non empty.

PiperOrigin-RevId: 371709740
Change-Id: I9c8b8031c74fa0950f026df0d05c70f94243910e
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 960e7e7..2ba2c4f 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -38,6 +38,7 @@
     "//tensorflow/core/platform:stream_executor_no_cuda",
     "//tensorflow/core/profiler/lib:traceme",
     "//tensorflow/stream_executor:tf_allocator_adapter",
+    "@com_google_absl//absl/types:optional",
 ]
 
 # Linked by tensorflow core, without registration of jit compilation passes.
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
index 487c183..7a7013e 100644
--- a/tensorflow/compiler/jit/kernels/xla_ops.cc
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -18,6 +18,7 @@
 #include "absl/container/flat_hash_map.h"
 #include "absl/memory/memory.h"
 #include "absl/synchronization/notification.h"
+#include "absl/types/optional.h"
 #include "tensorflow/compiler/jit/defs.h"
 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
 #include "tensorflow/compiler/jit/flags.h"
@@ -221,7 +222,8 @@
 
 // Resolve the device assignment for the TF single-host MirroredStrategy by
 // calling into TF runtime which in turn would start a rendezvous.
-static xla::StatusOr<xla::DeviceAssignment> ResolveDeviceAssignment(
+static xla::StatusOr<absl::optional<xla::DeviceAssignment>>
+ResolveDeviceAssignment(
     OpKernelContext* ctx,
     const absl::optional<
         XlaCompiler::CompilationResult::CollectiveReduceV2OpInfo>&
@@ -230,7 +232,7 @@
   if (!collective_reduce_info) {
     // An empty device assignment is sufficient for the case where no
     // collectives are present.
-    return xla::DeviceAssignment{};
+    return {{absl::nullopt}};
   }
 
   CollectiveParams params;
@@ -279,7 +281,7 @@
     out(device_idx, 0) = gpu_device_info->stream->parent()->device_ordinal();
   }
 
-  return out;
+  return {{out}};
 }
 
 void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
@@ -333,12 +335,14 @@
 
   // Execute the computation.
   VLOG(2) << "Executing computation.";
-  xla::StatusOr<xla::DeviceAssignment> device_assignment =
+  xla::StatusOr<absl::optional<xla::DeviceAssignment>> device_assignment =
       ResolveDeviceAssignment(ctx, compilation_result->collective_reduce_info);
   OP_REQUIRES_OK(ctx, device_assignment.status());
 
   xla::ExecutableRunOptions run_options;
-  run_options.set_device_assignment(&*device_assignment);
+  if (*device_assignment) {
+    run_options.set_device_assignment(&**device_assignment);
+  }
   run_options.set_stream(stream);
   run_options.set_allocator(allocator);
   run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());