blob: 975e41e83084da5706dd5956e074e3305f93797f [file] [log] [blame]
// Copyright 2022 The TensorFlow Runtime Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Pattern to lower lmhlo.fft op to tfrt dialect.
#include <cstdint>
#include <functional>
#include <numeric>
#include <string>
#include <utility>
#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/ScopedPrinter.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/pattern_utils.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tfrt/gpu/kernels/gpu_ops.h" // from @tf_runtime
#include "tfrt/gpu/passes/passes.h" // from @tf_runtime
#include "tfrt/gpu/wrapper/cufft_wrapper.h" // from @tf_runtime
#include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime
namespace tensorflow {
static llvm::Expected<tfrt::gpu::wrapper::FftType> GetFftType(
llvm::StringRef type, bool double_precision) {
llvm::Expected<int> value =
llvm::StringSwitch<llvm::Expected<int>>(type)
.Case("FFT", double_precision ? CUFFT_Z2Z : CUFFT_C2C)
.Case("IFFT", double_precision ? CUFFT_Z2Z : CUFFT_C2C)
.Case("RFFT", double_precision ? CUFFT_D2Z : CUFFT_R2C)
.Case("IRFFT", double_precision ? CUFFT_Z2D : CUFFT_C2R)
.Default(tfrt::MakeStringError("Unsupported FFT type: ", type));
if (!value) return value.takeError();
return tfrt::gpu::wrapper::FftType(*value, kGpuTargetPlatform);
}
static llvm::Expected<tfrt::gpu::wrapper::FftDirection> GetFftDirection(
llvm::StringRef type) {
llvm::Expected<int> value =
llvm::StringSwitch<llvm::Expected<int>>(type)
.Case("FFT", CUFFT_FORWARD)
.Case("IFFT", CUFFT_INVERSE)
.Case("RFFT", CUFFT_FORWARD)
.Case("IRFFT", CUFFT_INVERSE)
.Default(tfrt::MakeStringError("Unsupported FFT type: ", type));
if (!value) return value.takeError();
return tfrt::gpu::wrapper::FftDirection(*value, kGpuTargetPlatform);
}
namespace {
struct FftRewritePattern
: tfrt::gpu::StreamifyOpConversionPattern<lmhlo::FftOp> {
using tfrt::gpu::StreamifyOpConversionPattern<lmhlo::FftOp>::OpAdaptor;
using tfrt::gpu::StreamifyOpConversionPattern<
lmhlo::FftOp>::StreamifyOpConversionPattern;
FailureOr<Value> matchAndRewriteOp(
lmhlo::FftOp op, OpAdaptor adaptor, Value chain, Value stream,
ConversionPatternRewriter& rewriter) const override {
xla::Shape input_shape = xla::gpu::GetShape(op.getOperand());
xla::Shape output_shape = xla::gpu::GetShape(op.getOutput());
if (input_shape.is_dynamic() || output_shape.is_dynamic())
return rewriter.notifyMatchFailure(op, "expected static shapes");
if (!xla::LayoutUtil::IsMonotonicWithDim0Major(input_shape.layout()) ||
!xla::LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout())) {
return rewriter.notifyMatchFailure(op, "expected dense row-major");
}
bool double_precision = input_shape.element_type() == xla::F64 ||
input_shape.element_type() == xla::C128;
auto type = GetFftType(mlir::mhlo::stringifyFftType(adaptor.getFftType()),
double_precision);
auto direction =
GetFftDirection(mlir::mhlo::stringifyFftType(adaptor.getFftType()));
if (!type || !direction) {
auto error = joinErrors(type.takeError(), direction.takeError());
return rewriter.notifyMatchFailure(op, llvm::toString(std::move(error)));
}
llvm::SmallVector<int64_t, 3> dimensions;
llvm::copy(op.getFftLength().getValues<int64_t>(),
std::back_inserter(dimensions));
int rank = dimensions.size();
auto batch_dims = input_shape.dimensions();
uint64_t batch =
std::accumulate(batch_dims.begin(), batch_dims.end() - rank, 1,
std::multiplies<int64_t>());
auto get_strides = [](absl::Span<const int64_t> dims) {
llvm::SmallVector<int64_t, 4> strides(dims.size() + 1, 1);
std::partial_sum(dims.rbegin(), dims.rend(), strides.rbegin() + 1,
std::multiplies<int64_t>());
return strides;
};
llvm::SmallVector<int64_t, 4> input_strides =
get_strides(input_shape.dimensions().last(rank));
llvm::SmallVector<int64_t, 4> output_strides =
get_strides(output_shape.dimensions().last(rank));
mlir::Location loc = op->getLoc();
Value context = rewriter.create<tfrt::gpu::StreamGetContextOp>(loc, stream);
auto fft_handle = rewriter.create<tfrt::gpu::FftCreateOp>(
loc, context, *type, batch, rewriter.getI64ArrayAttr(dimensions),
rewriter.getI64ArrayAttr(input_strides),
rewriter.getI64ArrayAttr(output_strides));
// Note: we could determine the workspace size during lowering similar to
// convolutions because the dimensions are static. But it's unclear if we
// really want the compiler to depend on cuFFT/hipFFT, and the expensive
// part is the allocation, which is currently not hoisted.
mlir::Value workspace_size =
rewriter.create<tfrt::gpu::FftGetWorkspaceSizeOp>(loc, fft_handle);
mlir::Value allocator =
rewriter.create<tfrt::gpu::AllocatorCreateOp>(loc, context);
mlir::Value workspace = rewriter.create<tfrt::gpu::MemAllocateOp>(
loc, allocator, stream, workspace_size, chain);
chain = rewriter.create<tfrt::gpu::FftExecuteOp>(
loc, stream, fft_handle, adaptor.getOperand(), adaptor.getOutput(),
workspace, *direction, chain);
rewriter.eraseOp(op);
if (*direction ==
tfrt::gpu::wrapper::FftDirection(CUFFT_FORWARD, kGpuTargetPlatform)) {
return chain;
}
// CUDA/HIP inverse FFT is un-normalized, e.g. see
// https://docs.nvidia.com/cuda/cufft/index.html#cufft-transform-directions
// So in the inverse case we must manually normalize by scaling by the
// inverse of the total number of FFT samples.
int64_t elements_per_batch = std::accumulate(
dimensions.begin(), dimensions.end(), 1, std::multiplies<int64_t>());
int64_t total_num_elements = elements_per_batch * batch;
auto mlir_element_type =
op.getOutput().getType().cast<mlir::MemRefType>().getElementType();
// If the FFT output elements are complex numbers, treat the output as
// an array of twice as many real numbers so we can save compute by
// scaling in the real domain.
if (auto complex_type = mlir_element_type.dyn_cast<ComplexType>()) {
total_num_elements *= 2;
mlir_element_type = complex_type.getElementType();
}
auto n =
rewriter.create<tfrt::compiler::ConstantI32Op>(loc, total_num_elements);
auto scaling_factor = MakeScalingFactorConstant(
rewriter, loc, mlir_element_type,
/*value_real=*/llvm::APFloat(1.0f / elements_per_batch),
/*value_imaginary=*/llvm::APFloat(0.0f));
// This assumes that the stride of the FFT output is always 1.
auto stride = rewriter.create<tfrt::compiler::ConstantI32Op>(loc, 1);
auto blas_handle = rewriter.create<tfrt::gpu::BlasCreateOp>(loc, context);
auto blas_element_type = MlirTypeToBlasDataType(mlir_element_type);
chain = rewriter.create<tfrt::gpu::BlasScalOp>(
loc, chain.getType(), blas_handle, stream, n, scaling_factor,
blas_element_type, adaptor.getOutput(), blas_element_type, stride,
blas_element_type, chain);
return chain;
}
};
} // namespace
void populateFftConversionPattern(RewritePatternSet& patterns,
TypeConverter& converter) {
patterns.add<FftRewritePattern>(converter, patterns.getContext());
}
} // namespace tensorflow