Add support for input list in MLIR tracing API

PiperOrigin-RevId: 327081439
Change-Id: I354e1daae7de9c49ea9280446a3cee96ab544c25
diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc
index cb1f997..80b1f15 100644
--- a/tensorflow/c/eager/gradients_test.cc
+++ b/tensorflow/c/eager/gradients_test.cc
@@ -507,20 +507,18 @@
   result_tensor = nullptr;
 }
 
-// TODO(b/160888630): Enable this test with mlir after AddInputList is
-// supported. It is needed for IdentityN.
 // TODO(b/164171226): Enable this test with tfrt after AddInputList is
 // supported. It is needed for IdentityN.
 #ifdef PLATFORM_GOOGLE
 INSTANTIATE_TEST_SUITE_P(
     UnifiedCAPI, CppGradients,
-    ::testing::Combine(::testing::Values("graphdef"),
+    ::testing::Combine(::testing::Values("graphdef", "mlir"),
                        /*tfrt*/ ::testing::Values(false),
                        /*executing_eagerly*/ ::testing::Values(true, false)));
 #else
 INSTANTIATE_TEST_SUITE_P(
     UnifiedCAPI, CppGradients,
-    ::testing::Combine(::testing::Values("graphdef"),
+    ::testing::Combine(::testing::Values("graphdef", "mlir"),
                        /*tfrt*/ ::testing::Values(false),
                        /*executing_eagerly*/ ::testing::Values(true, false)));
 #endif
diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
index 343cbce..c62d62a 100644
--- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
+++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
@@ -299,7 +299,7 @@
         "op_type must be specified before specifying attrs.");
   Type mlir_type;
   Builder builder(context_);
-  TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder, &mlir_type));
+  TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &mlir_type));
   attrs_[attr_name] = TypeAttr::get(mlir_type);
   return Status::OK();
 }
@@ -367,12 +367,12 @@
                                  "' required for output '", output_arg.name(),
                                  "' isn't a type attribute");
         for (int i = 0; i < repeats; ++i)
-          state_->types.push_back(type_attr.getType());
+          state_->types.push_back(UnrankedTensorType::get(type_attr.getType()));
       } else if (output_arg.type() != tensorflow::DT_INVALID) {
         for (int i = 0; i < repeats; ++i) {
           Type type;
           TF_RETURN_IF_ERROR(
-              ConvertDataTypeToTensor(output_arg.type(), builder, &type));
+              ConvertDataType(output_arg.type(), builder, &type));
           state_->types.push_back(type);
         }
       } else {
@@ -390,7 +390,7 @@
         return InvalidArgument("Attribute '", output_arg.type_attr(),
                                "' required for output '", output_arg.name(),
                                "' isn't a type attribute");
-      state_->types.push_back(type_attr.getValue());
+      state_->types.push_back(UnrankedTensorType::get(type_attr.getValue()));
     } else if (!output_arg.type_list_attr().empty()) {
       // This is pointing to an attribute which is an array of types.
       Attribute attr = attrs_[output_arg.type_list_attr()];
@@ -410,13 +410,12 @@
                                  output_arg.type_list_attr(),
                                  "' required for output '", output_arg.name(),
                                  "' has a non-Type element");
-        state_->types.push_back(type_attr.getValue());
+        state_->types.push_back(UnrankedTensorType::get(type_attr.getValue()));
       }
     } else if (output_arg.type() != tensorflow::DT_INVALID) {
       Type type;
       Builder builder(context_);
-      TF_RETURN_IF_ERROR(
-          ConvertDataTypeToTensor(output_arg.type(), builder, &type));
+      TF_RETURN_IF_ERROR(ConvertDataType(output_arg.type(), builder, &type));
       state_->types.push_back(type);
     } else {
       return InvalidArgument("No type fields in ",
@@ -446,12 +445,6 @@
   return Status::OK();
 }
 
-Status MlirAbstractOp::AddInputList(
-    absl::Span<AbstractTensorHandle* const> inputs) {
-  return tensorflow::errors::Unimplemented(
-      "AddInputList has not been implemented yet.");
-}
-
 Status MlirAbstractOp::SetAttrString(const char* attr_name, const char* data,
                                      size_t length) {
   return Unimplemented("SetAttrString has not been implemented yet.");
@@ -589,13 +582,65 @@
       expected_type = output_type;
     }
   } else {
-    expected_type = operands_.back().getType();
+    expected_type = cast<MlirTensor>(input)->getElementType();
   }
   if (!arg_def.type_attr().empty())
     attrs_[arg_def.type_attr()] = TypeAttr::get(expected_type);
 
   return Status::OK();
 }
+
+Status MlirAbstractOp::AddInputList(
+    absl::Span<AbstractTensorHandle* const> inputs) {
+  if (current_ods_input_ >= op_def_->input_arg_size())
+    return InvalidArgument(
+        absl::StrCat("More Input() (", current_ods_input_, ") calls than the ",
+                     op_def_->input_arg_size(), " allowed input_args"));
+
+  for (AbstractTensorHandle* input : inputs) {
+    auto* operand = dyn_cast<MlirTensor>(input);
+    if (!operand) return InvalidArgument("Unable to cast input to MlirTensor");
+    operands_.push_back(operand->getValue());
+  }
+
+  // Get the next ArgDef and use it to infer the derived attributes associated
+  // to this input.
+  const tensorflow::OpDef::ArgDef& arg_def =
+      op_def_->input_arg(current_ods_input_++);
+  if (!arg_def.number_attr().empty()) {
+    Builder builder(context_);
+    attrs_[arg_def.number_attr()] = builder.getI32IntegerAttr(inputs.size());
+    // TODO(aminim): handle ref variable.
+    if (arg_def.type() != tensorflow::DT_INVALID) {
+      // TODO(aminim): check type wrt input
+      Type arg_def_type;
+      TF_RETURN_IF_ERROR(
+          ConvertDataType(arg_def.type(), builder, &arg_def_type));
+      // Ensure each of the type in the list matches the op def type.
+      // TODO(aminim): can we improve the error message with the actual types?
+      for (AbstractTensorHandle* input : inputs)
+        if (arg_def_type != cast<MlirTensor>(input)->getElementType())
+          return InvalidArgument(
+              "Invalid input list: type mismatch the op def expectation");
+    } else if (!inputs.empty()) {
+      if (arg_def.type_attr().empty())
+        return FailedPrecondition(
+            "Invalid opdef type constraint: either type or type_attr required");
+
+      attrs_[arg_def.type_attr()] =
+          TypeAttr::get(cast<MlirTensor>(inputs.front())->getElementType());
+    }
+  } else if (!arg_def.type_list_attr().empty()) {
+    // TODO(aminim): handle ref variable.
+    SmallVector<Attribute, 8> types;
+    types.reserve(inputs.size());
+    for (AbstractTensorHandle* input : inputs)
+      types.push_back(TypeAttr::get(cast<MlirTensor>(input)->getElementType()));
+    attrs_[arg_def.type_list_attr()] = ArrayAttr::get(types, GetContext());
+  }
+  return Status::OK();
+}
+
 Status MlirFunctionContext::Finalize(OutputList* outputs,
                                      AbstractFunction** f) {
   Block& body = func_.getBody().front();