Update CompileGraphToXlaHlo to support resource XlaArguments and populate XlaCompilationResult.resource_updates instead of XlaCompilationResult.outputs for resource writes.

This is necessary to enable MLIR support when compiling tf.functions with XLA that use resources.

PiperOrigin-RevId: 323692573
Change-Id: If3af722e5beb98dfd8260a4cb2df8db8a2a550ff
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index fe1f47d..8dcaf23 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -1459,6 +1459,7 @@
     ":mlir_roundtrip_flags",
     ":tensorflow",
     ":tensorflow_dialect_registration",
+    ":tensorflow_types",
     ":tensorflow_passes",
     ":translate_utils",
     "@com_google_absl//absl/types:optional",
@@ -1520,6 +1521,9 @@
     srcs = ["utils/compile_mlir_util_test.cc"],
     deps = [
         ":compile_mlir_util",
+        "//tensorflow/cc:function_ops",
+        "//tensorflow/cc:resource_variable_ops",
+        "//tensorflow/cc:scope",
         "//tensorflow/compiler/jit",
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:xla_compiler",
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index 9d6cc88..e273020 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -17,11 +17,14 @@
 
 #include "absl/types/optional.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/raw_ostream.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/Function.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
@@ -36,6 +39,7 @@
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
@@ -52,6 +56,7 @@
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/platform/logging.h"
 
 namespace tensorflow {
@@ -79,9 +84,15 @@
   return Status::OK();
 }
 
+// Arguments to a computation can be either a tensor or resource.
+struct TensorOrResourceShape {
+  TensorShape shape;
+  bool is_resource = false;
+};
+
 // Converts arg_shapes to xla::Shape's and store into xla_input_shapes.
 Status GetXlaInputShapes(
-    mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
+    mlir::ModuleOp module, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
     bool use_tuple_args,
     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
     std::vector<xla::Shape>* xla_input_shapes) {
@@ -103,7 +114,7 @@
     DataType dtype;
     TF_RETURN_IF_ERROR(ConvertToDataType(func_type.getInput(i), &dtype));
     TF_ASSIGN_OR_RETURN(xla_shape,
-                        shape_representation_fn(arg_shapes[i], dtype,
+                        shape_representation_fn(arg_shapes[i].shape, dtype,
                                                 /*use_fast_memory=*/false));
 
     // Rewrite layout with sharding, if sharding is set.
@@ -132,11 +143,13 @@
 }
 
 // Calculates computation output shape and build OutputDescription for each
-// output based on static shapes in MLIR module
+// output based on static shapes in MLIR module. If an output is a resource
+// write, `resource_updates` is populated insead of `outputs` for that output.
 Status GetOutputInfo(
     mlir::ModuleOp module,
     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
-    xla::Shape* xla_output_shape, std::vector<XlaOutputDescription>* outputs) {
+    xla::Shape* xla_output_shape, std::vector<XlaOutputDescription>* outputs,
+    std::vector<XlaResourceUpdate>* resource_updates) {
   auto shape_representation_fn_no_fast_memory =
       [shape_representation_fn](const TensorShape& shape, DataType dtype) {
         return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
@@ -147,17 +160,37 @@
 
   outputs->clear();
   outputs->reserve(func_type.getNumResults());
+  resource_updates->reserve(func_type.getNumResults());
 
   std::vector<xla::Shape> shapes;
   shapes.reserve(func_type.getNumResults());
 
-  for (mlir::Type type : func_type.getResults()) {
+  llvm::SmallDenseMap<unsigned, unsigned> resource_arg_to_write;
+  for (unsigned i = 0; i < main_func.getNumArguments(); ++i)
+    if (auto aliasing_output = main_func.getArgAttrOfType<mlir::IntegerAttr>(
+            i, "tf.aliasing_output"))
+      resource_arg_to_write.insert({aliasing_output.getInt(), i});
+
+  for (auto type_and_idx : llvm::enumerate(func_type.getResults())) {
     TF_ASSIGN_OR_RETURN(
         xla::Shape shape,
-        xla::TypeToShape(type, shape_representation_fn_no_fast_memory));
-    auto tensor_type = type.dyn_cast<mlir::RankedTensorType>();
+        xla::TypeToShape(type_and_idx.value(),
+                         shape_representation_fn_no_fast_memory));
+    auto tensor_type = type_and_idx.value().dyn_cast<mlir::RankedTensorType>();
     shapes.push_back(shape);
 
+    auto it = resource_arg_to_write.find(type_and_idx.index());
+    if (it != resource_arg_to_write.end()) {
+      // Add resource write.
+      resource_updates->emplace_back();
+      XlaResourceUpdate& resource_update = resource_updates->back();
+      resource_update.input_index = it->getSecond();
+      resource_update.modified = true;
+      TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &resource_update.type));
+      TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &resource_update.shape));
+      continue;
+    }
+
     // Construct OutputDescription for result.
     outputs->emplace_back();
     XlaOutputDescription& out_desc = outputs->back();
@@ -180,14 +213,6 @@
   return Status::OK();
 }
 
-// Gets information about how computation updates Tensorflow resources.
-// TODO(ycao): Implement logic to compute resource updates when we need to
-// support graphs with resource updates in MLIR-based TF compiler bridge.
-void GetResourceUpdatesForMlir(
-    std::vector<XlaResourceUpdate>* resource_updates) {
-  resource_updates->clear();
-}
-
 // Creates a vector that maps from the parameters of the XLA computation to
 // their original argument positions.
 // MLIR-based TF-Compiler bridge doesn't have constant analysis yet, thus no
@@ -201,7 +226,7 @@
 }
 
 // Refine MLIR types based on new shape information.
-Status RefineShapes(llvm::ArrayRef<TensorShape> arg_shapes,
+Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
                     mlir::ModuleOp module) {
   auto producer_or = GetTfGraphProducerVersion(module);
   if (!producer_or.ok()) return producer_or.status();
@@ -212,15 +237,20 @@
   {
     // Convert arg_shapes to a mlir friendly format.
     size_t count = 0;
-    for (const TensorShape& shape : arg_shapes) {
-      count += shape.dims();
+    for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) {
+      if (tensor_resource_shape.is_resource) continue;
+      count += tensor_resource_shape.shape.dims();
     }
     shape_backing.resize(count);
     arg_shapes_copy.reserve(arg_shapes.size());
     size_t offset = 0;
-    for (const TensorShape& shape : arg_shapes) {
+    for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) {
+      if (tensor_resource_shape.is_resource) {
+        arg_shapes_copy.push_back(llvm::ArrayRef<int64_t>());
+        continue;
+      }
       size_t start = offset;
-      for (tensorflow::TensorShapeDim dim : shape) {
+      for (tensorflow::TensorShapeDim dim : tensor_resource_shape.shape) {
         shape_backing[offset] = dim.size;
         ++offset;
       }
@@ -338,7 +368,7 @@
 }
 
 static Status CompileMlirToXlaHlo(
-    mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes,
+    mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
     llvm::StringRef device_type, bool use_tuple_args,
     XlaHelpers::ShapeRepresentationFn shape_representation_fn,
     XlaCompilationResult* compilation_result,
@@ -372,14 +402,10 @@
                                        shape_representation_fn,
                                        &compilation_result->xla_input_shapes));
 
-  // Compute all output descriptions.
-  TF_RETURN_IF_ERROR(GetOutputInfo(module_op, shape_representation_fn,
-                                   &compilation_result->xla_output_shape,
-                                   &compilation_result->outputs));
-
-  // Compute what resource variables need to be updated after XlaComputation's
-  // execution.
-  GetResourceUpdatesForMlir(&compilation_result->resource_updates);
+  // Compute all output descriptions and resource writes
+  TF_RETURN_IF_ERROR(GetOutputInfo(
+      module_op, shape_representation_fn, &compilation_result->xla_output_shape,
+      &compilation_result->outputs, &compilation_result->resource_updates));
 
   if (VLOG_IS_ON(1))
     tensorflow::DumpMlirOpToFile("mlir_compile_after", module_op);
@@ -399,26 +425,51 @@
 
   TF_RETURN_IF_ERROR(
       ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module));
-  return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, device_type,
-                             use_tuple_args, shape_representation_fn,
-                             compilation_result,
+  llvm::SmallVector<TensorOrResourceShape, 4> tensor_or_resource_shapes;
+  tensor_or_resource_shapes.reserve(arg_shapes.size());
+  for (const auto& arg_shape : arg_shapes)
+    tensor_or_resource_shapes.push_back({arg_shape});
+  return CompileMlirToXlaHlo(mlir_module.get(), tensor_or_resource_shapes,
+                             device_type, use_tuple_args,
+                             shape_representation_fn, compilation_result,
                              std::move(custom_legalization_passes));
 }
 
 // Rewrites the given module with specified args. For each of the constant args,
 // it gets inlined in the "main' function and the corresponding argument is
-// removed from the signature.
+// removed from the signature. For resource args, their subtypes are populated.
 // Returns the original indices for the other arguments on success.
 static StatusOr<std::vector<int>> RewriteWithArgs(
     mlir::ModuleOp module, llvm::ArrayRef<const XlaArgument> args) {
   mlir::FuncOp main_fn = module.lookupSymbol<mlir::FuncOp>("main");
   std::vector<int> params;
 
+  bool has_resource_args = false;
   auto builder = mlir::OpBuilder(main_fn.getBody());
   std::vector<int> args_to_erase;
   for (int idx = 0; idx < args.size(); idx++) {
     const XlaArgument& xla_arg = args[idx];
     mlir::BlockArgument mlir_arg = main_fn.getArgument(idx);
+    if (xla_arg.kind == XlaArgument::kResource) {
+      mlir::Type element_type;
+      TF_RETURN_IF_ERROR(ConvertDataType(xla_arg.type, builder, &element_type));
+      auto resource_shape = absl::get<TensorShape>(xla_arg.shape).dim_sizes();
+      llvm::SmallVector<int64_t, 4> resource_subtype_shape(
+          resource_shape.begin(), resource_shape.end());
+      auto resource_subtype =
+          mlir::RankedTensorType::get(resource_subtype_shape, element_type);
+      auto resource_type =
+          mlir::TF::ResourceType::get({resource_subtype}, builder.getContext());
+
+      auto tensor_type = mlir_arg.getType().cast<mlir::TensorType>();
+      if (tensor_type.hasRank()) {
+        mlir_arg.setType(
+            mlir::RankedTensorType::get(tensor_type.getShape(), resource_type));
+      } else {
+        mlir_arg.setType(mlir::UnrankedTensorType::get(resource_type));
+      }
+      has_resource_args = true;
+    }
     if (xla_arg.kind != XlaArgument::kConstant) {
       params.push_back(idx);
       continue;
@@ -433,7 +484,19 @@
     args_to_erase.push_back(idx);
   }
 
+  if (has_resource_args) {
+    llvm::SmallVector<mlir::Type, 4> updated_argument_types;
+    updated_argument_types.reserve(main_fn.getNumArguments());
+    for (mlir::BlockArgument& arg : main_fn.getArguments())
+      updated_argument_types.push_back(arg.getType());
+
+    main_fn.setType(mlir::FunctionType::get(updated_argument_types,
+                                            main_fn.getType().getResults(),
+                                            main_fn.getContext()));
+  }
+
   for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx);
+
   return params;
 }
 
@@ -456,10 +519,13 @@
   mlir::ModuleOp module = module_or.ValueOrDie().get();
   TF_ASSIGN_OR_RETURN(std::vector<int> remaining_params,
                       RewriteWithArgs(module, {args.data(), args.size()}));
-  llvm::SmallVector<TensorShape, 4> arg_shapes;
-  arg_shapes.reserve(args.size());
-  for (unsigned idx : remaining_params)
-    arg_shapes.push_back(absl::get<TensorShape>(args[idx].shape));
+  llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
+  arg_shapes.reserve(remaining_params.size());
+  for (unsigned idx : remaining_params) {
+    const auto& arg = args[idx];
+    arg_shapes.push_back({absl::get<TensorShape>(arg.shape),
+                          /*is_resource=*/arg.kind == XlaArgument::kResource});
+  }
 
   mlir::PassManager pm(&context);
   mlir::TF::StandardPipelineOptions tf_options;
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h
index 719a96f..5c64a65 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h
@@ -73,6 +73,7 @@
     std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
 
 // Same as the above but takes input as TensorFlow Graph.
+// TODO(lyandy): Allow populating of targets/control outputs.
 Status CompileGraphToXlaHlo(
     const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
     llvm::StringRef device_type, bool use_tuple_args,
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc
index dde2408..6ebf689 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc
@@ -15,6 +15,9 @@
 
 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
 
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -448,9 +451,6 @@
   FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
   Graph graph(OpRegistry::Global());
 
-  Tensor dummy_tensor(DT_FLOAT, TensorShape({1}));
-  test::FillValues<float>(&dummy_tensor, {-1.0});
-
   Node* arg = test::graph::Arg(&graph, 0, DT_FLOAT);
   test::graph::Retval(&graph, 0, arg);
 
@@ -483,5 +483,60 @@
             status_or_hlo_module.ValueOrDie()->ToString());
 }
 
+// Tests a conversion from Graph to MLIR with resource arguments.
+TEST(CompileGraphToXlaHlo, Resources) {
+  FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
+  Graph graph(OpRegistry::Global());
+
+  Scope scope = Scope::NewRootScope().ExitOnError();
+  auto val = ops::_Arg(scope.WithOpName("arg0"), DT_FLOAT, 0);
+  auto var = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1);
+  auto assign =
+      ops::AssignVariableOp(scope.WithOpName("assign_variable"), var, val);
+  TF_ASSERT_OK(scope.ToGraph(&graph));
+
+  XlaCompiler::CompilationResult result;
+  XlaCompiler::Argument arg0;
+  arg0.kind = XlaCompiler::Argument::kParameter;
+  arg0.shape = TensorShape({2});
+  XlaCompiler::Argument arg1;
+  arg1.kind = XlaCompiler::Argument::kResource;
+  arg1.shape = TensorShape({2});
+  arg1.type = DT_FLOAT;
+
+  TF_ASSERT_OK(
+      CompileGraphToXlaHlo(graph, /*args=*/{arg0, arg1}, "XLA_CPU_JIT",
+                           /*use_tuple_args=*/false, flib_def, GraphDebugInfo(),
+                           /*shape_representation_fn=*/nullptr, &result));
+
+  EXPECT_EQ(result.outputs.size(), 0);
+  ASSERT_EQ(result.resource_updates.size(), 1);
+  const auto& resource_update = result.resource_updates[0];
+  EXPECT_EQ(resource_update.input_index, 1);
+  EXPECT_EQ(resource_update.modified, true);
+  EXPECT_EQ(resource_update.shape, TensorShape({2}));
+  EXPECT_EQ(resource_update.type, DT_FLOAT);
+
+  const xla::HloModuleConfig module_config(
+      result.computation->GetProgramShape().ValueOrDie());
+  auto status_or_hlo_module = xla::HloModule::CreateFromProto(
+      result.computation->proto(), module_config);
+  ASSERT_TRUE(status_or_hlo_module.ok());
+
+  constexpr char expected_hlo_module_string[] =
+      R"(HloModule main.4, input_output_alias={ {0}: 1 }
+
+ENTRY %main.4 (Arg_0.1: f32[2], Arg_1.2: f32[2]) -> (f32[2]) {
+  %Arg_1.2 = f32[2]{0} parameter(1)
+  %Arg_0.1 = f32[2]{0} parameter(0)
+  ROOT %tuple.3 = (f32[2]{0}) tuple(f32[2]{0} %Arg_0.1)
+}
+
+)";
+
+  EXPECT_EQ(expected_hlo_module_string,
+            status_or_hlo_module.ValueOrDie()->ToString());
+}
+
 }  // namespace
 }  // namespace tensorflow