Add new Tensor List tests for auto_mixed_precision
- Updates TensorListPushBackBatchAndConcatLists to check that all
nodes are now converted to fp16.
- Adds TensorListThroughFunction test to check that Tensor Lists that
pass through sub-graphs are safely handled.
diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc
index 0f48ae9..2d1d440 100644
--- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc
+++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc
@@ -28,6 +28,7 @@
#include "tensorflow/cc/ops/list_ops.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/clusters/single_machine.h"
@@ -905,11 +906,96 @@
VLOG(1) << output.DebugString();
GraphView output_view(&output);
- // TODO(benbarsdell): Add checks for data type conversion here once support
- // for TensorListPushBackBatch and TensorListConcatLists is added in the
- // auto_mixed_precision pass.
+ EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
+ const char* type_key = "element_dtype";
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
+ EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
+ EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
+ EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
+ EXPECT_EQ(output_view.GetNode("tl3")->attr().at(type_key).type(), DT_HALF);
+ EXPECT_EQ(output_view.GetNode("tl3r1")->attr().at(type_key).type(), DT_HALF);
+
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(tensors.size(), tensors_expected.size());
+ EXPECT_EQ(tensors.size(), item.fetch.size());
+ for (int i = 0; i < item.fetch.size(); ++i) {
+ test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
+ }
+}
+
+TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
+ // This test passes a tensor list handle through a function with its own
+ // Tensor List ops inside to test that the types are not changed to a
+ // conflicting state.
+ // A separate Tensor List cluster is added to test that it is still changed to
+ // DT_HALF.
+ FunctionDefLibrary function_lib;
+ const Tensor kShape = test::AsTensor<int32>({32, 32});
+ FunctionDef func1 = FunctionDefHelper::Define(
+ "Func1", {"ihandle: variant", "x: float"},
+ {"ohandle: variant", "y: float"}, {},
+ {
+ {{"tl1w1_handle"},
+ "TensorListPushBack",
+ {"ihandle", "x"},
+ {{"element_dtype", DT_FLOAT}}},
+ {{"shape"}, "Const", {}, {{"value", kShape}, {"dtype", DT_INT32}}},
+ {{"tl1r1_handle", "tl1r1_data"},
+ "TensorListPopBack",
+ {"tl1w1_handle", "shape"},
+ {{"element_dtype", DT_FLOAT}}},
+ {{"ohandle"}, "Identity", {"tl1r1_handle"}, {{"T", DT_VARIANT}}},
+ {{"y"}, "Identity", {"tl1r1_data"}, {{"T", DT_FLOAT}}},
+ });
+ function_lib.add_function()->Swap(&func1);
+
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib));
+ tensorflow::Input shape = {32, 32};
+ Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
+ Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
+ Output gry1 = ops::Tanh(s.WithOpName("gry1"), wht1);
+ auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
+ auto tl1w1 = ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, gry1);
+ auto _gry1 = tensorflow::ops::AsNodeOut(s, gry1);
+ auto _tl1w1_handle = tensorflow::ops::AsNodeOut(s, tl1w1.output_handle);
+ auto builder =
+ tensorflow::NodeBuilder("Func1", "Func1", s.graph()->op_registry());
+ tensorflow::Node* func1_op;
+ TF_CHECK_OK(
+ builder.Input(_tl1w1_handle).Input(_gry1).Finalize(s.graph(), &func1_op));
+ Output func1_handle(func1_op, 0);
+ Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"), func1_handle,
+ shape, DT_FLOAT)
+ .tensor;
+ auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
+ auto tl2w1 = ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, gry1);
+ Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
+ tl2w1.output_handle, shape, DT_FLOAT)
+ .tensor;
+ Output wht2 = ops::MatMul(s.WithOpName("wht2"), tl1r1, tl2r1);
+ Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht2);
+
+ GrapplerItem item;
+ item.fetch = {"fetch1"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ AutoMixedPrecision optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
+
+ VLOG(1) << output.DebugString();
+
+ GraphView output_view(&output);
+ const char* type_key = "element_dtype";
+ EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
+ EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
+ EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
+ EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
+ EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
+ EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_HALF);
auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(tensors.size(), tensors_expected.size());