Pfor: Omit If ops in Gather's converter if the conditions are known statically
A downstream set_shape in pfor (setting vectorized_shape[1:] to original_shape) was causing issues for the branch of the cond statically known to not be taken.
There are some alternatives: we could remove the set_shape even though it's correct, or tweak If's gradient to discard extra shape information. Getting rid of the mismatch by omitting the branch known not to be taken (and therefore with conflicting shape information) seems simplest.
PiperOrigin-RevId: 391368381
Change-Id: I36affb98cf917a4236865d885b8fc9f8cd660e92
diff --git a/tensorflow/python/ops/parallel_for/array_test.py b/tensorflow/python/ops/parallel_for/array_test.py
index cd3b69d..80baefc 100644
--- a/tensorflow/python/ops/parallel_for/array_test.py
+++ b/tensorflow/python/ops/parallel_for/array_test.py
@@ -82,6 +82,15 @@
self._test_loop_fn(loop_fn, 3)
+ @test_util.run_v2_only
+ def test_gather_pfor_grad(self):
+ x = array_ops.zeros([1, 2])
+ with backprop.GradientTape() as tape:
+ tape.watch(x)
+ r = pfor_control_flow_ops.vectorized_map(
+ lambda t: array_ops.gather(x, t, axis=-1), math_ops.range(2))
+ self.assertAllClose([[1., 1.]], tape.gradient(r, x))
+
def test_shape(self):
x = random_ops.random_uniform([3, 2, 3])
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index 60a5185..4cbe11c 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -2490,15 +2490,15 @@
param, indices, validate_indices=validate_indices, axis=axis,
batch_dims=batch_dims)
if axis != 0:
- axis = control_flow_ops.cond(axis < 0,
+ axis = smart_cond.smart_cond(axis < 0,
lambda: axis + array_ops.rank(param),
- lambda: axis)
+ lambda: ops.convert_to_tensor(axis))
order = array_ops.concat(
[[axis],
math_ops.range(axis),
math_ops.range(axis + 1, array_ops.rank(output))],
axis=0)
- output = control_flow_ops.cond(
+ output = smart_cond.smart_cond(
math_ops.equal(axis, 0), lambda: output,
lambda: array_ops.transpose(output, order))
return wrap(output, True)