Added a "note" in tf.where documentation suggesting a workaround for issue #38349
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index e9f32de..18cc7d3 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -4487,6 +4487,21 @@
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([100, 100, 100, 100],
dtype=int32)>
+ Note that if the gradient of either branch of the tf.where generates
+ a NaN, then the gradient of the entire tf.where will be NaN.
+ A workaround is to use an inner tf.where to ensure the function has
+ no asymptote, and to avoid computing a value whose gradient is NaN by
+ replacing dangerous inputs with safe inputs.
+
+ Instead of this
+
+ >>> y = -1
+ >>> tf.where(y > 0, tf.sqrt(y), y)
+
+ Use this
+
+ >>> tf.where(y > 0, tf.sqrt(tf.where(y > 0, y, 1)), y)
+
Args:
condition: A `tf.Tensor` of type `bool`
x: If provided, a Tensor which is of the same type as `y`, and has a shape