Replace len(shape) with shape.rank since length of an unknown rank
TensorShape is undefined.
PiperOrigin-RevId: 297646567
Change-Id: I3d830ab61e9872e40f195d166554fc3e74662f5d
diff --git a/tensorflow/python/keras/engine/compile_utils.py b/tensorflow/python/keras/engine/compile_utils.py
index 85ea00d..ed96cfd 100644
--- a/tensorflow/python/keras/engine/compile_utils.py
+++ b/tensorflow/python/keras/engine/compile_utils.py
@@ -562,14 +562,10 @@
def match_dtype_and_rank(y_t, y_p, sw):
"""Match dtype and rank of predictions."""
- # Rank.
- y_t_rank = len(y_t.shape)
- y_p_rank = len(y_p.shape)
- if y_t_rank == 1 and y_p_rank == 2:
+ if y_t.shape.rank == 1 and y_p.shape.rank == 2:
y_t = array_ops.expand_dims_v2(y_t, axis=-1)
if sw is not None:
- sw_rank = len(sw.shape)
- if sw_rank == 1 and y_p_rank == 2:
+ if sw.shape.rank == 1 and y_p.shape.rank == 2:
sw = array_ops.expand_dims_v2(sw, axis=-1)
# Dtype.
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index 546e201..e4cc415 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -1177,8 +1177,8 @@
self.built = True
def call(self, inputs):
- rank = len(inputs.shape)
- if rank > 2:
+ rank = inputs.shape.rank
+ if rank is not None and rank > 2:
# Broadcasting is required for the inputs.
outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]])
# Reshape the output back to the original ndim of the input.