Get XlaCompile attr errors out of all of our gradient stack traces
Python helpfully appends the error from the "try" portion to new errors from calling grad_fn in "except", but it's 100% irrelevant in this case.
PiperOrigin-RevId: 335689319
Change-Id: Ibd12257e585d7d28eb720e217aeff7a4b5915664
diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py
index 9784650..4d4df0f 100644
--- a/tensorflow/python/ops/gradients_util.py
+++ b/tensorflow/python/ops/gradients_util.py
@@ -333,7 +333,7 @@
"_XlaSeparateCompiledGradients")
xla_scope = op.get_attr("_XlaScope").decode()
except ValueError:
- return grad_fn() # Exit early
+ xla_compile = False
if not xla_compile:
return grad_fn() # Exit early