[xla:runtime] NFC: Move types and type_converter library from jitrt to xla

PiperOrigin-RevId: 466429383
diff --git a/tensorflow/compiler/xla/mlir/transforms/runtime/BUILD b/tensorflow/compiler/xla/mlir/transforms/runtime/BUILD
index c8683f3..13ca2d6 100644
--- a/tensorflow/compiler/xla/mlir/transforms/runtime/BUILD
+++ b/tensorflow/compiler/xla/mlir/transforms/runtime/BUILD
@@ -43,3 +43,19 @@
         "@llvm-project//mlir:Support",
     ],
 )
+
+cc_library(
+    name = "type_converter",
+    srcs = ["type_converter.cc"],
+    hdrs = ["type_converter.h"],
+    compatible_with = get_compatible_with_portable(),
+    deps = [
+        "//tensorflow/compiler/xla/mlir/ir/runtime:rt_ops",
+        "//tensorflow/compiler/xla/runtime:types",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AsyncDialect",
+        "@llvm-project//mlir:IR",
+        "@tf_runtime//:dtype",
+        "@tf_runtime//:support",
+    ],
+)
diff --git a/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.cc b/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.cc
new file mode 100644
index 0000000..d65208c
--- /dev/null
+++ b/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.cc
@@ -0,0 +1,138 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.h"
+
+#include <memory>
+#include <utility>
+
+#include "mlir/Dialect/Async/IR/AsyncTypes.h"  // from @llvm-project
+#include "tensorflow/compiler/xla/mlir/ir/runtime/rt_ops.h"
+#include "tfrt/support/error_util.h"  // from @tf_runtime
+
+namespace xla {
+namespace runtime {
+
+using llvm::Expected;
+
+using tfrt::DType;
+using tfrt::MakeStringError;
+
+// Type conversion for the canonical MLIR types supported by the runtime.
+static std::unique_ptr<Type> ConvertCanonicalType(
+    mlir::Type type, const TypeConverter& convert) {
+  // KernelContextType -> KernelContextOperandType (both in xla::runtime).
+  if (auto ctx = type.dyn_cast<KernelContextType>())
+    return std::make_unique<KernelContextOperandType>();
+
+  // mlir::async::TokenType -> xla::runtime::AsyncTokenType
+  if (type.isa<mlir::async::TokenType>())
+    return std::make_unique<AsyncTokenType>();
+
+  // mlir::async::ValueType -> xla::runtime::AsyncValueType
+  if (auto value = type.dyn_cast<mlir::async::ValueType>()) {
+    if (auto value_type = convert.Convert(value.getValueType()))
+      return std::make_unique<AsyncValueType>(std::move(*value_type));
+  }
+
+  // mlir::RankedTensorType -> xla::runtime::RankedTensorType
+  if (auto tensor = type.dyn_cast<mlir::RankedTensorType>()) {
+    if (auto dtype = TypeConverter::ConvertElementType(tensor.getElementType()))
+      return std::make_unique<RankedTensorType>(tensor.getShape(), *dtype);
+  }
+
+  // mlir::UnrankedTensorType -> xla::runtime::UnrankedTensorType
+  if (auto tensor = type.dyn_cast<mlir::UnrankedTensorType>()) {
+    if (auto dtype = TypeConverter::ConvertElementType(tensor.getElementType()))
+      return std::make_unique<UnrankedTensorType>(*dtype);
+  }
+
+  // mlir::MemrefType -> xla::runtime::MemrefType
+  if (auto memref = type.dyn_cast<mlir::MemRefType>()) {
+    if (auto dtype = TypeConverter::ConvertElementType(memref.getElementType()))
+      return std::make_unique<MemrefType>(memref.getShape(), *dtype);
+  }
+
+  // mlir::UnrankedMemrefType -> xla::runtime::UnrankedMemrefType
+  if (auto memref = type.dyn_cast<mlir::UnrankedMemRefType>()) {
+    if (auto dtype = TypeConverter::ConvertElementType(memref.getElementType()))
+      return std::make_unique<UnrankedMemrefType>(*dtype);
+  }
+
+  // For non-canonical types the user must provide type conversion function.
+  return {};
+}
+
+/*static*/ Expected<DType> TypeConverter::ConvertElementType(mlir::Type type) {
+  if (type.isF32()) return DType::F32;
+  if (type.isF64()) return DType::F64;
+  if (type.isUnsignedInteger(8)) return DType::UI8;
+  if (type.isUnsignedInteger(16)) return DType::UI16;
+  if (type.isUnsignedInteger(32)) return DType::UI32;
+  if (type.isUnsignedInteger(64)) return DType::UI64;
+  if (type.isInteger(1)) return DType::I1;
+  if (type.isInteger(8)) return DType::I8;
+  if (type.isInteger(16)) return DType::I16;
+  if (type.isInteger(32)) return DType::I32;
+  if (type.isInteger(64)) return DType::I64;
+  if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
+    auto element_type = complex_type.getElementType();
+    if (element_type.isF32()) return DType::Complex64;
+    if (element_type.isF64()) return DType::Complex128;
+  }
+
+  return MakeStringError("unsupported element type: ", type);
+}
+
+Expected<std::unique_ptr<Type>> TypeConverter::Convert(mlir::Type type) const {
+  if (auto converted = ConvertCanonicalType(type, *this)) return converted;
+
+  for (const ConversionFn& conversion : conversions_)
+    if (auto converted = conversion(type)) return converted;
+
+  return MakeStringError("can't convert type: ", type, " to the run time type");
+}
+
+Expected<FunctionType> TypeConverter::Convert(mlir::FunctionType type) const {
+  assert(type && "function type must be not null");
+
+  llvm::SmallVector<std::unique_ptr<Type>> operands;
+  llvm::SmallVector<std::unique_ptr<Type>> results;
+
+  operands.reserve(type.getNumInputs());
+  results.reserve(type.getNumResults());
+
+  auto error = [](llvm::StringRef kind, unsigned i, mlir::Type type) {
+    return MakeStringError("can't convert ", kind, " #", i, " type ", type,
+                           " to the run time type");
+  };
+
+  for (unsigned i = 0; i < type.getNumInputs(); ++i) {
+    Expected<std::unique_ptr<Type>> converted = Convert(type.getInput(i));
+    if (!converted) return error("input", i, type.getInput(i));
+    operands.push_back(std::move(*converted));
+  }
+
+  for (unsigned i = 0; i < type.getNumResults(); ++i) {
+    Expected<std::unique_ptr<Type>> converted = Convert(type.getResult(i));
+    if (!converted) return error("result", i, type.getResult(i));
+    results.push_back(std::move(*converted));
+  }
+
+  return FunctionType(std::move(operands), std::move(results));
+}
+
+}  // namespace runtime
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.h b/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.h
new file mode 100644
index 0000000..0fe492a
--- /dev/null
+++ b/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.h
@@ -0,0 +1,81 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_MLIR_RUNTIME_TYPE_CONVERTER_H_
+#define XLA_MLIR_RUNTIME_TYPE_CONVERTER_H_
+
+#include <functional>
+#include <memory>
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Error.h"
+#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
+#include "tensorflow/compiler/xla/runtime/types.h"
+
+namespace xla {
+namespace runtime {
+
+//===----------------------------------------------------------------------===//
+// Type conversion from the compile time types to the run-time types.
+//===----------------------------------------------------------------------===//
+
+// Type converter converts MLIR types known at compile time to the corresponding
+// types used at run time. It provides default conversions for the canonical
+// types (memrefs, tensors, etc...) and allows users to register custom
+// conversions for user-defined types.
+class TypeConverter {
+ public:
+  // Conversion function must return run time type corresponding to the compile
+  // time type if the conversion is successful, or `nullptr` if failed.
+  using ConversionFn = std::function<std::unique_ptr<Type>(mlir::Type)>;
+
+  // Adds a type conversion function with a type predicate.
+  //
+  // Example:
+  //
+  //   AddConversion([](mlir::TensorType) -> std::unique_ptr<Type> { ... });
+  //
+  // The conversion function will match only the tensor type, and return empty
+  // result for all other types, and the type converter will try the next
+  // conversion function (see `Convert` implementation).
+  template <typename Fn, typename FnTraits = llvm::function_traits<Fn>>
+  void AddConversion(Fn fn) {
+    using ArgType = typename FnTraits::template arg_t<0>;
+    conversions_.emplace_back(
+        [fn = std::forward<Fn>(fn)](mlir::Type type) -> std::unique_ptr<Type> {
+          if (auto arg = type.dyn_cast<ArgType>()) return fn(arg);
+          return {};
+        });
+  }
+
+  // Converts MLIR element type to the DType.
+  static llvm::Expected<tfrt::DType> ConvertElementType(mlir::Type type);
+
+  // Converts MLIR type to the runtime type. Returns error if conversion was not
+  // successful and the type has no corresponding run time type.
+  llvm::Expected<std::unique_ptr<Type>> Convert(mlir::Type type) const;
+
+  // Converts MLIR function type to the runtime function type. Returns error if
+  // function has unsupported operands or results types.
+  llvm::Expected<FunctionType> Convert(mlir::FunctionType type) const;
+
+ private:
+  llvm::SmallVector<ConversionFn> conversions_;
+};
+
+}  // namespace runtime
+}  // namespace xla
+
+#endif  // XLA_MLIR_RUNTIME_TYPE_CONVERTER_H_
diff --git a/tensorflow/compiler/xla/runtime/BUILD b/tensorflow/compiler/xla/runtime/BUILD
new file mode 100644
index 0000000..07ef16a
--- /dev/null
+++ b/tensorflow/compiler/xla/runtime/BUILD
@@ -0,0 +1,23 @@
+load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
+
+package(
+    default_visibility = [
+        "//tensorflow:internal",
+        "@tf_runtime//:friends",
+    ],
+    licenses = ["notice"],
+)
+
+cc_library(
+    name = "types",
+    srcs = ["types.cc"],
+    hdrs = ["types.h"],
+    compatible_with = get_compatible_with_portable(),
+    deps = [
+        "//tensorflow/compiler/xla/mlir/ir/runtime:rt_ops",
+        "@llvm-project//llvm:Support",
+        "@tf_runtime//:dtype",
+        "@tf_runtime//:support",
+    ],
+)
diff --git a/tensorflow/compiler/xla/runtime/types.cc b/tensorflow/compiler/xla/runtime/types.cc
new file mode 100644
index 0000000..cd023b3
--- /dev/null
+++ b/tensorflow/compiler/xla/runtime/types.cc
@@ -0,0 +1,111 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/runtime/types.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
+
+namespace xla {
+namespace runtime {
+
+using llvm::ArrayRef;
+using llvm::raw_ostream;
+
+//===----------------------------------------------------------------------===//
+// Pretty printing for canonical types.
+//===----------------------------------------------------------------------===//
+
+static raw_ostream& operator<<(raw_ostream& os, const ArrayRef<int64_t>& arr) {
+  auto str = llvm::map_range(arr, [](int64_t i) { return std::to_string(i); });
+  return os << llvm::join(str, "x") << (arr.empty() ? "" : "x");
+}
+
+raw_ostream& AsyncTokenType::print(raw_ostream& os) const {
+  return os << "!async.token";
+}
+
+raw_ostream& AsyncValueType::print(raw_ostream& os) const {
+  return os << "!async.value<" << value_type() << ">";
+}
+
+raw_ostream& RankedTensorType::print(raw_ostream& os) const {
+  return os << "tensor<" << sizes() << element_type() << ">";
+}
+
+raw_ostream& UnrankedTensorType::print(raw_ostream& os) const {
+  return os << "tensor<*x" << element_type() << ">";
+}
+
+raw_ostream& MemrefType::print(raw_ostream& os) const {
+  return os << "memref<" << sizes() << element_type() << ">";
+}
+
+raw_ostream& UnrankedMemrefType::print(raw_ostream& os) const {
+  return os << "memref<*x" << element_type() << ">";
+}
+
+raw_ostream& KernelContextOperandType::print(raw_ostream& os) const {
+  return os << "!rt.kernel_context";
+}
+
+//===----------------------------------------------------------------------===//
+// ABI definition for canonical types.
+//===----------------------------------------------------------------------===//
+
+using ArgumentAbi = Type::ArgumentAbi;
+using ResultAbi = Type::ResultAbi;
+
+// Async token returned as a pointer to the runtime async token.
+llvm::ErrorOr<ResultAbi> AsyncTokenType::AsResult() const {
+  return ResultAbi{sizeof(void*)};
+}
+
+// Async value returned as a pointer to the runtime async token.
+llvm::ErrorOr<ResultAbi> AsyncValueType::AsResult() const {
+  return ResultAbi{sizeof(void*)};
+}
+
+// Memref passed as an unrolled strided memref type.
+llvm::ErrorOr<ArgumentAbi> MemrefType::AsArgument() const {
+  return ArgumentAbi{3 + 2 * rank()};
+}
+
+// TODO(ezhulenev): We should query the size of the `StridedMemrefType`
+// directly, however it introduces dependency on the MLIR C runner utils.
+//
+// Memrefs are returned as StridedMemref<T, rank> type:
+//   basePtr, data, offset, sizes[rank], strides[rank]
+llvm::ErrorOr<ResultAbi> MemrefType::AsResult() const {
+  return ResultAbi{
+      sizeof(void*) * 2 +           // pointers
+      sizeof(int64_t) +             // offset
+      sizeof(int64_t) * 2 * rank()  // sizes and strides
+  };
+}
+
+// Kernel context passed as a single opaque pointer.
+llvm::ErrorOr<ArgumentAbi> KernelContextOperandType::AsArgument() const {
+  return ArgumentAbi{1};
+}
+
+}  // namespace runtime
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/runtime/types.h b/tensorflow/compiler/xla/runtime/types.h
new file mode 100644
index 0000000..73c1d59
--- /dev/null
+++ b/tensorflow/compiler/xla/runtime/types.h
@@ -0,0 +1,253 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_RUNTIME_TYPES_H_
+#define XLA_RUNTIME_TYPES_H_
+
+#include <functional>
+#include <memory>
+#include <utility>
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/ExtensibleRTTI.h"
+#include "tfrt/dtype/dtype.h"  // from @tf_runtime
+
+namespace xla {
+namespace runtime {
+
+//===----------------------------------------------------------------------===//
+// Canonical XLA runtime types for the executable arguments.
+//===----------------------------------------------------------------------===//
+
+// Types supported by the compiled function signature. We do rely on the LLVM
+// style RTTI (https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html) to avoid
+// dependency on the MLIR types at runtime, because we don't want to depend
+// on any of the compiler implementation details at runtime and we want to
+// support lightweight loading and execution of AOT compiled programs.
+//
+// We rely on the RTTI for the open class hierarchies, because we want to allow
+// users to define their own types for the arguments.
+//
+// If the type can be passed to the compiled function as an argument or returned
+// as a result, it must define its own ABI. The ABI is defined by the MLIR to
+// LLVM lowering pipeline and the runtime integration (see `runtime.h`).
+class Type : public llvm::RTTIExtends<Type, llvm::RTTIRoot> {
+ public:
+  static constexpr char ID = 0;  // NOLINT
+
+  // Arguments to compiled functions passed as a set of pointers. For example
+  // memref descriptor passed in as a set of pointers to data, sizes and
+  // strides. See `Argument::Pack` implementation for details (in `argument.h`).
+  struct ArgumentAbi {
+    size_t num_ptrs;
+  };
+
+  // Compiled function returns results by writing into the pre-allocated storage
+  // of the given size with the requested alignment. Runtime pre-allocates
+  // memory required for all results in the call frame.
+  struct ResultAbi {
+    size_t size;
+
+    // TODO(ezhulenev): Add alignment to the result ABI. Alignment is an
+    // important part of the result ABI that we ignore today. It all doesn't
+    // crash only because all results happen to have a size that is multiple of
+    // 8 bytes, and because of that all of the results are properly aligned.
+    // Results memory layout in the call frame should take in account base
+    // pointer alignment and alignment requirements of all results.
+  };
+
+  // Returns an Abi if the type can be used as an argument.
+  virtual llvm::ErrorOr<ArgumentAbi> AsArgument() const {
+    return llvm::errc::not_supported;
+  }
+
+  // Returns an Abi if the type can be returned as a result.
+  virtual llvm::ErrorOr<ResultAbi> AsResult() const {
+    return llvm::errc::not_supported;
+  }
+
+  virtual llvm::raw_ostream& print(llvm::raw_ostream& os) const = 0;
+
+ protected:
+  Type() = default;
+};
+
+inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const Type& type) {
+  return type.print(os);
+}
+
+//===----------------------------------------------------------------------===//
+// Async Token type corresponding to the mlir::async::TokenType
+//===----------------------------------------------------------------------===//
+
+class AsyncTokenType : public llvm::RTTIExtends<AsyncTokenType, Type> {
+ public:
+  static constexpr char ID = 0;  // NOLINT
+
+  llvm::ErrorOr<ResultAbi> AsResult() const final;
+
+  llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
+};
+
+//===----------------------------------------------------------------------===//
+// Async Value type corresponding to the mlir::async::ValueType.
+//===----------------------------------------------------------------------===//
+
+class AsyncValueType : public llvm::RTTIExtends<AsyncValueType, Type> {
+ public:
+  static constexpr char ID = 0;  // NOLINT
+
+  explicit AsyncValueType(std::unique_ptr<Type> value_type)
+      : value_type_(std::move(value_type)) {}
+
+  const Type& value_type() const { return *value_type_; }
+
+  llvm::ErrorOr<ResultAbi> AsResult() const final;
+
+  llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
+
+ private:
+  std::unique_ptr<Type> value_type_;
+};
+
+//===----------------------------------------------------------------------===//
+// Ranked Tensor type corresponding to the mlir::RankedTensorType.
+//===----------------------------------------------------------------------===//
+
+class RankedTensorType : public llvm::RTTIExtends<RankedTensorType, Type> {
+ public:
+  static constexpr char ID = 0;  // NOLINT
+  static constexpr int64_t kDynamicSize = -1;
+
+  RankedTensorType(llvm::ArrayRef<int64_t> sizes, tfrt::DType element_type)
+      : sizes_(sizes.begin(), sizes.end()), element_type_(element_type) {}
+
+  llvm::ArrayRef<int64_t> sizes() const { return sizes_; }
+  unsigned rank() const { return sizes_.size(); }
+  tfrt::DType element_type() const { return element_type_; }
+
+  llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
+
+ private:
+  llvm::SmallVector<int64_t> sizes_;
+  tfrt::DType element_type_;
+};
+
+//===----------------------------------------------------------------------===//
+// Unranked Tensor type corresponding to the mlir::UnrankedTensorType.
+//===----------------------------------------------------------------------===//
+
+class UnrankedTensorType : public llvm::RTTIExtends<UnrankedTensorType, Type> {
+ public:
+  static constexpr char ID = 0;  // NOLINT
+
+  explicit UnrankedTensorType(tfrt::DType element_type)
+      : element_type_(element_type) {}
+
+  tfrt::DType element_type() const { return element_type_; }
+
+  llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
+
+ private:
+  tfrt::DType element_type_;
+};
+
+//===----------------------------------------------------------------------===//
+// Ranked Memref type corresponding to the mlir::MemrefType.
+//===----------------------------------------------------------------------===//
+
+class MemrefType : public llvm::RTTIExtends<MemrefType, Type> {
+ public:
+  static constexpr char ID = 0;  // NOLINT
+  static constexpr int64_t kDynamicSize = -1;
+
+  MemrefType(llvm::ArrayRef<int64_t> sizes, tfrt::DType element_type)
+      : sizes_(sizes.begin(), sizes.end()), element_type_(element_type) {}
+
+  llvm::ArrayRef<int64_t> sizes() const { return sizes_; }
+  unsigned rank() const { return sizes_.size(); }
+  tfrt::DType element_type() const { return element_type_; }
+
+  llvm::ErrorOr<ArgumentAbi> AsArgument() const final;
+  llvm::ErrorOr<ResultAbi> AsResult() const final;
+
+  llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
+
+ private:
+  llvm::SmallVector<int64_t> sizes_;
+  tfrt::DType element_type_;
+};
+
+//===----------------------------------------------------------------------===//
+// Unranked Memref type corresponding to the mlir::UnrankedMemrefType.
+//===----------------------------------------------------------------------===//
+
+class UnrankedMemrefType : public llvm::RTTIExtends<UnrankedMemrefType, Type> {
+ public:
+  static constexpr char ID = 0;  // NOLINT
+
+  explicit UnrankedMemrefType(tfrt::DType element_type)
+      : element_type_(element_type) {}
+
+  tfrt::DType element_type() const { return element_type_; }
+
+  llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
+
+ private:
+  tfrt::DType element_type_;
+};
+
+//===----------------------------------------------------------------------===//
+// Corresponds to the RT dialect's KernelContextType.
+//===----------------------------------------------------------------------===//
+
+class KernelContextOperandType
+    : public llvm::RTTIExtends<KernelContextOperandType, Type> {
+ public:
+  static constexpr char ID = 0;  // NOLINT
+
+  llvm::ErrorOr<ArgumentAbi> AsArgument() const final;
+
+  llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
+};
+
+//===----------------------------------------------------------------------===//
+// Compiled function signature type corresponding to the mlir::FunctionType.
+//===----------------------------------------------------------------------===//
+
+class FunctionType {
+ public:
+  const Type* operand(unsigned index) const { return operands_[index].get(); }
+  const Type* result(unsigned index) const { return results_[index].get(); }
+
+  unsigned num_operands() const { return operands_.size(); }
+  unsigned num_results() const { return results_.size(); }
+
+  FunctionType(llvm::SmallVector<std::unique_ptr<Type>> operands,
+               llvm::SmallVector<std::unique_ptr<Type>> results)
+      : operands_(std::move(operands)), results_(std::move(results)) {}
+
+ private:
+  llvm::SmallVector<std::unique_ptr<Type>> operands_;
+  llvm::SmallVector<std::unique_ptr<Type>> results_;
+};
+
+}  // namespace runtime
+}  // namespace xla
+
+#endif  // XLA_RUNTIME_TYPES_H_