blob: cd023b39863606b01c2d024a8ea98f63a8d46e34 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/runtime/types.h"
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
namespace xla {
namespace runtime {
using llvm::ArrayRef;
using llvm::raw_ostream;
//===----------------------------------------------------------------------===//
// Pretty printing for canonical types.
//===----------------------------------------------------------------------===//
static raw_ostream& operator<<(raw_ostream& os, const ArrayRef<int64_t>& arr) {
auto str = llvm::map_range(arr, [](int64_t i) { return std::to_string(i); });
return os << llvm::join(str, "x") << (arr.empty() ? "" : "x");
}
raw_ostream& AsyncTokenType::print(raw_ostream& os) const {
return os << "!async.token";
}
raw_ostream& AsyncValueType::print(raw_ostream& os) const {
return os << "!async.value<" << value_type() << ">";
}
raw_ostream& RankedTensorType::print(raw_ostream& os) const {
return os << "tensor<" << sizes() << element_type() << ">";
}
raw_ostream& UnrankedTensorType::print(raw_ostream& os) const {
return os << "tensor<*x" << element_type() << ">";
}
raw_ostream& MemrefType::print(raw_ostream& os) const {
return os << "memref<" << sizes() << element_type() << ">";
}
raw_ostream& UnrankedMemrefType::print(raw_ostream& os) const {
return os << "memref<*x" << element_type() << ">";
}
raw_ostream& KernelContextOperandType::print(raw_ostream& os) const {
return os << "!rt.kernel_context";
}
//===----------------------------------------------------------------------===//
// ABI definition for canonical types.
//===----------------------------------------------------------------------===//
using ArgumentAbi = Type::ArgumentAbi;
using ResultAbi = Type::ResultAbi;
// Async token returned as a pointer to the runtime async token.
llvm::ErrorOr<ResultAbi> AsyncTokenType::AsResult() const {
return ResultAbi{sizeof(void*)};
}
// Async value returned as a pointer to the runtime async token.
llvm::ErrorOr<ResultAbi> AsyncValueType::AsResult() const {
return ResultAbi{sizeof(void*)};
}
// Memref passed as an unrolled strided memref type.
llvm::ErrorOr<ArgumentAbi> MemrefType::AsArgument() const {
return ArgumentAbi{3 + 2 * rank()};
}
// TODO(ezhulenev): We should query the size of the `StridedMemrefType`
// directly, however it introduces dependency on the MLIR C runner utils.
//
// Memrefs are returned as StridedMemref<T, rank> type:
// basePtr, data, offset, sizes[rank], strides[rank]
llvm::ErrorOr<ResultAbi> MemrefType::AsResult() const {
return ResultAbi{
sizeof(void*) * 2 + // pointers
sizeof(int64_t) + // offset
sizeof(int64_t) * 2 * rank() // sizes and strides
};
}
// Kernel context passed as a single opaque pointer.
llvm::ErrorOr<ArgumentAbi> KernelContextOperandType::AsArgument() const {
return ArgumentAbi{1};
}
} // namespace runtime
} // namespace xla