Do not return an error on unknown input shape for AllToAll op
Use all information available to determine the output shape for the operation, but do not require fixed shape always. Instead, set the output shape to unknown if the input shape in unknown as dynamic is not an error, merely partial result. This also allows for refinement of shapes later post optimizations.
PiperOrigin-RevId: 366848572
Change-Id: I1001e5336e9bf22509dc49aa4ad75ec265c7da40
diff --git a/tensorflow/core/ops/tpu_cross_replica_ops.cc b/tensorflow/core/ops/tpu_cross_replica_ops.cc
index 1f10fe3..8e00c4a 100644
--- a/tensorflow/core/ops/tpu_cross_replica_ops.cc
+++ b/tensorflow/core/ops/tpu_cross_replica_ops.cc
@@ -32,16 +32,15 @@
.Attr("split_count: int")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
- int64 rank;
- if (c->RankKnown(input)) {
- rank = c->Rank(input);
- } else {
- return errors::InvalidArgument("input's rank is unknown.");
+ if (!c->RankKnown(input)) {
+ c->set_output(0, c->UnknownShape());
+ return Status::OK();
}
+
+ int64 rank = c->Rank(input);
int concat_dimension;
int split_dimension;
int split_count;
-
TF_RETURN_IF_ERROR(c->GetAttr("split_count", &split_count));
TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension", &concat_dimension));