blob: 73c1d596c33724cfcde8c273487016af1fbe5516 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_RUNTIME_TYPES_H_
#define XLA_RUNTIME_TYPES_H_
#include <functional>
#include <memory>
#include <utility>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include "tfrt/dtype/dtype.h" // from @tf_runtime
namespace xla {
namespace runtime {
//===----------------------------------------------------------------------===//
// Canonical XLA runtime types for the executable arguments.
//===----------------------------------------------------------------------===//
// Types supported by the compiled function signature. We do rely on the LLVM
// style RTTI (https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html) to avoid
// dependency on the MLIR types at runtime, because we don't want to depend
// on any of the compiler implementation details at runtime and we want to
// support lightweight loading and execution of AOT compiled programs.
//
// We rely on the RTTI for the open class hierarchies, because we want to allow
// users to define their own types for the arguments.
//
// If the type can be passed to the compiled function as an argument or returned
// as a result, it must define its own ABI. The ABI is defined by the MLIR to
// LLVM lowering pipeline and the runtime integration (see `runtime.h`).
class Type : public llvm::RTTIExtends<Type, llvm::RTTIRoot> {
public:
static constexpr char ID = 0; // NOLINT
// Arguments to compiled functions passed as a set of pointers. For example
// memref descriptor passed in as a set of pointers to data, sizes and
// strides. See `Argument::Pack` implementation for details (in `argument.h`).
struct ArgumentAbi {
size_t num_ptrs;
};
// Compiled function returns results by writing into the pre-allocated storage
// of the given size with the requested alignment. Runtime pre-allocates
// memory required for all results in the call frame.
struct ResultAbi {
size_t size;
// TODO(ezhulenev): Add alignment to the result ABI. Alignment is an
// important part of the result ABI that we ignore today. It all doesn't
// crash only because all results happen to have a size that is multiple of
// 8 bytes, and because of that all of the results are properly aligned.
// Results memory layout in the call frame should take in account base
// pointer alignment and alignment requirements of all results.
};
// Returns an Abi if the type can be used as an argument.
virtual llvm::ErrorOr<ArgumentAbi> AsArgument() const {
return llvm::errc::not_supported;
}
// Returns an Abi if the type can be returned as a result.
virtual llvm::ErrorOr<ResultAbi> AsResult() const {
return llvm::errc::not_supported;
}
virtual llvm::raw_ostream& print(llvm::raw_ostream& os) const = 0;
protected:
Type() = default;
};
inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const Type& type) {
return type.print(os);
}
//===----------------------------------------------------------------------===//
// Async Token type corresponding to the mlir::async::TokenType
//===----------------------------------------------------------------------===//
class AsyncTokenType : public llvm::RTTIExtends<AsyncTokenType, Type> {
public:
static constexpr char ID = 0; // NOLINT
llvm::ErrorOr<ResultAbi> AsResult() const final;
llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
};
//===----------------------------------------------------------------------===//
// Async Value type corresponding to the mlir::async::ValueType.
//===----------------------------------------------------------------------===//
class AsyncValueType : public llvm::RTTIExtends<AsyncValueType, Type> {
public:
static constexpr char ID = 0; // NOLINT
explicit AsyncValueType(std::unique_ptr<Type> value_type)
: value_type_(std::move(value_type)) {}
const Type& value_type() const { return *value_type_; }
llvm::ErrorOr<ResultAbi> AsResult() const final;
llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
private:
std::unique_ptr<Type> value_type_;
};
//===----------------------------------------------------------------------===//
// Ranked Tensor type corresponding to the mlir::RankedTensorType.
//===----------------------------------------------------------------------===//
class RankedTensorType : public llvm::RTTIExtends<RankedTensorType, Type> {
public:
static constexpr char ID = 0; // NOLINT
static constexpr int64_t kDynamicSize = -1;
RankedTensorType(llvm::ArrayRef<int64_t> sizes, tfrt::DType element_type)
: sizes_(sizes.begin(), sizes.end()), element_type_(element_type) {}
llvm::ArrayRef<int64_t> sizes() const { return sizes_; }
unsigned rank() const { return sizes_.size(); }
tfrt::DType element_type() const { return element_type_; }
llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
private:
llvm::SmallVector<int64_t> sizes_;
tfrt::DType element_type_;
};
//===----------------------------------------------------------------------===//
// Unranked Tensor type corresponding to the mlir::UnrankedTensorType.
//===----------------------------------------------------------------------===//
class UnrankedTensorType : public llvm::RTTIExtends<UnrankedTensorType, Type> {
public:
static constexpr char ID = 0; // NOLINT
explicit UnrankedTensorType(tfrt::DType element_type)
: element_type_(element_type) {}
tfrt::DType element_type() const { return element_type_; }
llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
private:
tfrt::DType element_type_;
};
//===----------------------------------------------------------------------===//
// Ranked Memref type corresponding to the mlir::MemrefType.
//===----------------------------------------------------------------------===//
class MemrefType : public llvm::RTTIExtends<MemrefType, Type> {
public:
static constexpr char ID = 0; // NOLINT
static constexpr int64_t kDynamicSize = -1;
MemrefType(llvm::ArrayRef<int64_t> sizes, tfrt::DType element_type)
: sizes_(sizes.begin(), sizes.end()), element_type_(element_type) {}
llvm::ArrayRef<int64_t> sizes() const { return sizes_; }
unsigned rank() const { return sizes_.size(); }
tfrt::DType element_type() const { return element_type_; }
llvm::ErrorOr<ArgumentAbi> AsArgument() const final;
llvm::ErrorOr<ResultAbi> AsResult() const final;
llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
private:
llvm::SmallVector<int64_t> sizes_;
tfrt::DType element_type_;
};
//===----------------------------------------------------------------------===//
// Unranked Memref type corresponding to the mlir::UnrankedMemrefType.
//===----------------------------------------------------------------------===//
class UnrankedMemrefType : public llvm::RTTIExtends<UnrankedMemrefType, Type> {
public:
static constexpr char ID = 0; // NOLINT
explicit UnrankedMemrefType(tfrt::DType element_type)
: element_type_(element_type) {}
tfrt::DType element_type() const { return element_type_; }
llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
private:
tfrt::DType element_type_;
};
//===----------------------------------------------------------------------===//
// Corresponds to the RT dialect's KernelContextType.
//===----------------------------------------------------------------------===//
class KernelContextOperandType
: public llvm::RTTIExtends<KernelContextOperandType, Type> {
public:
static constexpr char ID = 0; // NOLINT
llvm::ErrorOr<ArgumentAbi> AsArgument() const final;
llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
};
//===----------------------------------------------------------------------===//
// Compiled function signature type corresponding to the mlir::FunctionType.
//===----------------------------------------------------------------------===//
class FunctionType {
public:
const Type* operand(unsigned index) const { return operands_[index].get(); }
const Type* result(unsigned index) const { return results_[index].get(); }
unsigned num_operands() const { return operands_.size(); }
unsigned num_results() const { return results_.size(); }
FunctionType(llvm::SmallVector<std::unique_ptr<Type>> operands,
llvm::SmallVector<std::unique_ptr<Type>> results)
: operands_(std::move(operands)), results_(std::move(results)) {}
private:
llvm::SmallVector<std::unique_ptr<Type>> operands_;
llvm::SmallVector<std::unique_ptr<Type>> results_;
};
} // namespace runtime
} // namespace xla
#endif // XLA_RUNTIME_TYPES_H_