Allow enable_auto_xla_input_sharding=true and enable_variable_deduplication=true for single core models.
If num_cores_per_replica==1, XLA won't split the model across multiple cores making enable_auto_xla_input_sharding effectively a no-op.
PiperOrigin-RevId: 378654009
Change-Id: Icb3bba8e099ec2a9d8081427480966fffb520967
diff --git a/tensorflow/core/tpu/kernels/tpu_functional_ops.cc b/tensorflow/core/tpu/kernels/tpu_functional_ops.cc
index 1c562db..e9e49b8 100644
--- a/tensorflow/core/tpu/kernels/tpu_functional_ops.cc
+++ b/tensorflow/core/tpu/kernels/tpu_functional_ops.cc
@@ -1584,13 +1584,17 @@
// partitioning. It is possible that it could be supported in the future.
const bool enable_variable_deduplication =
runtime_params_.enable_variable_deduplication;
- if (enable_spmd_xla_partitioning && enable_variable_deduplication) {
+ if (enable_spmd_xla_partitioning && num_cores_per_replica > 1 &&
+ enable_variable_deduplication) {
// If enable_spmd_xla_partitioning is true, the user set the
// enable_auto_xla_input_sharding flag. Warn them that only one of the flags
- // can be set safely.
+ // can be set safely when num_cores_per_replica > 1. If
+ // num_cores_per_replica==1, enable_spmd_xla_partitioning is effectively a
+ // no-op so we can skip this check.
return errors::InvalidArgument(
- "The following flags are incompatible: enable_auto_xla_input_sharding "
- "and enable_variable_deduplication. Only enable one of the flags.");
+ "The following flags are incompatible when num_cores_per_replica > 1: "
+ "enable_auto_xla_input_sharding and enable_variable_deduplication. "
+ "Only enable one of the flags.");
}
std::vector<Node*> tpu_resource_args;
std::vector<int> arg_indices;