Do not rely on dropout implementation in auto_mixed_precision_test
PiperOrigin-RevId: 280702474
Change-Id: Iadd456f9ab81ab627b6e67cef2dad0d17d55eb6e
diff --git a/tensorflow/python/grappler/auto_mixed_precision_test.py b/tensorflow/python/grappler/auto_mixed_precision_test.py
index 4b43660..e47a62e 100644
--- a/tensorflow/python/grappler/auto_mixed_precision_test.py
+++ b/tensorflow/python/grappler/auto_mixed_precision_test.py
@@ -473,6 +473,7 @@
x = _input([2, 8, 8, 1])
y = _conv_bn(x)
y = nn.dropout(y, rate=0.5)
+ y = math_ops.add(y, 1, name='addition')
y = _conv_bn(y)
y = array_ops.identity(y)
optimizer = gradient_descent.GradientDescentOptimizer(
@@ -484,11 +485,13 @@
node_map = _build_node_map(cost_graph.node)
self._assert_output_fp16(node_map, 'Conv2D')
self._assert_output_fp16(node_map, 'FusedBatchNormV3')
- self._assert_output_fp16(node_map, 'dropout/mul')
+ # We do not assert dropout's dtype because we do not want to rely on the
+ # node names of dropout's internal implementation.
+ self._assert_output_fp16(node_map, 'addition')
self._assert_output_fp16(node_map, 'Conv2D_1')
output_val_ref, output_val, cost_graph = self._run(output)
- self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
+ self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=2e-3)
@test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA')