[xla:runtime] Move runtime transforms from jitrt under xla/mlir/transforms

PiperOrigin-RevId: 465327647
diff --git a/tensorflow/compiler/xla/mlir/transforms/runtime/BUILD b/tensorflow/compiler/xla/mlir/transforms/runtime/BUILD
new file mode 100644
index 0000000..73bd73f
--- /dev/null
+++ b/tensorflow/compiler/xla/mlir/transforms/runtime/BUILD
@@ -0,0 +1,45 @@
+load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
+load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
+
+package(
+    default_visibility = [
+        "//tensorflow:internal",
+        "@tf_runtime//:friends",
+    ],
+    licenses = ["notice"],
+)
+
+gentbl_cc_library(
+    name = "rt_transforms_passes_inc_gen",
+    compatible_with = get_compatible_with_portable(),
+    tbl_outs = [
+        (
+            [
+                "-gen-pass-decls",
+                "-name=RuntimeTransforms",
+            ],
+            "rt_gen_passes.h.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "rt_passes.td",
+    deps = ["@llvm-project//mlir:PassBaseTdFiles"],
+)
+
+cc_library(
+    name = "rt_transforms",
+    srcs = ["rt_convert_to_entrypoint.cc"],
+    hdrs = ["rt_passes.h"],
+    compatible_with = get_compatible_with_portable(),
+    deps = [
+        ":rt_transforms_passes_inc_gen",
+        "//tensorflow/compiler/xla/mlir/ir/runtime:rt_ops",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:ControlFlowDialect",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+    ],
+)
diff --git a/tensorflow/compiler/xla/mlir/transforms/runtime/rt_convert_to_entrypoint.cc b/tensorflow/compiler/xla/mlir/transforms/runtime/rt_convert_to_entrypoint.cc
new file mode 100644
index 0000000..0ebdf7a
--- /dev/null
+++ b/tensorflow/compiler/xla/mlir/transforms/runtime/rt_convert_to_entrypoint.cc
@@ -0,0 +1,207 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"  // from @llvm-project
+#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
+#include "mlir/IR/ImplicitLocOpBuilder.h"  // from @llvm-project
+#include "tensorflow/compiler/xla/mlir/ir/runtime/rt_ops.h"
+#include "tensorflow/compiler/xla/mlir/transforms/runtime/rt_passes.h"
+
+namespace xla {
+namespace runtime {
+
+using namespace mlir;  // NOLINT
+
+#define GEN_PASS_CLASSES
+#include "tensorflow/compiler/xla/mlir/transforms/runtime/rt_gen_passes.h.inc"
+
+class ConvertToEntrypointPass
+    : public ConvertToEntrypointBase<ConvertToEntrypointPass> {
+  void runOnOperation() override;
+};
+
+static void ConvertCustomCallOperations(func::FuncOp func, Value exec_ctx) {
+  MLIRContext* ctx = func->getContext();
+
+  SymbolTable sym_table(func->getParentOfType<ModuleOp>());
+
+  struct CustomCall {
+    func::CallOp call;
+    func::FuncOp callee;
+    llvm::StringRef target;
+    bool direct;
+  };
+
+  // Collect function calls that have to be converted to custom calls.
+  llvm::SmallVector<CustomCall> custom_calls;
+  func.walk([&](func::CallOp op) {
+    auto callee = dyn_cast<func::FuncOp>(sym_table.lookup(op.getCallee()));
+    if (!callee) return;
+
+    // Check if the call is an indirect custom call ...
+    StringAttr target = callee->getAttrOfType<StringAttr>("rt.custom_call");
+    if (target) custom_calls.push_back({op, callee, target.strref(), false});
+
+    // ... or a direct custom call.
+    target = callee->getAttrOfType<StringAttr>("rt.direct_custom_call");
+    if (target) custom_calls.push_back({op, callee, target.strref(), true});
+  });
+
+  // After converting to custom call we need to clean up all declarations.
+  llvm::DenseSet<func::FuncOp> erase_declarations;
+
+  // Rewrite function calls to `rt.custom_call` operations.
+  for (CustomCall custom_call : custom_calls) {
+    ImplicitLocOpBuilder b(custom_call.call.getLoc(), custom_call.call);
+
+    // Custom call intrinsic always returns the status flag.
+    llvm::SmallVector<Type> results = {StatusType::get(ctx)};
+    results.append(custom_call.call->getResultTypes().begin(),
+                   custom_call.call->getResultTypes().end());
+
+    // Rewrite function call with a custom call, and check the return status.
+    auto call = b.create<CustomCallOp>(results, exec_ctx, custom_call.target,
+                                       custom_call.direct,
+                                       custom_call.call.getOperands());
+
+    // Copy optional attributes from the custom call function declaration.
+    llvm::ArrayRef<llvm::StringRef> callee_attrs =
+        custom_call.callee.getAttributeNames();
+    for (auto& attr : custom_call.callee->getAttrs()) {
+      if (isa_and_nonnull<RuntimeDialect>(attr.getNameDialect())) continue;
+      if (llvm::find(callee_attrs, attr.getName()) == callee_attrs.end())
+        call->setAttr(attr.getName(), attr.getValue());
+    }
+
+    // Copy optional attributes from the call operation to the custom call.
+    llvm::ArrayRef<llvm::StringRef> orig_attrs =
+        custom_call.call.getAttributeNames();
+    for (auto& attr : custom_call.call->getAttrs()) {
+      if (llvm::find(orig_attrs, attr.getName()) == orig_attrs.end())
+        call->setAttr(attr.getName(), attr.getValue());
+    }
+
+    b.create<cf::AssertOp>(
+        b.create<IsOkOp>(TypeRange(b.getI1Type()), call.status()),
+        b.getStringAttr("custom call '" + custom_call.target + "' failed"));
+
+    // Forward users of the original results to custom call results.
+    auto rets = llvm::zip(custom_call.call.getResults(),
+                          llvm::drop_begin(call.getResults()));
+    llvm::for_each(rets, [](auto ret) {
+      std::get<0>(ret).replaceAllUsesWith(std::get<1>(ret));
+    });
+
+    // Keep track of custom call declaration to erase.
+    erase_declarations.insert(custom_call.callee);
+
+    // Erase the original function call operation.
+    custom_call.call.erase();
+  }
+
+  // Erase all converted custom calls declarations.
+  for (auto func : erase_declarations) sym_table.erase(func);
+}
+
+static void ConvertReturnOperations(func::FuncOp func, Value exec_ctx) {
+  // Convert all returns to the Runtime API calls.
+  func.walk([&](func::ReturnOp ret) {
+    ImplicitLocOpBuilder b(ret.getLoc(), ret);
+
+    // Return all outputs via the `rt.set_output` operation.
+    for (auto& pair : llvm::enumerate(ret.getOperands())) {
+      b.create<SetOutputOp>(exec_ctx, pair.index(), pair.value());
+    }
+
+    // Replace original return with an empty one.
+    b.create<func::ReturnOp>();
+    ret.erase();
+  });
+
+  // Update function type to the function with empty results.
+  auto type = FunctionType::get(func.getContext(), func.getArgumentTypes(), {});
+  func.setType(type);
+}
+
+static void ConvertAssertOperations(func::FuncOp func, Value exec_ctx) {
+  // Collect all assert operations in the function body.
+  llvm::SmallVector<cf::AssertOp> asserts;
+  func.walk([&](cf::AssertOp op) { asserts.push_back(op); });
+
+  // Rewrite all asserts to the Runtime API calls.
+  for (cf::AssertOp assert : asserts) {
+    ImplicitLocOpBuilder b(assert.getLoc(), assert);
+
+    // Split the block at the assert operation.
+    Block* block = assert->getBlock();
+    Block* ok = block->splitBlock(assert);
+
+    // Set up block for returning error.
+    Block* err = func.addBlock();
+    b.setInsertionPointToStart(err);
+    b.create<SetErrorOp>(exec_ctx, assert.getMsg());
+    b.create<func::ReturnOp>();
+
+    // Branch into the error block if assertion failed.
+    b.setInsertionPointToEnd(block);
+    b.create<cf::CondBranchOp>(assert.getArg(), ok, err);
+
+    // Erase the original assert operation.
+    assert.erase();
+  }
+}
+
+static Value PrependExecutionContextArgument(func::FuncOp func) {
+  Type new_type = KernelContextType::get(func.getContext());
+  DictionaryAttr attr = DictionaryAttr::get(func.getContext());
+  func.insertArguments({0}, {new_type}, {attr}, {func.getLoc()});
+  return func.getArgument(0);
+}
+
+static void ConvertToEntrypoint(func::FuncOp func) {
+  assert(func->hasAttr(kEntrypointAttrName));
+
+  Value exec_ctx = PrependExecutionContextArgument(func);
+  ConvertCustomCallOperations(func, exec_ctx);
+  ConvertReturnOperations(func, exec_ctx);
+  ConvertAssertOperations(func, exec_ctx);
+
+  // After conversion !rt.execution_context is a marker of an entrypoint.
+  func->removeAttr(kEntrypointAttrName);
+}
+
+void ConvertToEntrypointPass::runOnOperation() {
+  llvm::SmallVector<func::FuncOp> entry_points;
+
+  // Collect entrypoint functions.
+  getOperation().walk([&](func::FuncOp op) {
+    if (op->hasAttr(kEntrypointAttrName)) entry_points.push_back(op);
+  });
+
+  llvm::for_each(entry_points, ConvertToEntrypoint);
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> CreateConvertToEntrypoint() {
+  return std::make_unique<ConvertToEntrypointPass>();
+}
+
+}  // namespace runtime
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/mlir/transforms/runtime/rt_passes.h b/tensorflow/compiler/xla/mlir/transforms/runtime/rt_passes.h
new file mode 100644
index 0000000..0c5186c
--- /dev/null
+++ b/tensorflow/compiler/xla/mlir/transforms/runtime/rt_passes.h
@@ -0,0 +1,39 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_MLIR_RUNTIME_RT_PASSES_H_
+#define XLA_MLIR_RUNTIME_RT_PASSES_H_
+
+#include <memory>
+
+#include "mlir/IR/BuiltinOps.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "tensorflow/compiler/xla/mlir/ir/runtime/rt_ops.h"
+
+namespace xla {
+namespace runtime {
+
+static constexpr char const* kEntrypointAttrName = "rt.entrypoint";
+
+std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
+CreateConvertToEntrypoint();
+
+#define GEN_PASS_REGISTRATION
+#include "tensorflow/compiler/xla/mlir/transforms/runtime/rt_gen_passes.h.inc"
+
+}  // namespace runtime
+}  // namespace xla
+
+#endif  // XLA_MLIR_RUNTIME_RT_PASSES_H_
diff --git a/tensorflow/compiler/xla/mlir/transforms/runtime/rt_passes.td b/tensorflow/compiler/xla/mlir/transforms/runtime/rt_passes.td
new file mode 100644
index 0000000..709e95f
--- /dev/null
+++ b/tensorflow/compiler/xla/mlir/transforms/runtime/rt_passes.td
@@ -0,0 +1,85 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RT_PASSES
+#define RT_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def ConvertToEntrypoint : Pass<"rt-convert-to-entrypoint", "mlir::ModuleOp"> {
+
+  let summary = "Converts function(s) to Xla runtime entrypoint(s)";
+
+  let description = [{
+    Converts function with a `xla.entrypoint` unit attribute to an Xla
+    entrypoint, i.e.:
+     - first argument is an `!rt.execution_context`
+     - all results returned via the `rt.set_result` operation
+     - failed asserts set the results error via the `rt.set_error` operation
+     - function calls marked with `rt.custom_call` attribute (on the callee)
+       converted to the `rt.custom_call` operations (or `rt.direct_custom_call`
+       attribute for direct custom calls)
+
+    See the `ir/runtime/rt_ops.td` to find how Xla executable returns results
+    and errors usin the runtime APIs.
+
+    When converting function call to the custom call operation, custom call
+    attributes will be a union of custom call function declaration attributes,
+    and the call operation attributes. Function call attributes will override
+    any attributes defined by the custom call function declaration.
+
+    Example:
+
+      ```mlir
+      func @custom_call() -> memref<?xf32>
+        attributes { rt.custom_call = "custom_call", attr = <value> }
+
+      func @compute(...) -> memref<?xf32> attributes { xla.entrypoint } {
+        %0 = ... : i1
+        assert %0, "Oops"
+        %1 = call @custom_call() { attr = <new_value> }: () -> memref<?xf32>
+        return %1
+      }
+      ```
+
+    converted to:
+
+      ```mlir
+      func @compute(%ctx: !rt.execution_context, ...) {
+        %0 = ... : i1
+        cond_br %0, ^ok0, ^err0
+      ^ok0:
+        %status, %1 = rt.custom_call %ctx, "custom_call"()
+                      { attr = <new value> } : () -> memref<?xf32>
+        %success = rt.is_ok %status : !rt.status
+        cond_br %success, ^ok1, ^err1
+      ^ok1:
+        rt.set_output %ctx, 0, %1 : memref<xf32>
+        return
+      ^err0:
+        rt.set_error %ctx, "Oops"
+        return;
+      ^err1:
+        rt.set_error %ctx, "Custom call failed"
+        return;
+      }
+      ```
+  }];
+
+  let constructor = "xla::runtime::CreateConvertToEntrypoint()";
+  let dependentDialects = ["xla::runtime::RuntimeDialect"];
+}
+
+#endif  // RT_PASSES