[xla:jitrt] Encode FFT type (MHLO attr) as XLA enum to custom call.
PiperOrigin-RevId: 462470264
diff --git a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/fft_to_jitrt.mlir b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/fft_to_jitrt.mlir
index dd685c8..3e62156 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/fft_to_jitrt.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/fft_to_jitrt.mlir
@@ -23,7 +23,7 @@
// CHECK: call @[[FFT:.*]](%[[ARG0]], %[[ARG1]])
// CHECK-SAME: fft_length = dense<[16, 8]> : tensor<2xi64>
- // CHECK-SAME: fft_type = 3 : i32
+ // CHECK-SAME: fft_type = #mhlo<fft_type IRFFT>
"lmhlo.fft"(%arg0, %arg1) {
fft_length = dense<[16, 8]> : tensor<2xi64>,
fft_type = #mhlo<fft_type IRFFT>
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc
index bebcd7a..ff81387 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc
@@ -961,8 +961,7 @@
// Copy backend specific attributes.
call->setAttr(b.getStringAttr("fft_length"), op.getFftLengthAttr());
- call->setAttr(b.getStringAttr("fft_type"),
- b.getI32IntegerAttr(static_cast<int32_t>(op.getFftType())));
+ call->setAttr(b.getStringAttr("fft_type"), op.getFftTypeAttr());
// Erase the original Fft operation.
rewriter.eraseOp(op);
diff --git a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
index c16c5be..3bf5d87 100644
--- a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
+++ b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
@@ -83,6 +83,8 @@
namespace se = ::stream_executor;
namespace jitrt = ::tfrt::jitrt;
+namespace lmhlo_gpu = ::mlir::lmhlo_gpu;
+namespace mhlo = ::mlir::mhlo;
namespace runtime = ::tfrt::jitrt::runtime;
// Disable all CustomCall checks in optimized build.
@@ -98,12 +100,29 @@
void PopulateLmhloToXlaAttrEncoding(
jitrt::CustomCallAttrEncodingSet& encoding) {
- encoding.Add<jitrt::EnumAttrEncoding<mlir::lmhlo_gpu::ActivationAttr,
- mlir::lmhlo_gpu::Activation,
- se::dnn::ActivationMode>>(
- [](mlir::lmhlo_gpu::Activation value) -> se::dnn::ActivationMode {
+ encoding.Add<
+ jitrt::EnumAttrEncoding<lmhlo_gpu::ActivationAttr, lmhlo_gpu::Activation,
+ se::dnn::ActivationMode>>(
+ [](lmhlo_gpu::Activation value) -> se::dnn::ActivationMode {
return ConvertConvActivationMode(value).value();
});
+
+ encoding.Add<
+ jitrt::EnumAttrEncoding<mhlo::FftTypeAttr, mhlo::FftType, se::fft::Type>>(
+ [](mhlo::FftType value) -> se::fft::Type {
+ switch (value) {
+ case mhlo::FftType::FFT:
+ return se::fft::Type::kC2CForward;
+ case mhlo::FftType::IFFT:
+ return se::fft::Type::kC2CInverse;
+ case mhlo::FftType::RFFT:
+ return se::fft::Type::kR2C;
+ case mhlo::FftType::IRFFT:
+ return se::fft::Type::kC2R;
+ default:
+ return se::fft::Type::kInvalid;
+ }
+ });
}
// -------------------------------------------------------------------------- //
@@ -1088,7 +1107,7 @@
jitrt::StridedMemrefView input,
jitrt::StridedMemrefView output,
ArrayRef<int64_t> fft_length,
- int32_t fft_type) const;
+ se::fft::Type fft_type) const;
static Fft Handler() { return Fft(); }
};
} // namespace
@@ -1097,42 +1116,37 @@
jitrt::StridedMemrefView input,
jitrt::StridedMemrefView output,
ArrayRef<int64_t> fft_length,
- int32_t fft_type) const {
+ se::fft::Type fft_type) const {
// TODO(ezhulenev): Cache FFT plans in the GpuExecutable.
FftPlanCache fft_plan_cache;
se::Stream* stream = run_options->stream();
se::StreamExecutor* executor = stream->parent();
- // TODO(ezhulenev): Compiler pass should pass fft type to the custom call.
- bool double_precision =
- input.dtype == tfrt::DType::F64 || input.dtype == tfrt::DType::Complex128;
-
- // TODO(b/234085769): Lmhlo to JitRt lowering pass should pass Xla Fft type to
- // the custom call.
- se::fft::Type fft = [&] {
- // See mlir::mhlo::FftType enum.
+ if (input.dtype == tfrt::DType::F64 ||
+ input.dtype == tfrt::DType::Complex128) {
+ // Adjust FFT type to reflect double precision.
switch (fft_type) {
- case 0: // FFT
- return double_precision ? se::fft::Type::kZ2ZForward
- : se::fft::Type::kC2CForward;
- case 1: // IFFT
- return double_precision ? se::fft::Type::kZ2ZInverse
- : se::fft::Type::kC2CInverse;
- case 2: // RFFT
- return double_precision ? se::fft::Type::kD2Z : se::fft::Type::kR2C;
- case 3: // IRFFT
- return double_precision ? se::fft::Type::kZ2D : se::fft::Type::kC2R;
+ case se::fft::Type::kC2CForward:
+ fft_type = se::fft::Type::kZ2ZForward;
+ break;
+ case se::fft::Type::kC2CInverse:
+ fft_type = se::fft::Type::kZ2ZInverse;
+ break;
+ case se::fft::Type::kR2C:
+ fft_type = se::fft::Type::kD2Z;
+ break;
+ case se::fft::Type::kC2R:
+ fft_type = se::fft::Type::kZ2D;
+ break;
default:
- return se::fft::Type::kInvalid;
+ return failure();
}
- }();
-
- if (fft == se::fft::Type::kInvalid) return failure();
+ }
auto st =
RunFft(GetDeviceAddress(input), ToShape(input), GetDeviceAddress(output),
- ToShape(output), fft, fft_length, executor->device_ordinal(),
+ ToShape(output), fft_type, fft_length, executor->device_ordinal(),
&fft_plan_cache, stream, run_options->allocator());
if (!st.ok()) return failure();
@@ -1145,7 +1159,7 @@
.Arg<jitrt::StridedMemrefView>() // input
.Arg<jitrt::StridedMemrefView>() // output
.Attr<ArrayRef<int64_t>>("fft_length")
- .Attr<int32_t>("fft_type")
+ .Attr<se::fft::Type>("fft_type")
.To<RuntimeChecks()>(Fft::Handler())
.release();
diff --git a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h
index 7abb23a..6af4bc7 100644
--- a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h
+++ b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h
@@ -39,6 +39,7 @@
namespace tfrt {
namespace jitrt {
JITRT_REGISTER_ENUM_ATTR_DECODING(stream_executor::dnn::ActivationMode);
+JITRT_REGISTER_ENUM_ATTR_DECODING(stream_executor::fft::Type);
} // namespace jitrt
} // namespace tfrt