tf.signal: If frame_length is statically known, make frame_step statically known.
This enables tf.signal.mdct support on TPU, since otherwise XLA cannot statically determine the output shape due to frame_step not being known.
PiperOrigin-RevId: 289789563
Change-Id: I1b5c56d5cd8fe11d8b972069a9e760891e84cd77
diff --git a/tensorflow/python/ops/signal/spectral_ops.py b/tensorflow/python/ops/signal/spectral_ops.py
index 9963882..8fd3ca4 100644
--- a/tensorflow/python/ops/signal/spectral_ops.py
+++ b/tensorflow/python/ops/signal/spectral_ops.py
@@ -329,9 +329,13 @@
frame_length.shape.assert_has_rank(0)
# Assert that frame_length is divisible by 4.
frame_length_static = tensor_util.constant_value(frame_length)
- if frame_length_static is not None and frame_length_static % 4 != 0:
- raise ValueError('The frame length must be a multiple of 4.')
- frame_step = frame_length // 2
+ if frame_length_static is not None:
+ if frame_length_static % 4 != 0:
+ raise ValueError('The frame length must be a multiple of 4.')
+ frame_step = ops.convert_to_tensor(frame_length_static // 2,
+ dtype=frame_length.dtype)
+ else:
+ frame_step = frame_length // 2
framed_signals = shape_ops.frame(
signals, frame_length, frame_step, pad_end=pad_end)