[xla:jitrt] Encode FFT type (MHLO attr) as XLA enum to custom call.

PiperOrigin-RevId: 462470264
diff --git a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/fft_to_jitrt.mlir b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/fft_to_jitrt.mlir
index dd685c8..3e62156 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/fft_to_jitrt.mlir
+++ b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/fft_to_jitrt.mlir
@@ -23,7 +23,7 @@
 
   // CHECK: call @[[FFT:.*]](%[[ARG0]], %[[ARG1]])
   // CHECK-SAME: fft_length = dense<[16, 8]> : tensor<2xi64>
-  // CHECK-SAME: fft_type = 3 : i32
+  // CHECK-SAME: fft_type = #mhlo<fft_type IRFFT>
   "lmhlo.fft"(%arg0, %arg1) {
     fft_length = dense<[16, 8]> : tensor<2xi64>,
     fft_type = #mhlo<fft_type IRFFT>
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc
index bebcd7a..ff81387 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc
@@ -961,8 +961,7 @@
 
     // Copy backend specific attributes.
     call->setAttr(b.getStringAttr("fft_length"), op.getFftLengthAttr());
-    call->setAttr(b.getStringAttr("fft_type"),
-                  b.getI32IntegerAttr(static_cast<int32_t>(op.getFftType())));
+    call->setAttr(b.getStringAttr("fft_type"), op.getFftTypeAttr());
 
     // Erase the original Fft operation.
     rewriter.eraseOp(op);
diff --git a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
index c16c5be..3bf5d87 100644
--- a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
+++ b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
@@ -83,6 +83,8 @@
 
 namespace se = ::stream_executor;
 namespace jitrt = ::tfrt::jitrt;
+namespace lmhlo_gpu = ::mlir::lmhlo_gpu;
+namespace mhlo = ::mlir::mhlo;
 namespace runtime = ::tfrt::jitrt::runtime;
 
 // Disable all CustomCall checks in optimized build.
@@ -98,12 +100,29 @@
 
 void PopulateLmhloToXlaAttrEncoding(
     jitrt::CustomCallAttrEncodingSet& encoding) {
-  encoding.Add<jitrt::EnumAttrEncoding<mlir::lmhlo_gpu::ActivationAttr,
-                                       mlir::lmhlo_gpu::Activation,
-                                       se::dnn::ActivationMode>>(
-      [](mlir::lmhlo_gpu::Activation value) -> se::dnn::ActivationMode {
+  encoding.Add<
+      jitrt::EnumAttrEncoding<lmhlo_gpu::ActivationAttr, lmhlo_gpu::Activation,
+                              se::dnn::ActivationMode>>(
+      [](lmhlo_gpu::Activation value) -> se::dnn::ActivationMode {
         return ConvertConvActivationMode(value).value();
       });
+
+  encoding.Add<
+      jitrt::EnumAttrEncoding<mhlo::FftTypeAttr, mhlo::FftType, se::fft::Type>>(
+      [](mhlo::FftType value) -> se::fft::Type {
+        switch (value) {
+          case mhlo::FftType::FFT:
+            return se::fft::Type::kC2CForward;
+          case mhlo::FftType::IFFT:
+            return se::fft::Type::kC2CInverse;
+          case mhlo::FftType::RFFT:
+            return se::fft::Type::kR2C;
+          case mhlo::FftType::IRFFT:
+            return se::fft::Type::kC2R;
+          default:
+            return se::fft::Type::kInvalid;
+        }
+      });
 }
 
 // -------------------------------------------------------------------------- //
@@ -1088,7 +1107,7 @@
                            jitrt::StridedMemrefView input,
                            jitrt::StridedMemrefView output,
                            ArrayRef<int64_t> fft_length,
-                           int32_t fft_type) const;
+                           se::fft::Type fft_type) const;
   static Fft Handler() { return Fft(); }
 };
 }  // namespace
@@ -1097,42 +1116,37 @@
                               jitrt::StridedMemrefView input,
                               jitrt::StridedMemrefView output,
                               ArrayRef<int64_t> fft_length,
-                              int32_t fft_type) const {
+                              se::fft::Type fft_type) const {
   // TODO(ezhulenev): Cache FFT plans in the GpuExecutable.
   FftPlanCache fft_plan_cache;
 
   se::Stream* stream = run_options->stream();
   se::StreamExecutor* executor = stream->parent();
 
-  // TODO(ezhulenev): Compiler pass should pass fft type to the custom call.
-  bool double_precision =
-      input.dtype == tfrt::DType::F64 || input.dtype == tfrt::DType::Complex128;
-
-  // TODO(b/234085769): Lmhlo to JitRt lowering pass should pass Xla Fft type to
-  // the custom call.
-  se::fft::Type fft = [&] {
-    // See mlir::mhlo::FftType enum.
+  if (input.dtype == tfrt::DType::F64 ||
+      input.dtype == tfrt::DType::Complex128) {
+    // Adjust FFT type to reflect double precision.
     switch (fft_type) {
-      case 0:  // FFT
-        return double_precision ? se::fft::Type::kZ2ZForward
-                                : se::fft::Type::kC2CForward;
-      case 1:  // IFFT
-        return double_precision ? se::fft::Type::kZ2ZInverse
-                                : se::fft::Type::kC2CInverse;
-      case 2:  // RFFT
-        return double_precision ? se::fft::Type::kD2Z : se::fft::Type::kR2C;
-      case 3:  // IRFFT
-        return double_precision ? se::fft::Type::kZ2D : se::fft::Type::kC2R;
+      case se::fft::Type::kC2CForward:
+        fft_type = se::fft::Type::kZ2ZForward;
+        break;
+      case se::fft::Type::kC2CInverse:
+        fft_type = se::fft::Type::kZ2ZInverse;
+        break;
+      case se::fft::Type::kR2C:
+        fft_type = se::fft::Type::kD2Z;
+        break;
+      case se::fft::Type::kC2R:
+        fft_type = se::fft::Type::kZ2D;
+        break;
       default:
-        return se::fft::Type::kInvalid;
+        return failure();
     }
-  }();
-
-  if (fft == se::fft::Type::kInvalid) return failure();
+  }
 
   auto st =
       RunFft(GetDeviceAddress(input), ToShape(input), GetDeviceAddress(output),
-             ToShape(output), fft, fft_length, executor->device_ordinal(),
+             ToShape(output), fft_type, fft_length, executor->device_ordinal(),
              &fft_plan_cache, stream, run_options->allocator());
   if (!st.ok()) return failure();
 
@@ -1145,7 +1159,7 @@
                              .Arg<jitrt::StridedMemrefView>()  // input
                              .Arg<jitrt::StridedMemrefView>()  // output
                              .Attr<ArrayRef<int64_t>>("fft_length")
-                             .Attr<int32_t>("fft_type")
+                             .Attr<se::fft::Type>("fft_type")
                              .To<RuntimeChecks()>(Fft::Handler())
                              .release();
 
diff --git a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h
index 7abb23a..6af4bc7 100644
--- a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h
+++ b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h
@@ -39,6 +39,7 @@
 namespace tfrt {
 namespace jitrt {
 JITRT_REGISTER_ENUM_ATTR_DECODING(stream_executor::dnn::ActivationMode);
+JITRT_REGISTER_ENUM_ATTR_DECODING(stream_executor::fft::Type);
 }  // namespace jitrt
 }  // namespace tfrt