Only trigger resharding when necessary.

PiperOrigin-RevId: 332060600
Change-Id: If36314756713071679c20cd6645434c5db53500b
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
index 6ecde66..78be6cc 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
@@ -537,10 +537,9 @@
   // Build a constant default key to specify that the unformatting should
   // transform the variables to the original format.
   builder.setInsertionPointAfter(while_op);
-  tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {3});
+  tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {2});
   default_key_tensor.vec<tensorflow::tstring>()(0) = kDefaultShardingValue;
   default_key_tensor.vec<tensorflow::tstring>()(1) = kDefaultShardingValue;
-  default_key_tensor.vec<tensorflow::tstring>()(2) = kDefaultShardingValue;
   auto default_state_key = builder.create<TF::ConstOp>(
       while_op.getLoc(),
       tensorflow::ConvertTensor(default_key_tensor, &builder).ValueOrDie());
diff --git a/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.cc b/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.cc
index f303588..fad8e22 100644
--- a/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.cc
+++ b/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.cc
@@ -504,10 +504,9 @@
       "TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
   AddNodeAttr("dtype", DT_STRING, &default_sharding);
 
-  Tensor t(DT_STRING, {3});
+  Tensor t(DT_STRING, {2});
   t.vec<tstring>()(0) = kDefaultShardingValue;
   t.vec<tstring>()(1) = kDefaultShardingValue;
-  t.vec<tstring>()(2) = kDefaultShardingValue;
   t.AsProtoTensorContent(
       (*default_sharding.mutable_attr())["value"].mutable_tensor());
 
diff --git a/tensorflow/core/tpu/kernels/compiled_subgraph.h b/tensorflow/core/tpu/kernels/compiled_subgraph.h
index 091d6e7..a97c652 100644
--- a/tensorflow/core/tpu/kernels/compiled_subgraph.h
+++ b/tensorflow/core/tpu/kernels/compiled_subgraph.h
@@ -124,9 +124,6 @@
   // Compilation cache proto key to identify the cache entry.
   std::vector<std::string> proto_key;
 
-  // Fingerprints of sharding programs if there is any.
-  std::vector<std::string> sharding_key;
-
   // The number of 'external' client-held references to the entry.
   int external_references = 0;
 
diff --git a/tensorflow/core/tpu/kernels/host_compute_ops.cc b/tensorflow/core/tpu/kernels/host_compute_ops.cc
index 5295c1c..77a7d6f 100644
--- a/tensorflow/core/tpu/kernels/host_compute_ops.cc
+++ b/tensorflow/core/tpu/kernels/host_compute_ops.cc
@@ -58,9 +58,9 @@
     OP_REQUIRES_ASYNC(
         ctx,
         TensorShapeUtils::IsVector(input.shape()) &&
-            input.shape().dim_size(0) == 3,
+            input.shape().dim_size(0) == 2,
         errors::InvalidArgument("Input shape ", input.shape().DebugString(),
-                                " is not a vector of length 3."),
+                                " is not a vector of length 2."),
         done);
     const string rendezvous_key_base = input.vec<tstring>()(1);
     OP_REQUIRES_ASYNC(
@@ -164,10 +164,10 @@
     const Tensor& key_input = ctx->input(ctx->num_inputs() - 1);
     OP_REQUIRES(ctx,
                 TensorShapeUtils::IsVector(key_input.shape()) &&
-                    key_input.shape().dim_size(0) == 3,
+                    key_input.shape().dim_size(0) == 2,
                 errors::InvalidArgument("Key input shape ",
                                         key_input.shape().DebugString(),
-                                        " is not a vector of length 3."));
+                                        " is not a vector of length 2."));
     const string rendezvous_key_base = key_input.vec<tstring>()(1);
     OP_REQUIRES(
         ctx, ctx->rendezvous() != nullptr,
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc
index 0d0a2ae..1928303 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc
@@ -362,15 +362,14 @@
     const TpuCompilationCacheKey& subgraph_key,
     const SessionMetadata* session_metadata,
     CompilationRefHolder* per_step_ref_holder, int64* uid,
-    std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
+    std::vector<std::string>* proto_key,
     std::vector<bool>* may_modify_variables,
     absl::Span<const xla::HloProto* const>* hlo_metadatas,
     const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
   std::vector<CompiledSubgraph*> removed_entries;
   auto status = CompileIfKeyAbsentHelper(
       subgraph_key, session_metadata, per_step_ref_holder, uid, proto_key,
-      sharding_key, may_modify_variables, &removed_entries, hlo_metadatas,
-      compile_function);
+      may_modify_variables, &removed_entries, hlo_metadatas, compile_function);
   for (auto entry : removed_entries) {
     UnloadAndDestroy(entry);
   }
@@ -400,7 +399,7 @@
     const TpuCompilationCacheKey& subgraph_key,
     const SessionMetadata* session_metadata,
     CompilationRefHolder* per_step_ref_holder, int64* uid,
-    std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
+    std::vector<std::string>* proto_key,
     std::vector<bool>* may_modify_variables,
     std::vector<CompiledSubgraph*>* removed_entries,
     absl::Span<const xla::HloProto* const>* hlo_metadatas,
@@ -498,7 +497,6 @@
   *uid = entry->uid;
   // Let the caller know the keys for each of the cached protos.
   *proto_key = entry->proto_key;
-  *sharding_key = entry->sharding_key;
   *may_modify_variables = entry->tpu_program_group->may_modify_variables_list();
   *hlo_metadatas = entry->tpu_program_group->hlo_metadatas();
 
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h
index 12f116d..e1e7cf2 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h
@@ -109,7 +109,6 @@
       const SessionMetadata* session_metadata,
       CompilationRefHolder* per_step_ref_holder, int64* uid,
       std::vector<std::string>* proto_key,
-      std::vector<std::string>* sharding_key,
       std::vector<bool>* may_modify_variables,
       absl::Span<const xla::HloProto* const>* hlo_metadatas,
       const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
@@ -198,7 +197,6 @@
       const SessionMetadata* session_metadata,
       CompilationRefHolder* per_step_ref_holder, int64* uid,
       std::vector<std::string>* proto_key,
-      std::vector<std::string>* sharding_key,
       std::vector<bool>* may_modify_variables,
       std::vector<CompiledSubgraph*>* removed_entries,
       absl::Span<const xla::HloProto* const>* hlo_metadatas,
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc
index 34ea317..ce18e84 100644
--- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc
+++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc
@@ -657,11 +657,10 @@
 
   int64 uid;
   std::vector<std::string> proto_key;
-  std::vector<std::string> sharding_key;
   std::vector<bool> may_modify_variables;
   absl::Span<const xla::HloProto* const> hlo_metadatas;
   Status status = cache->CompileIfKeyAbsent(
-      key, ctx->session_metadata(), ref_holder, &uid, &proto_key, &sharding_key,
+      key, ctx->session_metadata(), ref_holder, &uid, &proto_key,
       &may_modify_variables, &hlo_metadatas,
       [&](TpuProgramGroupInterface* tpu_program_group) {
         VLOG(1) << "Cloud TPU: Compiling TPU program";
@@ -779,21 +778,13 @@
 
   if (status.ok()) {
     for (int i = 0; i < num_cores_with_compiled_programs; ++i) {
-      Tensor output(DT_STRING, TensorShape({3}));
+      Tensor output(DT_STRING, TensorShape({2}));
       if (proto_key.size() == 1) {
         output.vec<tstring>()(0) = proto_key[0];
       } else {
         output.vec<tstring>()(0) = proto_key[i];
       }
       output.vec<tstring>()(1) = rendezvous_key_base;
-      if (sharding_key.empty()) {
-        output.vec<tstring>()(2) = "";
-      } else if (sharding_key.size() == 1) {
-        output.vec<tstring>()(2) = sharding_key[0];
-      } else {
-        TF_RET_CHECK(sharding_key.size() == num_cores_with_compiled_programs);
-        output.vec<tstring>()(2) = sharding_key[i];
-      }
       ctx->set_output(i + 1, output);
     }
     if (!use_mlir_) {
@@ -814,10 +805,9 @@
   } else {
     // Return error in the invalid case.
     for (int i = 0; i < num_computations_; ++i) {
-      Tensor output(DT_STRING, TensorShape({3}));
+      Tensor output(DT_STRING, TensorShape({2}));
       output.vec<tstring>()(0) = "<<NO PROGRAM AS COMPILATION FAILED>>";
       output.vec<tstring>()(1) = "<<NO RENDEZVOUS KEY AS COMPILATION FAILED>>";
-      output.vec<tstring>()(2) = "<<NO SHARDing KEY AS COMPILATION FAILED>>";
       ctx->set_output(i + 1, output);
     }
     if (!use_mlir_) {
diff --git a/tensorflow/core/tpu/kernels/tpu_execute_op.cc b/tensorflow/core/tpu/kernels/tpu_execute_op.cc
index ce69d97..3522ace 100644
--- a/tensorflow/core/tpu/kernels/tpu_execute_op.cc
+++ b/tensorflow/core/tpu/kernels/tpu_execute_op.cc
@@ -72,9 +72,9 @@
   TF_RETURN_IF_ERROR(context->input("key", &key));
   profiler::TraceMe trace_me("TpuExecuteOp::LookupProto", /*level=*/2);
   if (!TensorShapeUtils::IsVector(key->shape()) ||
-      key->shape().dim_size(0) != 3) {
+      key->shape().dim_size(0) != 2) {
     return errors::InvalidArgument(
-        "Key argument to TPUExecute must be a 3-element vector");
+        "Key argument to TPUExecute must be a 2-element vector");
   }
 
   ResourceMgr* rmgr = GetTPUConfigResourceMgr();
diff --git a/tensorflow/core/tpu/ops/tpu_compile_op.cc b/tensorflow/core/tpu/ops/tpu_compile_op.cc
index 038005b..bfd29a7 100644
--- a/tensorflow/core/tpu/ops/tpu_compile_op.cc
+++ b/tensorflow/core/tpu/ops/tpu_compile_op.cc
@@ -40,7 +40,7 @@
       c->set_output(0, c->Scalar());
       // Programs.
       for (int i = 0; i < num_computations; ++i) {
-        c->set_output(i + 1, c->Vector(3));
+        c->set_output(i + 1, c->Vector(2));
       }
       return Status::OK();
     })
@@ -64,7 +64,7 @@
     .SetIsStateful()
     .Output("program: string")
     .SetShapeFn([](shape_inference::InferenceContext* c) {
-      c->set_output(0, c->Vector(3));
+      c->set_output(0, c->Vector(2));
       return Status::OK();
     })
     .SetIsStateful()
@@ -100,7 +100,7 @@
       c->set_output(0, c->Scalar());
       // Programs.
       for (int i = 0; i < num_computations; ++i) {
-        c->set_output(i + 1, c->Vector(3));
+        c->set_output(i + 1, c->Vector(2));
       }
       // May modify variables.
       for (int i = 0; i < num_computations; ++i) {
diff --git a/tensorflow/core/tpu/ops/tpu_execute_op.cc b/tensorflow/core/tpu/ops/tpu_execute_op.cc
index 2e43719..68ddc86 100644
--- a/tensorflow/core/tpu/ops/tpu_execute_op.cc
+++ b/tensorflow/core/tpu/ops/tpu_execute_op.cc
@@ -30,7 +30,7 @@
       shape_inference::ShapeHandle key;
       TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &key));
       shape_inference::DimensionHandle unused;
-      TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 3, &unused));
+      TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 2, &unused));
       for (int i = 0; i < c->num_outputs(); ++i) {
         c->set_output(i, c->UnknownShape());
       }
@@ -50,7 +50,7 @@
       shape_inference::ShapeHandle key;
       TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &key));
       shape_inference::DimensionHandle unused;
-      TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 3, &unused));
+      TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 2, &unused));
       for (int i = 0; i < c->num_outputs(); ++i) {
         c->set_output(i, c->UnknownShape());
       }