[TF:TRT] Implement cast from fp16 to fp32 with IIdentityLayer.

This is the first CL to implement the request in b/150285802.

Add Cast op test to convert_nodes_test.

PiperOrigin-RevId: 312093049
Change-Id: I77215cf6da104f51acc93de1b03e9a179db54f0a
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
index a43b16e..e791ff9 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
@@ -29,6 +29,7 @@
 #include "absl/memory/memory.h"
 #include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
 #include "absl/strings/string_view.h"
 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
@@ -795,6 +796,19 @@
   }
 }
 
+Status TRT_TensorOrWeights::GetTfType(DataType* tf_type) const {
+  if (is_tensor()) {
+    nvinfer1::DataType trt_type = tensor()->getType();
+    return TrtTypeToTfType(trt_type, tf_type);
+  }
+
+  if (is_weights()) {
+    *tf_type = weights().GetTensor().dtype();
+    return Status::OK();
+  }
+  return errors::Internal("The object is probably not initialized");
+}
+
 string TRT_TensorOrWeights::DebugString() const {
   string output = "TRT_TensorOrWeights(type=";
   if (is_tensor()) {
@@ -1900,27 +1914,48 @@
   return Status::OK();
 }
 
-Status AllowDataTypes(const OpConverterParams& params,
-                      const std::set<DataType>& allowed_dtypes,
-                      const char* dtype_attr_name = "T") {
-  const auto& node_def = params.node_def;
+Status GetNodeDefTfType(const NodeDef& node_def, DataType* tf_type,
+                        const char* type_attr_name) {
   TFAttrs attrs(node_def);
-  if (!attrs.count(dtype_attr_name)) {
-    return errors::InvalidArgument("Attribute with name ", dtype_attr_name,
+  if (!attrs.count(type_attr_name)) {
+    return errors::InvalidArgument("Attribute with name ", type_attr_name,
                                    " not found.");
   }
-  const auto op_dtype = attrs.get<DataType>(dtype_attr_name);
-  if (!allowed_dtypes.count(op_dtype)) {
-    // Build string list of allowed types.
-    std::ostringstream ss;
-    for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) {
-      if (it != allowed_dtypes.begin()) ss << ", ";
-      ss << DataTypeString(*it);
-    }
-    return errors::Unimplemented("Data type ", DataTypeString(op_dtype),
+  *tf_type = attrs.get<DataType>(type_attr_name);
+  return Status::OK();
+}
+
+Status GetInputTfType(const OpConverterParams& params, DataType* tf_type,
+                      int pos) {
+  const std::vector<TRT_TensorOrWeights>& inputs = params.inputs;
+  if (inputs.size() <= pos) {
+    return errors::Internal("Invalid input position");
+  }
+
+  return inputs[pos].GetTfType(tf_type);
+}
+
+constexpr const char kOutputTypeAttrName[] = "T";
+
+Status GetOutputTfType(const OpConverterParams& params, DataType* tf_type) {
+  return GetNodeDefTfType(params.node_def, tf_type, kOutputTypeAttrName);
+}
+
+Status AllowDataTypes(const OpConverterParams& params,
+                      const std::set<DataType>& allowed_types,
+                      const char* type_attr_name = kOutputTypeAttrName) {
+  const auto& node_def = params.node_def;
+  DataType tf_type;
+  TF_RETURN_IF_ERROR(GetNodeDefTfType(node_def, &tf_type, type_attr_name));
+  if (!allowed_types.count(tf_type)) {
+    string allowed_types_string = absl::StrJoin(
+        allowed_types, ", ", [](string* out, const DataType& type) {
+          absl::StrAppendFormat(out, "%s", DataTypeString(type));
+        });
+    return errors::Unimplemented("Data type ", DataTypeString(tf_type),
                                  " is not supported for ", node_def.op(),
-                                 ", must be one of [", ss.str(), "], at ",
-                                 node_def.name());
+                                 ", must be one of [", allowed_types_string,
+                                 "], at ", node_def.name());
   }
   return Status::OK();
 }
@@ -4598,6 +4633,42 @@
   return ConvertSplitHelper(params, inputs.at(0), tf_axis, num, true);
 }
 
+// Supports cast fp16=>fp32 through IIdentityLayer.
+Status ConvertCast(OpConverterParams* params) {
+  const NodeDef& node_def = params->node_def;
+  TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
+  auto unsupport_cast_error = [&]() {
+    return errors::Unimplemented("Cast op: ", node_def.op(),
+                                 " not supported at: ", node_def.name());
+  };
+
+  DataType input_type;
+  TF_RETURN_IF_ERROR(GetInputTfType(*params, &input_type, 0));
+  if (input_type != DataType::DT_HALF) {
+    return unsupport_cast_error();
+  }
+
+  DataType output_type;
+  TF_RETURN_IF_ERROR(GetOutputTfType(*params, &output_type));
+  if (output_type != DataType::DT_FLOAT) {
+    return unsupport_cast_error();
+  }
+
+  if (params->validation_only) return Status::OK();
+
+  nvinfer1::ITensor* input = params->inputs.at(0).tensor();
+  nvinfer1::IIdentityLayer* layer =
+      params->converter->network()->addIdentity(*input);
+  layer->setPrecision(nvinfer1::DataType::kFLOAT);
+
+  if (layer->getOutput(0)->getType() != nvinfer1::DataType::kFLOAT) {
+    return errors::Internal("IIdentityLayer doesn't work as expected");
+  }
+
+  params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
+  return Status::OK();
+}
+
 Status ConvertConcat(OpConverterParams* params) {
   const auto& inputs = params->inputs;
   const auto& node_def = params->node_def;
@@ -5675,6 +5746,7 @@
   (*registration)["CombinedNonMaxSuppression"] = ConvertCombinedNMS;
 #endif
   (*registration)["AddN"] = ConvertAddN;
+  (*registration)["Cast"] = ConvertCast;
   (*registration)["ConcatV2"] = ConvertConcat;
   (*registration)["Const"] = ConvertConst;
   (*registration)["Conv2D"] = ConvertConv2D;
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
index 2092aec..2fe8eec 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
@@ -294,6 +294,8 @@
 
   nvinfer1::Dims GetTrtDims() const;
 
+  Status GetTfType(DataType* tf_type) const;
+
   int batch_size() const { return batch_size_; }
 
   string DebugString() const;
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
index 964370a..1efc31f 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
@@ -5147,6 +5147,14 @@
   return T(s.WithOpName("my_unary"), input).operation.node()->def();
 }
 
+NodeDef CreateCastOp() {
+  Scope s = Scope::NewRootScope();
+  auto input = ops::Placeholder(s.WithOpName("input"), DT_HALF);
+  return ops::Cast(s.WithOpName("my_unary"), input, DT_FLOAT)
+      .operation.node()
+      ->def();
+}
+
 TEST_P(ParameterizedOpConverterTest, ConvertUnary) {
   const auto& spec = GetParam();
   const TrtTestMode trt_mode = std::get<0>(spec);
@@ -5174,6 +5182,7 @@
   ADD_OP("Asinh", ops::Asinh, std::asinh);
   ADD_OP("Atan", ops::Atan, std::atan);
   ADD_OP("Atanh", ops::Atanh, std::atanh);
+  op_map["Cast"] = std::make_pair(CreateCastOp, [](float x) { return x; });
   ADD_OP("Ceil", ops::Ceil, std::ceil);
   ADD_OP("Cos", ops::Cos, std::cos);
   ADD_OP("Cosh", ops::Cosh, std::cosh);
@@ -5212,7 +5221,13 @@
     }
     NodeDef node_def = op_map[op_name].first();
 
-    AddTestTensor("input", p.input_dims, TfDataTypeToTrt(tf_dtype), trt_mode);
+    // TODO(bixia): we assume this test is only instantiated for DT_FLOAT for
+    // now. Need to find a better way to express input and output types.
+    DataType input_tf_dtype = op_name == "Cast" ? DT_HALF : tf_dtype;
+    DataType output_tf_dtype = tf_dtype;
+
+    AddTestTensor("input", p.input_dims, TfDataTypeToTrt(input_tf_dtype),
+                  trt_mode);
     RunValidationAndConversion(node_def, Status::OK(), "my_unary",
                                p.expected_output_dims);
 
@@ -5220,8 +5235,8 @@
     std::vector<float> output;
     std::transform(input_values.begin(), input_values.end(),
                    std::back_inserter(output), op_map[op_name].second);
-    InstantiateBuildAndRun(tf_dtype, "my_unary", this, p, input_values,
-                           ArrayFloatNear(output, 0.0001, true));
+    InstantiateBuildAndRun(input_tf_dtype, output_tf_dtype, "my_unary", this, p,
+                           input_values, ArrayFloatNear(output, 0.0001, true));
   }
 }