/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/stream_executor/rocm/rocm_fft.h"

#include <complex>

#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"

namespace stream_executor {
namespace gpu {

PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocFftPlugin);

namespace wrap {

#ifdef PLATFORM_GOOGLE
// This macro wraps a global identifier, given by __name, in a callable
// structure that loads the DLL symbol out of the DSO handle in a thread-safe
// manner on first use. This dynamic loading technique is used to avoid DSO
// dependencies on vendor libraries which may or may not be available in the
// deployed binary environment.
#define STREAM_EXECUTOR_ROCFFT_WRAP(__name)                      \
  struct WrapperShim__##__name {                                 \
    template <typename... Args>                                  \
    hipfftResult operator()(GpuExecutor *parent, Args... args) { \
      gpu::ScopedActivateExecutorContext sac{parent};            \
      return ::__name(args...);                                  \
    }                                                            \
  } __name;

#else

#define STREAM_EXECUTOR_ROCFFT_WRAP(__name)                               \
  struct DynLoadShim__##__name {                                          \
    static const char *kName;                                             \
    using FuncPtrT = std::add_pointer<decltype(::__name)>::type;          \
    static void *GetDsoHandle() {                                         \
      auto s = internal::CachedDsoLoader::GetRocfftDsoHandle();           \
      return s.ValueOrDie();                                              \
    }                                                                     \
    static FuncPtrT LoadOrDie() {                                         \
      void *f;                                                            \
      auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
                                                          kName, &f);     \
      CHECK(s.ok()) << "could not find " << kName                         \
                    << " in rocfft DSO; dlerror: " << s.error_message();  \
      return reinterpret_cast<FuncPtrT>(f);                               \
    }                                                                     \
    static FuncPtrT DynLoad() {                                           \
      static FuncPtrT f = LoadOrDie();                                    \
      return f;                                                           \
    }                                                                     \
    template <typename... Args>                                           \
    hipfftResult operator()(GpuExecutor *parent, Args... args) {          \
      gpu::ScopedActivateExecutorContext sac{parent};                     \
      return DynLoad()(args...);                                          \
    }                                                                     \
  } __name;                                                               \
  const char *DynLoadShim__##__name::kName = #__name;

#endif

// clang-format off
#define ROCFFT_ROUTINE_EACH(__macro) \
  __macro(hipfftDestroy)             \
  __macro(hipfftSetStream)           \
  __macro(hipfftPlan1d)              \
  __macro(hipfftPlan2d)              \
  __macro(hipfftPlan3d)              \
  __macro(hipfftPlanMany)            \
  __macro(hipfftCreate)              \
  __macro(hipfftSetAutoAllocation)   \
  __macro(hipfftSetWorkArea)         \
  __macro(hipfftGetSize1d)           \
  __macro(hipfftMakePlan1d)          \
  __macro(hipfftGetSize2d)           \
  __macro(hipfftMakePlan2d)          \
  __macro(hipfftGetSize3d)           \
  __macro(hipfftMakePlan3d)          \
  __macro(hipfftGetSizeMany)         \
  __macro(hipfftMakePlanMany)        \
  __macro(hipfftExecD2Z)             \
  __macro(hipfftExecZ2D)             \
  __macro(hipfftExecC2C)             \
  __macro(hipfftExecC2R)             \
  __macro(hipfftExecZ2Z)             \
  __macro(hipfftExecR2C)

// clang-format on

ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP)

}  // namespace wrap

namespace {

// A helper function transforming gpu_fft arguments into rocFFT arguments.
hipfftType ROCMFftType(fft::Type type) {
  switch (type) {
    case fft::Type::kC2CForward:
    case fft::Type::kC2CInverse:
      return HIPFFT_C2C;
    case fft::Type::kC2R:
      return HIPFFT_C2R;
    case fft::Type::kR2C:
      return HIPFFT_R2C;
    case fft::Type::kZ2ZForward:
    case fft::Type::kZ2ZInverse:
      return HIPFFT_Z2Z;
    case fft::Type::kZ2D:
      return HIPFFT_Z2D;
    case fft::Type::kD2Z:
      return HIPFFT_D2Z;
    default:
      LOG(FATAL) << "Invalid value of fft::Type.";
  }
}

// Associates the given stream with the given rocFFT plan.
bool SetStream(GpuExecutor *parent, hipfftHandle plan, Stream *stream) {
  auto ret = wrap::hipfftSetStream(parent, plan, AsGpuStreamValue(stream));
  if (ret != HIPFFT_SUCCESS) {
    LOG(ERROR) << "failed to run rocFFT routine hipfftSetStream: " << ret;
    return false;
  }
  return true;
}

}  // namespace

port::Status ROCMFftPlan::Initialize(
    GpuExecutor *parent, Stream *stream, int rank, uint64 *elem_count,
    uint64 *input_embed, uint64 input_stride, uint64 input_distance,
    uint64 *output_embed, uint64 output_stride, uint64 output_distance,
    fft::Type type, int batch_count, ScratchAllocator *scratch_allocator) {
  if (IsInitialized()) {
    LOG(FATAL) << "Try to repeatedly initialize.";
  }
  is_initialized_ = true;
  int elem_count_[3], input_embed_[3], output_embed_[3];
  for (int i = 0; i < rank; ++i) {
    elem_count_[i] = elem_count[i];
    if (input_embed) {
      input_embed_[i] = input_embed[i];
    }
    if (output_embed) {
      output_embed_[i] = output_embed[i];
    }
  }
  parent_ = parent;
  fft_type_ = type;
  if (batch_count == 1 && input_embed == nullptr && output_embed == nullptr) {
    hipfftResult_t ret;
    if (scratch_allocator == nullptr) {
      switch (rank) {
        case 1:
          // hipfftPlan1d
          ret = wrap::hipfftPlan1d(parent, &plan_, elem_count_[0],
                                   ROCMFftType(type), 1 /* = batch */);
          if (ret != HIPFFT_SUCCESS) {
            LOG(ERROR) << "failed to create rocFFT 1d plan:" << ret;
            return port::Status{port::error::INTERNAL,
                                "Failed to create rocFFT 1d plan."};
          }
          return port::Status::OK();
        case 2:
          // hipfftPlan2d
          ret = wrap::hipfftPlan2d(parent, &plan_, elem_count_[0],
                                   elem_count_[1], ROCMFftType(type));
          if (ret != HIPFFT_SUCCESS) {
            LOG(ERROR) << "failed to create rocFFT 2d plan:" << ret;
            return port::Status{port::error::INTERNAL,
                                "Failed to create rocFFT 2d plan."};
          }
          return port::Status::OK();
        case 3:
          // hipfftPlan3d
          ret =
              wrap::hipfftPlan3d(parent, &plan_, elem_count_[0], elem_count_[1],
                                 elem_count_[2], ROCMFftType(type));
          if (ret != HIPFFT_SUCCESS) {
            LOG(ERROR) << "failed to create rocFFT 3d plan:" << ret;
            return port::Status{port::error::INTERNAL,
                                "Failed to create rocFFT 3d plan."};
          }
          return port::Status::OK();
        default:
          LOG(ERROR) << "Invalid rank value for hipfftPlan. "
                        "Requested 1, 2, or 3, given: "
                     << rank;
          return port::Status{port::error::INVALID_ARGUMENT,
                              "hipfftPlan only takes rank 1, 2, or 3."};
      }
    } else {
      ret = wrap::hipfftCreate(parent, &plan_);
      if (ret != HIPFFT_SUCCESS) {
        LOG(ERROR) << "failed to create rocFFT plan:" << ret;
        return port::Status{port::error::INTERNAL,
                            "Failed to create rocFFT plan."};
      }
      ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0);
      if (ret != HIPFFT_SUCCESS) {
        LOG(ERROR) << "failed to set auto allocation for rocFFT plan:" << ret;
        return port::Status{port::error::INTERNAL,
                            "Failed to set auto allocation for rocFFT plan."};
      }
      size_t size_in_bytes;
      switch (rank) {
        case 1:
          ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0],
                                       ROCMFftType(type), /*batch=*/1,
                                       &size_in_bytes);
          if (ret != HIPFFT_SUCCESS) {
            LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret;
            return port::Status{port::error::INTERNAL,
                                "Failed to make rocFFT 1d plan."};
          }
          break;
        case 2:
          ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0],
                                       elem_count_[1], ROCMFftType(type),
                                       &size_in_bytes);
          if (ret != HIPFFT_SUCCESS) {
            LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret;
            return port::Status{port::error::INTERNAL,
                                "Failed to make rocFFT 2d plan."};
          }
          break;
        case 3:
          ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0],
                                       elem_count_[1], elem_count_[2],
                                       ROCMFftType(type), &size_in_bytes);
          if (ret != HIPFFT_SUCCESS) {
            LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret;
            return port::Status{port::error::INTERNAL,
                                "Failed to make rocFFT 3d plan."};
          }
          break;
        default:
          LOG(ERROR) << "Invalid rank value for hipfftPlan. "
                        "Requested 1, 2, or 3, given: "
                     << rank;
          return port::Status{port::error::INVALID_ARGUMENT,
                              "hipfftPlan only takes rank 1, 2, or 3."};
      }
      // TODO(yangzihao): refactor this code and the one with the same function
      // in the batch mode.
      if (size_in_bytes != 0) {
        auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
        if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
          LOG(ERROR) << "failed to allocate work area.";
          return allocated.status();
        }
      }
      // Connect work area with allocated space.
      ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
      if (ret != HIPFFT_SUCCESS) {
        LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
        return port::Status{port::error::INTERNAL,
                            "Failed to set work area for rocFFT plan."};
      }
      return port::Status::OK();
    }
  } else {
    // For either multiple batches or rank higher than 3, use hipfftPlanMany().
    if (scratch_allocator == nullptr) {
      auto ret = wrap::hipfftPlanMany(
          parent, &plan_, rank, elem_count_,
          input_embed ? input_embed_ : nullptr, input_stride, input_distance,
          output_embed ? output_embed_ : nullptr, output_stride,
          output_distance, ROCMFftType(type), batch_count);
      if (ret != HIPFFT_SUCCESS) {
        LOG(ERROR) << "failed to create rocFFT batched plan:" << ret;
        return port::Status{port::error::INTERNAL,
                            "Failed to create rocFFT batched plan."};
      }
    } else {
      auto ret = wrap::hipfftCreate(parent, &plan_);
      if (ret != HIPFFT_SUCCESS) {
        LOG(ERROR) << "failed to create rocFFT batched plan:" << ret;
        return port::Status{port::error::INTERNAL,
                            "Failed to create rocFFT batched plan."};
      }
      ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0);
      if (ret != HIPFFT_SUCCESS) {
        LOG(ERROR) << "failed to set auto allocation for rocFFT batched plan:"
                   << ret;
        return port::Status{
            port::error::INTERNAL,
            "Failed to set auto allocation for rocFFT batched plan."};
      }
      size_t size_in_bytes;
      ret = wrap::hipfftMakePlanMany(
          parent, plan_, rank, elem_count_,
          input_embed ? input_embed_ : nullptr, input_stride, input_distance,
          output_embed ? output_embed_ : nullptr, output_stride,
          output_distance, ROCMFftType(type), batch_count, &size_in_bytes);
      if (ret != HIPFFT_SUCCESS) {
        LOG(ERROR) << "failed to make rocFFT batched plan:" << ret;
        return port::Status{port::error::INTERNAL,
                            "Failed to make rocFFT batched plan."};
      }
      if (size_in_bytes != 0) {
        auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
        if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
          LOG(ERROR) << "failed to allocate work area.";
          return allocated.status();
        }
      }
      // Connect work area with allocated space.
      ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
      if (ret != HIPFFT_SUCCESS) {
        LOG(ERROR) << "failed to set work area for rocFFT batched plan:" << ret;
        return port::Status{port::error::INTERNAL,
                            "Failed to set work area for rocFFT batched plan."};
      }
    }
  }
  return port::Status::OK();
}

port::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
                                     int rank, uint64 *elem_count,
                                     fft::Type type,
                                     ScratchAllocator *scratch_allocator) {
  return Initialize(parent_, stream, rank, elem_count,
                    /*input_embed=*/nullptr, /*input_stride=*/0,
                    /*input_distance=*/0,
                    /*output_embed=*/nullptr, /*output_stride=*/0,
                    /*output_distance=*/0, type, 1, scratch_allocator);
}

ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); }

int ROCMFftPlan::GetFftDirection() const {
  if (!IsInitialized()) {
    LOG(FATAL) << "Try to get fft direction before initialization.";
  } else {
    switch (fft_type_) {
      case fft::Type::kC2CForward:
      case fft::Type::kZ2ZForward:
      case fft::Type::kR2C:
      case fft::Type::kD2Z:
        return HIPFFT_FORWARD;
      case fft::Type::kC2CInverse:
      case fft::Type::kZ2ZInverse:
      case fft::Type::kC2R:
      case fft::Type::kZ2D:
        return HIPFFT_BACKWARD;
      default:
        LOG(FATAL) << "Invalid value of fft::Type.";
    }
  }
}

std::unique_ptr<fft::Plan> ROCMFft::Create1dPlan(Stream *stream, uint64 num_x,
                                                 fft::Type type,
                                                 bool in_place_fft) {
  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
  uint64 elem_count[1] = {num_x};
  port::Status status = fft_plan_ptr->Initialize(
      parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
  // TODO(yangzihao): In the future, send error msg back to TensorFlow
  // so it can fail gracefully,
  if (!status.ok()) {
    LOG(FATAL) << "failed to initialize hipfft 1d plan: "
               << status.error_message();
  }
  return std::move(fft_plan_ptr);
}

std::unique_ptr<fft::Plan> ROCMFft::Create1dPlanWithScratchAllocator(
    Stream *stream, uint64 num_x, fft::Type type, bool in_place_fft,
    ScratchAllocator *scratch_allocator) {
  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
  uint64 elem_count[1] = {num_x};
  port::Status status = fft_plan_ptr->Initialize(parent_, stream, 1, elem_count,
                                                 type, scratch_allocator);
  if (!status.ok()) {
    LOG(FATAL)
        << "failed to initialize hipfft 1d plan with customized allocator: "
        << status.error_message();
  }
  return std::move(fft_plan_ptr);
}

std::unique_ptr<fft::Plan> ROCMFft::Create2dPlan(Stream *stream, uint64 num_x,
                                                 uint64 num_y, fft::Type type,
                                                 bool in_place_fft) {
  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
  uint64 elem_count[2] = {num_x, num_y};
  port::Status status = fft_plan_ptr->Initialize(
      parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
  if (!status.ok()) {
    LOG(FATAL) << "failed to initialize hipfft 2d plan: "
               << status.error_message();
  }
  return std::move(fft_plan_ptr);
}

std::unique_ptr<fft::Plan> ROCMFft::Create2dPlanWithScratchAllocator(
    Stream *stream, uint64 num_x, uint64 num_y, fft::Type type,
    bool in_place_fft, ScratchAllocator *scratch_allocator) {
  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
  uint64 elem_count[2] = {num_x, num_y};
  port::Status status = fft_plan_ptr->Initialize(parent_, stream, 2, elem_count,
                                                 type, scratch_allocator);
  if (!status.ok()) {
    LOG(FATAL)
        << "failed to initialize hipfft 2d plan with customized allocator: "
        << status.error_message();
  }
  return std::move(fft_plan_ptr);
}

std::unique_ptr<fft::Plan> ROCMFft::Create3dPlan(Stream *stream, uint64 num_x,
                                                 uint64 num_y, uint64 num_z,
                                                 fft::Type type,
                                                 bool in_place_fft) {
  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
  uint64 elem_count[3] = {num_x, num_y, num_z};
  port::Status status = fft_plan_ptr->Initialize(
      parent_, stream, 3, elem_count, type, /*scratch_allocator=*/nullptr);
  if (!status.ok()) {
    LOG(FATAL) << "failed to initialize hipfft 3d plan: "
               << status.error_message();
  }
  return std::move(fft_plan_ptr);
}

std::unique_ptr<fft::Plan> ROCMFft::Create3dPlanWithScratchAllocator(
    Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, fft::Type type,
    bool in_place_fft, ScratchAllocator *scratch_allocator) {
  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
  uint64 elem_count[3] = {num_x, num_y, num_z};
  port::Status status = fft_plan_ptr->Initialize(parent_, stream, 3, elem_count,
                                                 type, scratch_allocator);
  if (!status.ok()) {
    LOG(FATAL)
        << "failed to initialize hipfft 3d plan with customized allocator: "
        << status.error_message();
  }
  return std::move(fft_plan_ptr);
}

std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlan(
    Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
    uint64 input_stride, uint64 input_distance, uint64 *output_embed,
    uint64 output_stride, uint64 output_distance, fft::Type type,
    bool in_place_fft, int batch_count) {
  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
  port::Status status = fft_plan_ptr->Initialize(
      parent_, stream, rank, elem_count, input_embed, input_stride,
      input_distance, output_embed, output_stride, output_distance, type,
      batch_count, /*scratch_allocator=*/nullptr);
  if (!status.ok()) {
    LOG(FATAL) << "failed to initialize batched hipfft plan: "
               << status.error_message();
  }

  return std::move(fft_plan_ptr);
}

std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlanWithScratchAllocator(
    Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
    uint64 input_stride, uint64 input_distance, uint64 *output_embed,
    uint64 output_stride, uint64 output_distance, fft::Type type,
    bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) {
  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
  port::Status status = fft_plan_ptr->Initialize(
      parent_, stream, rank, elem_count, input_embed, input_stride,
      input_distance, output_embed, output_stride, output_distance, type,
      batch_count, scratch_allocator);
  if (!status.ok()) {
    LOG(FATAL) << "failed to initialize batched hipfft plan with customized "
                  "allocator: "
               << status.error_message();
  }
  return std::move(fft_plan_ptr);
}

void ROCMFft::UpdatePlanWithScratchAllocator(
    Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) {
  LOG(ERROR) << "update plan with scratch allocator not implemented";
}

template <typename FuncT, typename InputT, typename OutputT>
bool ROCMFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfftExec,
                            const DeviceMemory<InputT> &input,
                            DeviceMemory<OutputT> *output) {
  ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
  if (rocm_fft_plan == nullptr) {
    LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object.";
    return false;
  }

  if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) {
    return false;
  }

  auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
                        GpuComplex(const_cast<InputT *>(GpuMemory(input))),
                        GpuComplex(GpuMemoryMutable(output)));

  if (ret != HIPFFT_SUCCESS) {
    LOG(ERROR) << "failed to run rocFFT routine: " << ret;
    return false;
  }

  return true;
}

template <typename FuncT, typename InputT, typename OutputT>
bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
                                         FuncT hipfftExec,
                                         const DeviceMemory<InputT> &input,
                                         DeviceMemory<OutputT> *output) {
  ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
  if (rocm_fft_plan == nullptr) {
    LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object.";
    return false;
  }

  if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) {
    return false;
  }

  auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
                        GpuComplex(const_cast<InputT *>(GpuMemory(input))),
                        GpuComplex(GpuMemoryMutable(output)),
                        rocm_fft_plan->GetFftDirection());

  if (ret != HIPFFT_SUCCESS) {
    LOG(ERROR) << "failed to run rocFFT routine: " << ret;
    return false;
  }

  return true;
}

#define STREAM_EXECUTOR_ROCM_DEFINE_FFT(__type, __fft_type1, __fft_type2,    \
                                        __fft_type3)                         \
  bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan,                       \
                      const DeviceMemory<std::complex<__type>> &input,       \
                      DeviceMemory<std::complex<__type>> *output) {          \
    return DoFftWithDirectionInternal(                                       \
        stream, plan, wrap::hipfftExec##__fft_type1, input, output);         \
  }                                                                          \
  bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan,                       \
                      const DeviceMemory<__type> &input,                     \
                      DeviceMemory<std::complex<__type>> *output) {          \
    return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, \
                         output);                                            \
  }                                                                          \
  bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan,                       \
                      const DeviceMemory<std::complex<__type>> &input,       \
                      DeviceMemory<__type> *output) {                        \
    return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, \
                         output);                                            \
  }

STREAM_EXECUTOR_ROCM_DEFINE_FFT(float, C2C, R2C, C2R)
STREAM_EXECUTOR_ROCM_DEFINE_FFT(double, Z2Z, D2Z, Z2D)

#undef STREAM_EXECUTOR_ROCM_DEFINE_FFT

}  // namespace gpu

void initialize_rocfft() {
  auto rocFftAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
      rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);

  if (!rocFftAlreadyRegistered) {
    port::Status status =
        PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
            rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
            [](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
              gpu::GpuExecutor *rocm_executor =
                  dynamic_cast<gpu::GpuExecutor *>(parent);
              if (rocm_executor == nullptr) {
                LOG(ERROR)
                    << "Attempting to initialize an instance of the rocFFT "
                    << "support library with a non-ROCM StreamExecutor";
                return nullptr;
              }

              return new gpu::ROCMFft(rocm_executor);
            });
    if (!status.ok()) {
      LOG(ERROR) << "Unable to register rocFFT factory: "
                 << status.error_message();
    }

    PluginRegistry::Instance()->SetDefaultFactory(
        rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
  }
}

}  // namespace stream_executor

REGISTER_MODULE_INITIALIZER(register_rocfft,
                            { stream_executor::initialize_rocfft(); });
