[xla:runtime] NFC: Move types and type_converter library from jitrt to xla
PiperOrigin-RevId: 466429383
diff --git a/tensorflow/compiler/xla/mlir/transforms/runtime/BUILD b/tensorflow/compiler/xla/mlir/transforms/runtime/BUILD
index c8683f3..13ca2d6 100644
--- a/tensorflow/compiler/xla/mlir/transforms/runtime/BUILD
+++ b/tensorflow/compiler/xla/mlir/transforms/runtime/BUILD
@@ -43,3 +43,19 @@
"@llvm-project//mlir:Support",
],
)
+
+cc_library(
+ name = "type_converter",
+ srcs = ["type_converter.cc"],
+ hdrs = ["type_converter.h"],
+ compatible_with = get_compatible_with_portable(),
+ deps = [
+ "//tensorflow/compiler/xla/mlir/ir/runtime:rt_ops",
+ "//tensorflow/compiler/xla/runtime:types",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:AsyncDialect",
+ "@llvm-project//mlir:IR",
+ "@tf_runtime//:dtype",
+ "@tf_runtime//:support",
+ ],
+)
diff --git a/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.cc b/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.cc
new file mode 100644
index 0000000..d65208c
--- /dev/null
+++ b/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.cc
@@ -0,0 +1,138 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.h"
+
+#include <memory>
+#include <utility>
+
+#include "mlir/Dialect/Async/IR/AsyncTypes.h" // from @llvm-project
+#include "tensorflow/compiler/xla/mlir/ir/runtime/rt_ops.h"
+#include "tfrt/support/error_util.h" // from @tf_runtime
+
+namespace xla {
+namespace runtime {
+
+using llvm::Expected;
+
+using tfrt::DType;
+using tfrt::MakeStringError;
+
+// Type conversion for the canonical MLIR types supported by the runtime.
+static std::unique_ptr<Type> ConvertCanonicalType(
+ mlir::Type type, const TypeConverter& convert) {
+ // KernelContextType -> KernelContextOperandType (both in xla::runtime).
+ if (auto ctx = type.dyn_cast<KernelContextType>())
+ return std::make_unique<KernelContextOperandType>();
+
+ // mlir::async::TokenType -> xla::runtime::AsyncTokenType
+ if (type.isa<mlir::async::TokenType>())
+ return std::make_unique<AsyncTokenType>();
+
+ // mlir::async::ValueType -> xla::runtime::AsyncValueType
+ if (auto value = type.dyn_cast<mlir::async::ValueType>()) {
+ if (auto value_type = convert.Convert(value.getValueType()))
+ return std::make_unique<AsyncValueType>(std::move(*value_type));
+ }
+
+ // mlir::RankedTensorType -> xla::runtime::RankedTensorType
+ if (auto tensor = type.dyn_cast<mlir::RankedTensorType>()) {
+ if (auto dtype = TypeConverter::ConvertElementType(tensor.getElementType()))
+ return std::make_unique<RankedTensorType>(tensor.getShape(), *dtype);
+ }
+
+ // mlir::UnrankedTensorType -> xla::runtime::UnrankedTensorType
+ if (auto tensor = type.dyn_cast<mlir::UnrankedTensorType>()) {
+ if (auto dtype = TypeConverter::ConvertElementType(tensor.getElementType()))
+ return std::make_unique<UnrankedTensorType>(*dtype);
+ }
+
+ // mlir::MemrefType -> xla::runtime::MemrefType
+ if (auto memref = type.dyn_cast<mlir::MemRefType>()) {
+ if (auto dtype = TypeConverter::ConvertElementType(memref.getElementType()))
+ return std::make_unique<MemrefType>(memref.getShape(), *dtype);
+ }
+
+ // mlir::UnrankedMemrefType -> xla::runtime::UnrankedMemrefType
+ if (auto memref = type.dyn_cast<mlir::UnrankedMemRefType>()) {
+ if (auto dtype = TypeConverter::ConvertElementType(memref.getElementType()))
+ return std::make_unique<UnrankedMemrefType>(*dtype);
+ }
+
+ // For non-canonical types the user must provide type conversion function.
+ return {};
+}
+
+/*static*/ Expected<DType> TypeConverter::ConvertElementType(mlir::Type type) {
+ if (type.isF32()) return DType::F32;
+ if (type.isF64()) return DType::F64;
+ if (type.isUnsignedInteger(8)) return DType::UI8;
+ if (type.isUnsignedInteger(16)) return DType::UI16;
+ if (type.isUnsignedInteger(32)) return DType::UI32;
+ if (type.isUnsignedInteger(64)) return DType::UI64;
+ if (type.isInteger(1)) return DType::I1;
+ if (type.isInteger(8)) return DType::I8;
+ if (type.isInteger(16)) return DType::I16;
+ if (type.isInteger(32)) return DType::I32;
+ if (type.isInteger(64)) return DType::I64;
+ if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
+ auto element_type = complex_type.getElementType();
+ if (element_type.isF32()) return DType::Complex64;
+ if (element_type.isF64()) return DType::Complex128;
+ }
+
+ return MakeStringError("unsupported element type: ", type);
+}
+
+Expected<std::unique_ptr<Type>> TypeConverter::Convert(mlir::Type type) const {
+ if (auto converted = ConvertCanonicalType(type, *this)) return converted;
+
+ for (const ConversionFn& conversion : conversions_)
+ if (auto converted = conversion(type)) return converted;
+
+ return MakeStringError("can't convert type: ", type, " to the run time type");
+}
+
+Expected<FunctionType> TypeConverter::Convert(mlir::FunctionType type) const {
+ assert(type && "function type must be not null");
+
+ llvm::SmallVector<std::unique_ptr<Type>> operands;
+ llvm::SmallVector<std::unique_ptr<Type>> results;
+
+ operands.reserve(type.getNumInputs());
+ results.reserve(type.getNumResults());
+
+ auto error = [](llvm::StringRef kind, unsigned i, mlir::Type type) {
+ return MakeStringError("can't convert ", kind, " #", i, " type ", type,
+ " to the run time type");
+ };
+
+ for (unsigned i = 0; i < type.getNumInputs(); ++i) {
+ Expected<std::unique_ptr<Type>> converted = Convert(type.getInput(i));
+ if (!converted) return error("input", i, type.getInput(i));
+ operands.push_back(std::move(*converted));
+ }
+
+ for (unsigned i = 0; i < type.getNumResults(); ++i) {
+ Expected<std::unique_ptr<Type>> converted = Convert(type.getResult(i));
+ if (!converted) return error("result", i, type.getResult(i));
+ results.push_back(std::move(*converted));
+ }
+
+ return FunctionType(std::move(operands), std::move(results));
+}
+
+} // namespace runtime
+} // namespace xla
diff --git a/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.h b/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.h
new file mode 100644
index 0000000..0fe492a
--- /dev/null
+++ b/tensorflow/compiler/xla/mlir/transforms/runtime/type_converter.h
@@ -0,0 +1,81 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_MLIR_RUNTIME_TYPE_CONVERTER_H_
+#define XLA_MLIR_RUNTIME_TYPE_CONVERTER_H_
+
+#include <functional>
+#include <memory>
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Error.h"
+#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
+#include "tensorflow/compiler/xla/runtime/types.h"
+
+namespace xla {
+namespace runtime {
+
+//===----------------------------------------------------------------------===//
+// Type conversion from the compile time types to the run-time types.
+//===----------------------------------------------------------------------===//
+
+// Type converter converts MLIR types known at compile time to the corresponding
+// types used at run time. It provides default conversions for the canonical
+// types (memrefs, tensors, etc...) and allows users to register custom
+// conversions for user-defined types.
+class TypeConverter {
+ public:
+ // Conversion function must return run time type corresponding to the compile
+ // time type if the conversion is successful, or `nullptr` if failed.
+ using ConversionFn = std::function<std::unique_ptr<Type>(mlir::Type)>;
+
+ // Adds a type conversion function with a type predicate.
+ //
+ // Example:
+ //
+ // AddConversion([](mlir::TensorType) -> std::unique_ptr<Type> { ... });
+ //
+ // The conversion function will match only the tensor type, and return empty
+ // result for all other types, and the type converter will try the next
+ // conversion function (see `Convert` implementation).
+ template <typename Fn, typename FnTraits = llvm::function_traits<Fn>>
+ void AddConversion(Fn fn) {
+ using ArgType = typename FnTraits::template arg_t<0>;
+ conversions_.emplace_back(
+ [fn = std::forward<Fn>(fn)](mlir::Type type) -> std::unique_ptr<Type> {
+ if (auto arg = type.dyn_cast<ArgType>()) return fn(arg);
+ return {};
+ });
+ }
+
+ // Converts MLIR element type to the DType.
+ static llvm::Expected<tfrt::DType> ConvertElementType(mlir::Type type);
+
+ // Converts MLIR type to the runtime type. Returns error if conversion was not
+ // successful and the type has no corresponding run time type.
+ llvm::Expected<std::unique_ptr<Type>> Convert(mlir::Type type) const;
+
+ // Converts MLIR function type to the runtime function type. Returns error if
+ // function has unsupported operands or results types.
+ llvm::Expected<FunctionType> Convert(mlir::FunctionType type) const;
+
+ private:
+ llvm::SmallVector<ConversionFn> conversions_;
+};
+
+} // namespace runtime
+} // namespace xla
+
+#endif // XLA_MLIR_RUNTIME_TYPE_CONVERTER_H_
diff --git a/tensorflow/compiler/xla/runtime/BUILD b/tensorflow/compiler/xla/runtime/BUILD
new file mode 100644
index 0000000..07ef16a
--- /dev/null
+++ b/tensorflow/compiler/xla/runtime/BUILD
@@ -0,0 +1,23 @@
+load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
+
+package(
+ default_visibility = [
+ "//tensorflow:internal",
+ "@tf_runtime//:friends",
+ ],
+ licenses = ["notice"],
+)
+
+cc_library(
+ name = "types",
+ srcs = ["types.cc"],
+ hdrs = ["types.h"],
+ compatible_with = get_compatible_with_portable(),
+ deps = [
+ "//tensorflow/compiler/xla/mlir/ir/runtime:rt_ops",
+ "@llvm-project//llvm:Support",
+ "@tf_runtime//:dtype",
+ "@tf_runtime//:support",
+ ],
+)
diff --git a/tensorflow/compiler/xla/runtime/types.cc b/tensorflow/compiler/xla/runtime/types.cc
new file mode 100644
index 0000000..cd023b3
--- /dev/null
+++ b/tensorflow/compiler/xla/runtime/types.cc
@@ -0,0 +1,111 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/runtime/types.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
+
+namespace xla {
+namespace runtime {
+
+using llvm::ArrayRef;
+using llvm::raw_ostream;
+
+//===----------------------------------------------------------------------===//
+// Pretty printing for canonical types.
+//===----------------------------------------------------------------------===//
+
+static raw_ostream& operator<<(raw_ostream& os, const ArrayRef<int64_t>& arr) {
+ auto str = llvm::map_range(arr, [](int64_t i) { return std::to_string(i); });
+ return os << llvm::join(str, "x") << (arr.empty() ? "" : "x");
+}
+
+raw_ostream& AsyncTokenType::print(raw_ostream& os) const {
+ return os << "!async.token";
+}
+
+raw_ostream& AsyncValueType::print(raw_ostream& os) const {
+ return os << "!async.value<" << value_type() << ">";
+}
+
+raw_ostream& RankedTensorType::print(raw_ostream& os) const {
+ return os << "tensor<" << sizes() << element_type() << ">";
+}
+
+raw_ostream& UnrankedTensorType::print(raw_ostream& os) const {
+ return os << "tensor<*x" << element_type() << ">";
+}
+
+raw_ostream& MemrefType::print(raw_ostream& os) const {
+ return os << "memref<" << sizes() << element_type() << ">";
+}
+
+raw_ostream& UnrankedMemrefType::print(raw_ostream& os) const {
+ return os << "memref<*x" << element_type() << ">";
+}
+
+raw_ostream& KernelContextOperandType::print(raw_ostream& os) const {
+ return os << "!rt.kernel_context";
+}
+
+//===----------------------------------------------------------------------===//
+// ABI definition for canonical types.
+//===----------------------------------------------------------------------===//
+
+using ArgumentAbi = Type::ArgumentAbi;
+using ResultAbi = Type::ResultAbi;
+
+// Async token returned as a pointer to the runtime async token.
+llvm::ErrorOr<ResultAbi> AsyncTokenType::AsResult() const {
+ return ResultAbi{sizeof(void*)};
+}
+
+// Async value returned as a pointer to the runtime async token.
+llvm::ErrorOr<ResultAbi> AsyncValueType::AsResult() const {
+ return ResultAbi{sizeof(void*)};
+}
+
+// Memref passed as an unrolled strided memref type.
+llvm::ErrorOr<ArgumentAbi> MemrefType::AsArgument() const {
+ return ArgumentAbi{3 + 2 * rank()};
+}
+
+// TODO(ezhulenev): We should query the size of the `StridedMemrefType`
+// directly, however it introduces dependency on the MLIR C runner utils.
+//
+// Memrefs are returned as StridedMemref<T, rank> type:
+// basePtr, data, offset, sizes[rank], strides[rank]
+llvm::ErrorOr<ResultAbi> MemrefType::AsResult() const {
+ return ResultAbi{
+ sizeof(void*) * 2 + // pointers
+ sizeof(int64_t) + // offset
+ sizeof(int64_t) * 2 * rank() // sizes and strides
+ };
+}
+
+// Kernel context passed as a single opaque pointer.
+llvm::ErrorOr<ArgumentAbi> KernelContextOperandType::AsArgument() const {
+ return ArgumentAbi{1};
+}
+
+} // namespace runtime
+} // namespace xla
diff --git a/tensorflow/compiler/xla/runtime/types.h b/tensorflow/compiler/xla/runtime/types.h
new file mode 100644
index 0000000..73c1d59
--- /dev/null
+++ b/tensorflow/compiler/xla/runtime/types.h
@@ -0,0 +1,253 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_RUNTIME_TYPES_H_
+#define XLA_RUNTIME_TYPES_H_
+
+#include <functional>
+#include <memory>
+#include <utility>
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/ExtensibleRTTI.h"
+#include "tfrt/dtype/dtype.h" // from @tf_runtime
+
+namespace xla {
+namespace runtime {
+
+//===----------------------------------------------------------------------===//
+// Canonical XLA runtime types for the executable arguments.
+//===----------------------------------------------------------------------===//
+
+// Types supported by the compiled function signature. We do rely on the LLVM
+// style RTTI (https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html) to avoid
+// dependency on the MLIR types at runtime, because we don't want to depend
+// on any of the compiler implementation details at runtime and we want to
+// support lightweight loading and execution of AOT compiled programs.
+//
+// We rely on the RTTI for the open class hierarchies, because we want to allow
+// users to define their own types for the arguments.
+//
+// If the type can be passed to the compiled function as an argument or returned
+// as a result, it must define its own ABI. The ABI is defined by the MLIR to
+// LLVM lowering pipeline and the runtime integration (see `runtime.h`).
+class Type : public llvm::RTTIExtends<Type, llvm::RTTIRoot> {
+ public:
+ static constexpr char ID = 0; // NOLINT
+
+ // Arguments to compiled functions passed as a set of pointers. For example
+ // memref descriptor passed in as a set of pointers to data, sizes and
+ // strides. See `Argument::Pack` implementation for details (in `argument.h`).
+ struct ArgumentAbi {
+ size_t num_ptrs;
+ };
+
+ // Compiled function returns results by writing into the pre-allocated storage
+ // of the given size with the requested alignment. Runtime pre-allocates
+ // memory required for all results in the call frame.
+ struct ResultAbi {
+ size_t size;
+
+ // TODO(ezhulenev): Add alignment to the result ABI. Alignment is an
+ // important part of the result ABI that we ignore today. It all doesn't
+ // crash only because all results happen to have a size that is multiple of
+ // 8 bytes, and because of that all of the results are properly aligned.
+ // Results memory layout in the call frame should take in account base
+ // pointer alignment and alignment requirements of all results.
+ };
+
+ // Returns an Abi if the type can be used as an argument.
+ virtual llvm::ErrorOr<ArgumentAbi> AsArgument() const {
+ return llvm::errc::not_supported;
+ }
+
+ // Returns an Abi if the type can be returned as a result.
+ virtual llvm::ErrorOr<ResultAbi> AsResult() const {
+ return llvm::errc::not_supported;
+ }
+
+ virtual llvm::raw_ostream& print(llvm::raw_ostream& os) const = 0;
+
+ protected:
+ Type() = default;
+};
+
+inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const Type& type) {
+ return type.print(os);
+}
+
+//===----------------------------------------------------------------------===//
+// Async Token type corresponding to the mlir::async::TokenType
+//===----------------------------------------------------------------------===//
+
+class AsyncTokenType : public llvm::RTTIExtends<AsyncTokenType, Type> {
+ public:
+ static constexpr char ID = 0; // NOLINT
+
+ llvm::ErrorOr<ResultAbi> AsResult() const final;
+
+ llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
+};
+
+//===----------------------------------------------------------------------===//
+// Async Value type corresponding to the mlir::async::ValueType.
+//===----------------------------------------------------------------------===//
+
+class AsyncValueType : public llvm::RTTIExtends<AsyncValueType, Type> {
+ public:
+ static constexpr char ID = 0; // NOLINT
+
+ explicit AsyncValueType(std::unique_ptr<Type> value_type)
+ : value_type_(std::move(value_type)) {}
+
+ const Type& value_type() const { return *value_type_; }
+
+ llvm::ErrorOr<ResultAbi> AsResult() const final;
+
+ llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
+
+ private:
+ std::unique_ptr<Type> value_type_;
+};
+
+//===----------------------------------------------------------------------===//
+// Ranked Tensor type corresponding to the mlir::RankedTensorType.
+//===----------------------------------------------------------------------===//
+
+class RankedTensorType : public llvm::RTTIExtends<RankedTensorType, Type> {
+ public:
+ static constexpr char ID = 0; // NOLINT
+ static constexpr int64_t kDynamicSize = -1;
+
+ RankedTensorType(llvm::ArrayRef<int64_t> sizes, tfrt::DType element_type)
+ : sizes_(sizes.begin(), sizes.end()), element_type_(element_type) {}
+
+ llvm::ArrayRef<int64_t> sizes() const { return sizes_; }
+ unsigned rank() const { return sizes_.size(); }
+ tfrt::DType element_type() const { return element_type_; }
+
+ llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
+
+ private:
+ llvm::SmallVector<int64_t> sizes_;
+ tfrt::DType element_type_;
+};
+
+//===----------------------------------------------------------------------===//
+// Unranked Tensor type corresponding to the mlir::UnrankedTensorType.
+//===----------------------------------------------------------------------===//
+
+class UnrankedTensorType : public llvm::RTTIExtends<UnrankedTensorType, Type> {
+ public:
+ static constexpr char ID = 0; // NOLINT
+
+ explicit UnrankedTensorType(tfrt::DType element_type)
+ : element_type_(element_type) {}
+
+ tfrt::DType element_type() const { return element_type_; }
+
+ llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
+
+ private:
+ tfrt::DType element_type_;
+};
+
+//===----------------------------------------------------------------------===//
+// Ranked Memref type corresponding to the mlir::MemrefType.
+//===----------------------------------------------------------------------===//
+
+class MemrefType : public llvm::RTTIExtends<MemrefType, Type> {
+ public:
+ static constexpr char ID = 0; // NOLINT
+ static constexpr int64_t kDynamicSize = -1;
+
+ MemrefType(llvm::ArrayRef<int64_t> sizes, tfrt::DType element_type)
+ : sizes_(sizes.begin(), sizes.end()), element_type_(element_type) {}
+
+ llvm::ArrayRef<int64_t> sizes() const { return sizes_; }
+ unsigned rank() const { return sizes_.size(); }
+ tfrt::DType element_type() const { return element_type_; }
+
+ llvm::ErrorOr<ArgumentAbi> AsArgument() const final;
+ llvm::ErrorOr<ResultAbi> AsResult() const final;
+
+ llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
+
+ private:
+ llvm::SmallVector<int64_t> sizes_;
+ tfrt::DType element_type_;
+};
+
+//===----------------------------------------------------------------------===//
+// Unranked Memref type corresponding to the mlir::UnrankedMemrefType.
+//===----------------------------------------------------------------------===//
+
+class UnrankedMemrefType : public llvm::RTTIExtends<UnrankedMemrefType, Type> {
+ public:
+ static constexpr char ID = 0; // NOLINT
+
+ explicit UnrankedMemrefType(tfrt::DType element_type)
+ : element_type_(element_type) {}
+
+ tfrt::DType element_type() const { return element_type_; }
+
+ llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
+
+ private:
+ tfrt::DType element_type_;
+};
+
+//===----------------------------------------------------------------------===//
+// Corresponds to the RT dialect's KernelContextType.
+//===----------------------------------------------------------------------===//
+
+class KernelContextOperandType
+ : public llvm::RTTIExtends<KernelContextOperandType, Type> {
+ public:
+ static constexpr char ID = 0; // NOLINT
+
+ llvm::ErrorOr<ArgumentAbi> AsArgument() const final;
+
+ llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
+};
+
+//===----------------------------------------------------------------------===//
+// Compiled function signature type corresponding to the mlir::FunctionType.
+//===----------------------------------------------------------------------===//
+
+class FunctionType {
+ public:
+ const Type* operand(unsigned index) const { return operands_[index].get(); }
+ const Type* result(unsigned index) const { return results_[index].get(); }
+
+ unsigned num_operands() const { return operands_.size(); }
+ unsigned num_results() const { return results_.size(); }
+
+ FunctionType(llvm::SmallVector<std::unique_ptr<Type>> operands,
+ llvm::SmallVector<std::unique_ptr<Type>> results)
+ : operands_(std::move(operands)), results_(std::move(results)) {}
+
+ private:
+ llvm::SmallVector<std::unique_ptr<Type>> operands_;
+ llvm::SmallVector<std::unique_ptr<Type>> results_;
+};
+
+} // namespace runtime
+} // namespace xla
+
+#endif // XLA_RUNTIME_TYPES_H_