Only trigger resharding when necessary.
PiperOrigin-RevId: 332060600
Change-Id: If36314756713071679c20cd6645434c5db53500b
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
index 6ecde66..78be6cc 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc
@@ -537,10 +537,9 @@
// Build a constant default key to specify that the unformatting should
// transform the variables to the original format.
builder.setInsertionPointAfter(while_op);
- tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {3});
+ tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {2});
default_key_tensor.vec<tensorflow::tstring>()(0) = kDefaultShardingValue;
default_key_tensor.vec<tensorflow::tstring>()(1) = kDefaultShardingValue;
- default_key_tensor.vec<tensorflow::tstring>()(2) = kDefaultShardingValue;
auto default_state_key = builder.create<TF::ConstOp>(
while_op.getLoc(),
tensorflow::ConvertTensor(default_key_tensor, &builder).ValueOrDie());
diff --git a/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.cc b/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.cc
index f303588..fad8e22 100644
--- a/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.cc
+++ b/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.cc
@@ -504,10 +504,9 @@
"TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
AddNodeAttr("dtype", DT_STRING, &default_sharding);
- Tensor t(DT_STRING, {3});
+ Tensor t(DT_STRING, {2});
t.vec<tstring>()(0) = kDefaultShardingValue;
t.vec<tstring>()(1) = kDefaultShardingValue;
- t.vec<tstring>()(2) = kDefaultShardingValue;
t.AsProtoTensorContent(
(*default_sharding.mutable_attr())["value"].mutable_tensor());
diff --git a/tensorflow/core/tpu/kernels/compiled_subgraph.h b/tensorflow/core/tpu/kernels/compiled_subgraph.h
index 091d6e7..a97c652 100644
--- a/tensorflow/core/tpu/kernels/compiled_subgraph.h
+++ b/tensorflow/core/tpu/kernels/compiled_subgraph.h
@@ -124,9 +124,6 @@
// Compilation cache proto key to identify the cache entry.
std::vector<std::string> proto_key;
- // Fingerprints of sharding programs if there is any.
- std::vector<std::string> sharding_key;
-
// The number of 'external' client-held references to the entry.
int external_references = 0;
diff --git a/tensorflow/core/tpu/kernels/host_compute_ops.cc b/tensorflow/core/tpu/kernels/host_compute_ops.cc
index 5295c1c..77a7d6f 100644
--- a/tensorflow/core/tpu/kernels/host_compute_ops.cc
+++ b/tensorflow/core/tpu/kernels/host_compute_ops.cc
@@ -58,9 +58,9 @@
OP_REQUIRES_ASYNC(
ctx,
TensorShapeUtils::IsVector(input.shape()) &&
- input.shape().dim_size(0) == 3,
+ input.shape().dim_size(0) == 2,
errors::InvalidArgument("Input shape ", input.shape().DebugString(),
- " is not a vector of length 3."),
+ " is not a vector of length 2."),
done);
const string rendezvous_key_base = input.vec<tstring>()(1);
OP_REQUIRES_ASYNC(
@@ -164,10 +164,10 @@
const Tensor& key_input = ctx->input(ctx->num_inputs() - 1);
OP_REQUIRES(ctx,
TensorShapeUtils::IsVector(key_input.shape()) &&
- key_input.shape().dim_size(0) == 3,
+ key_input.shape().dim_size(0) == 2,
errors::InvalidArgument("Key input shape ",
key_input.shape().DebugString(),
- " is not a vector of length 3."));
+ " is not a vector of length 2."));
const string rendezvous_key_base = key_input.vec<tstring>()(1);
OP_REQUIRES(
ctx, ctx->rendezvous() != nullptr,
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc
index 0d0a2ae..1928303 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc
@@ -362,15 +362,14 @@
const TpuCompilationCacheKey& subgraph_key,
const SessionMetadata* session_metadata,
CompilationRefHolder* per_step_ref_holder, int64* uid,
- std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
+ std::vector<std::string>* proto_key,
std::vector<bool>* may_modify_variables,
absl::Span<const xla::HloProto* const>* hlo_metadatas,
const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
std::vector<CompiledSubgraph*> removed_entries;
auto status = CompileIfKeyAbsentHelper(
subgraph_key, session_metadata, per_step_ref_holder, uid, proto_key,
- sharding_key, may_modify_variables, &removed_entries, hlo_metadatas,
- compile_function);
+ may_modify_variables, &removed_entries, hlo_metadatas, compile_function);
for (auto entry : removed_entries) {
UnloadAndDestroy(entry);
}
@@ -400,7 +399,7 @@
const TpuCompilationCacheKey& subgraph_key,
const SessionMetadata* session_metadata,
CompilationRefHolder* per_step_ref_holder, int64* uid,
- std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
+ std::vector<std::string>* proto_key,
std::vector<bool>* may_modify_variables,
std::vector<CompiledSubgraph*>* removed_entries,
absl::Span<const xla::HloProto* const>* hlo_metadatas,
@@ -498,7 +497,6 @@
*uid = entry->uid;
// Let the caller know the keys for each of the cached protos.
*proto_key = entry->proto_key;
- *sharding_key = entry->sharding_key;
*may_modify_variables = entry->tpu_program_group->may_modify_variables_list();
*hlo_metadatas = entry->tpu_program_group->hlo_metadatas();
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h
index 12f116d..e1e7cf2 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h
@@ -109,7 +109,6 @@
const SessionMetadata* session_metadata,
CompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<std::string>* proto_key,
- std::vector<std::string>* sharding_key,
std::vector<bool>* may_modify_variables,
absl::Span<const xla::HloProto* const>* hlo_metadatas,
const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
@@ -198,7 +197,6 @@
const SessionMetadata* session_metadata,
CompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<std::string>* proto_key,
- std::vector<std::string>* sharding_key,
std::vector<bool>* may_modify_variables,
std::vector<CompiledSubgraph*>* removed_entries,
absl::Span<const xla::HloProto* const>* hlo_metadatas,
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc
index 34ea317..ce18e84 100644
--- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc
+++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc
@@ -657,11 +657,10 @@
int64 uid;
std::vector<std::string> proto_key;
- std::vector<std::string> sharding_key;
std::vector<bool> may_modify_variables;
absl::Span<const xla::HloProto* const> hlo_metadatas;
Status status = cache->CompileIfKeyAbsent(
- key, ctx->session_metadata(), ref_holder, &uid, &proto_key, &sharding_key,
+ key, ctx->session_metadata(), ref_holder, &uid, &proto_key,
&may_modify_variables, &hlo_metadatas,
[&](TpuProgramGroupInterface* tpu_program_group) {
VLOG(1) << "Cloud TPU: Compiling TPU program";
@@ -779,21 +778,13 @@
if (status.ok()) {
for (int i = 0; i < num_cores_with_compiled_programs; ++i) {
- Tensor output(DT_STRING, TensorShape({3}));
+ Tensor output(DT_STRING, TensorShape({2}));
if (proto_key.size() == 1) {
output.vec<tstring>()(0) = proto_key[0];
} else {
output.vec<tstring>()(0) = proto_key[i];
}
output.vec<tstring>()(1) = rendezvous_key_base;
- if (sharding_key.empty()) {
- output.vec<tstring>()(2) = "";
- } else if (sharding_key.size() == 1) {
- output.vec<tstring>()(2) = sharding_key[0];
- } else {
- TF_RET_CHECK(sharding_key.size() == num_cores_with_compiled_programs);
- output.vec<tstring>()(2) = sharding_key[i];
- }
ctx->set_output(i + 1, output);
}
if (!use_mlir_) {
@@ -814,10 +805,9 @@
} else {
// Return error in the invalid case.
for (int i = 0; i < num_computations_; ++i) {
- Tensor output(DT_STRING, TensorShape({3}));
+ Tensor output(DT_STRING, TensorShape({2}));
output.vec<tstring>()(0) = "<<NO PROGRAM AS COMPILATION FAILED>>";
output.vec<tstring>()(1) = "<<NO RENDEZVOUS KEY AS COMPILATION FAILED>>";
- output.vec<tstring>()(2) = "<<NO SHARDing KEY AS COMPILATION FAILED>>";
ctx->set_output(i + 1, output);
}
if (!use_mlir_) {
diff --git a/tensorflow/core/tpu/kernels/tpu_execute_op.cc b/tensorflow/core/tpu/kernels/tpu_execute_op.cc
index ce69d97..3522ace 100644
--- a/tensorflow/core/tpu/kernels/tpu_execute_op.cc
+++ b/tensorflow/core/tpu/kernels/tpu_execute_op.cc
@@ -72,9 +72,9 @@
TF_RETURN_IF_ERROR(context->input("key", &key));
profiler::TraceMe trace_me("TpuExecuteOp::LookupProto", /*level=*/2);
if (!TensorShapeUtils::IsVector(key->shape()) ||
- key->shape().dim_size(0) != 3) {
+ key->shape().dim_size(0) != 2) {
return errors::InvalidArgument(
- "Key argument to TPUExecute must be a 3-element vector");
+ "Key argument to TPUExecute must be a 2-element vector");
}
ResourceMgr* rmgr = GetTPUConfigResourceMgr();
diff --git a/tensorflow/core/tpu/ops/tpu_compile_op.cc b/tensorflow/core/tpu/ops/tpu_compile_op.cc
index 038005b..bfd29a7 100644
--- a/tensorflow/core/tpu/ops/tpu_compile_op.cc
+++ b/tensorflow/core/tpu/ops/tpu_compile_op.cc
@@ -40,7 +40,7 @@
c->set_output(0, c->Scalar());
// Programs.
for (int i = 0; i < num_computations; ++i) {
- c->set_output(i + 1, c->Vector(3));
+ c->set_output(i + 1, c->Vector(2));
}
return Status::OK();
})
@@ -64,7 +64,7 @@
.SetIsStateful()
.Output("program: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
- c->set_output(0, c->Vector(3));
+ c->set_output(0, c->Vector(2));
return Status::OK();
})
.SetIsStateful()
@@ -100,7 +100,7 @@
c->set_output(0, c->Scalar());
// Programs.
for (int i = 0; i < num_computations; ++i) {
- c->set_output(i + 1, c->Vector(3));
+ c->set_output(i + 1, c->Vector(2));
}
// May modify variables.
for (int i = 0; i < num_computations; ++i) {
diff --git a/tensorflow/core/tpu/ops/tpu_execute_op.cc b/tensorflow/core/tpu/ops/tpu_execute_op.cc
index 2e43719..68ddc86 100644
--- a/tensorflow/core/tpu/ops/tpu_execute_op.cc
+++ b/tensorflow/core/tpu/ops/tpu_execute_op.cc
@@ -30,7 +30,7 @@
shape_inference::ShapeHandle key;
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &key));
shape_inference::DimensionHandle unused;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 3, &unused));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 2, &unused));
for (int i = 0; i < c->num_outputs(); ++i) {
c->set_output(i, c->UnknownShape());
}
@@ -50,7 +50,7 @@
shape_inference::ShapeHandle key;
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &key));
shape_inference::DimensionHandle unused;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 3, &unused));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), 2, &unused));
for (int i = 0; i < c->num_outputs(); ++i) {
c->set_output(i, c->UnknownShape());
}