[XLA:GPU] Enable SPMD passes even when num_partitions == 1.
- This won't do any partitioning as such, but will eliminate the "Sharding" calls that were
added in the IR.
PiperOrigin-RevId: 375162846
Change-Id: Idae3f2f0a71d26c8adc6eac8381176b1af0c9516
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 5a984a7..fd94522 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -160,6 +160,10 @@
Status GpuCompiler::OptimizeHloModule(
HloModule* hlo_module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) {
+ const int64 num_partitions = hlo_module->config().num_partitions();
+ const bool use_spmd =
+ hlo_module->config().use_spmd_partitioning() && num_partitions > 1;
+
{
HloPassPipeline pipeline("optimization");
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
@@ -290,12 +294,11 @@
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
- if (hlo_module->config().use_spmd_partitioning()) {
+ if (use_spmd) {
HloPassPipeline spmd_pipeline("spmd-partitioner");
spmd_pipeline.AddPass<ShardingPropagation>(/*is_spmd=*/true);
spmd_pipeline.AddPass<GpuSpmdPartitioner>(
- hlo_module->config().num_partitions(),
- hlo_module->config().replica_count());
+ num_partitions, hlo_module->config().replica_count());
TF_RETURN_IF_ERROR(spmd_pipeline.Run(hlo_module).status());
}