Displays an error when we perfom an unvoluntary rematerialization.
PiperOrigin-RevId: 457684888
diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD
index bd1a102..f095445 100644
--- a/tensorflow/compiler/xla/service/spmd/BUILD
+++ b/tensorflow/compiler/xla/service/spmd/BUILD
@@ -66,7 +66,6 @@
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/container:node_hash_map",
- "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
index 8f540c3..fb33a48 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
@@ -55,6 +55,7 @@
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace spmd {
@@ -478,6 +479,13 @@
// If not replicated yet, first replicate and then reshard to use one of the
// two implementations below.
if (!sharding().IsReplicated()) {
+ LOG(ERROR) << "[spmd] Involuntary full rematerialization. The compiled was "
+ "not able to go from sharding "
+ << sharding().ToString(/*include_metadata=*/true) << " to "
+ << target.ToString(/*include_metadata=*/true)
+ << " without doing a full rematerialization of the tensor. You "
+ "probably want to enrich the sharding annotations to prevent "
+ "this from happening.";
return Replicate().Reshard(target);
}