[XLA:GPU] Enable SPMD passes even when num_partitions == 1.
- This won't do any partitioning as such, but will eliminate the "Sharding" calls that were
  added in the IR.

PiperOrigin-RevId: 375162846
Change-Id: Idae3f2f0a71d26c8adc6eac8381176b1af0c9516
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 5a984a7..fd94522 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -160,6 +160,10 @@
 Status GpuCompiler::OptimizeHloModule(
     HloModule* hlo_module, se::StreamExecutor* stream_exec,
     se::DeviceMemoryAllocator* device_allocator) {
+  const int64 num_partitions = hlo_module->config().num_partitions();
+  const bool use_spmd =
+      hlo_module->config().use_spmd_partitioning() && num_partitions > 1;
+
   {
     HloPassPipeline pipeline("optimization");
     pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
@@ -290,12 +294,11 @@
     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
   }
 
-  if (hlo_module->config().use_spmd_partitioning()) {
+  if (use_spmd) {
     HloPassPipeline spmd_pipeline("spmd-partitioner");
     spmd_pipeline.AddPass<ShardingPropagation>(/*is_spmd=*/true);
     spmd_pipeline.AddPass<GpuSpmdPartitioner>(
-        hlo_module->config().num_partitions(),
-        hlo_module->config().replica_count());
+        num_partitions, hlo_module->config().replica_count());
     TF_RETURN_IF_ERROR(spmd_pipeline.Run(hlo_module).status());
   }