Allow trivial tensor list ops in auto_mixed_precision
- These ops do not have data type attributes and so do not require any
special handling.
diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc
index a49215e..80f35cd 100644
--- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc
+++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc
@@ -1203,9 +1203,9 @@
VLOG(2) << "Building node type map for graph";
TF_RETURN_IF_ERROR(node_type_map_.Init(*graph_));
- // Note: If an op is added to this list, it should also be added to the
- // AddDataStructureOpsToMap call below (and to the clearlist if it involves
- // data flow).
+ // Note: If an op is added to this list that has a data type attribute, it
+ // should also be added to the AddDataStructureOpsToMap call below (and to the
+ // clearlist if it involves data flow).
// TODO(benbarsdell): Add support for TensorListPushBackBatch and
// TensorListConcatLists. They require special handling because they connect
// multiple list objects together. Currently if they appear in the graph then
@@ -1227,7 +1227,10 @@
"TensorListConcat",
"TensorListConcatV2",
"TensorListGetItem",
- "TensorListGather"};
+ "TensorListGather",
+ "TensorListLength",
+ "TensorListElementShape",
+ "TensorListResize"};
bool can_change_tensor_list_ops = true;
for (const NodeDef& node : graph_->node()) {
diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc
index 942abe5..bc51031 100644
--- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc
+++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc
@@ -628,10 +628,12 @@
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
auto tl1w2 =
ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, wht1);
- Output tl1r1 =
- ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1w2.output_handle, idx2,
- shape, DT_FLOAT)
- .item;
+ // Ensure that TensorListResize doesn't cause any problems.
+ Output tl1rs =
+ ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
+ Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
+ shape, DT_FLOAT)
+ .item;
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
auto tl1w3 =