Unified cudnn exec plan builder
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index e405887..df42bf1 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -4559,6 +4559,117 @@
   float as_float_;
 };
 
+namespace {
+
+template<typename Sig>
+port::Status CreateOpRunners(
+    Stream* stream, CudnnHandle &cudnn, GpuExecutor* gpu_executor,
+    CudnnAccess* cudnn_access,
+    std::unique_ptr<cudnn_frontend::OperationGraph> op_graph,
+    dnn::ConvolutionKind kind, dnn::DataType input_type,
+    const int64_t* input_uids, bool use_fallback,
+    std::vector<std::unique_ptr<const dnn::OpRunner<Sig>>>* out_runners) {
+  cudnn_frontend::EngineConfigList filtered_configs;
+  auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
+    return GenericEngineFilter(
+        engine_config,
+        /*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
+        /*disable_nondeterminism*/ RequireCudnnDeterminism(),
+        /*disable_tensor_core*/ !IsTensorMathEnabled(stream, input_type));
+  };
+
+  if (!use_fallback) {
+    auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
+                          .setOperationGraph(*op_graph)
+                          .setHeurMode(GetCudnnFrontendHeurMode())
+                          .build();
+    RETURN_MSG_IF_CUDNN_ERROR(heuristics);
+
+    // cuDNN frontend sneakily puts error messages on the object and returns
+    // partially-initialized results when there's an error; make sure to check
+    // them.
+    int64_t engine_count = heuristics.getEngineConfigCount();
+    RETURN_MSG_IF_CUDNN_ERROR(heuristics);
+    auto& heuristics_configs = heuristics.getEngineConfig(engine_count);
+    RETURN_MSG_IF_CUDNN_ERROR(heuristics);
+    VLOG(4) << "\nHeuristics engine configs size: "
+            << heuristics_configs.size();
+
+    cudnn_frontend::filter(heuristics_configs, filtered_configs,
+                           generic_filter_fn);
+  } else {
+    auto fallback = cudnn_frontend::EngineFallbackListBuilder()
+                        .setOperationGraph(*op_graph)
+                        .setOperation(GetCudnnConvolutionType(kind))
+                        .build();
+    RETURN_MSG_IF_CUDNN_ERROR(fallback);
+
+    auto& fallback_configs = fallback.getFallbackList();
+    VLOG(4) << "\nFallback engine configs size: " << fallback_configs.size();
+
+    cudnn_frontend::filter(fallback_configs, filtered_configs,
+                           generic_filter_fn);
+  }
+  VLOG(4) << "\nFiltered engine configs size: " << filtered_configs.size();
+
+  auto fn = []() { return true; };
+  auto maybe_json_handle_static = CudnnExecutionPlanEngineFilterStatic();
+  auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime();
+
+  out_runners->clear();
+  for (int i = 0; i < filtered_configs.size(); i++) {
+    auto plan = cudnn_frontend::ExecutionPlanBuilder()
+                    .setHandle(cudnn.handle())
+                    .setEngineConfig(filtered_configs[i], op_graph->getTag())
+                    .build();
+    if (plan.get_status() != CUDNN_STATUS_SUCCESS) {
+      continue;
+    }
+
+    if (maybe_json_handle_static &&
+        cudnn_frontend::check_errata(*maybe_json_handle_static, plan.getTag(),
+                                     cudnn.handle(), fn)) {
+      VLOG(4) << "Exclude engine (static): " << plan.getTag();
+      continue;
+    }
+    if (maybe_json_handle_runtime &&
+        cudnn_frontend::check_errata(*maybe_json_handle_runtime, plan.getTag(),
+                                     cudnn.handle(), fn)) {
+      VLOG(4) << "Exclude engine (runtime): " << plan.getTag();
+      continue;
+    }
+
+    auto runner_or = CudnnExecutionPlanRunner<Sig>::Create(
+                         gpu_executor, cudnn_access, std::move(plan),
+                         input_uids);
+    if (!runner_or.ok()) {
+      // Note this can happen if cuDNN Frontend gives us partially-initialized
+      // ExecutionPlans because its error handling is broken in non-exception
+      // builds; those were meant to be filtered out earlier inside cuDNN
+      // Frontend, but instead they get filtered out here.
+      VLOG(4) << "Failed building runner from ExecutionPlan (i.e. failed "
+                 "getting its workspace size): "
+              << runner_or.status().ToString();
+      continue;
+    }
+
+    out_runners->push_back(
+        std::make_unique<CudnnExecutionPlanRunner<Sig>>(
+            runner_or.ConsumeValueOrDie()));
+
+    // We will use the first working plan when determinism is required.
+    if (RequireCudnnDeterminism()) {
+      break;
+    }
+  }
+
+  VLOG(4) << "\nReturned execution plans size: " << out_runners->size();
+
+  return port::Status::OK();
+}
+
+} // namespace
+
 port::Status CudnnSupport::GetConvolveRunners(
     bool use_cudnn_frontend, dnn::ConvolutionKind kind,
     dnn::DataType input_type, dnn::DataType output_type, Stream* stream,
@@ -4654,102 +4765,9 @@
                              filter_descriptor, output_descriptor,
                              convolution_descriptor, cudnn));
 
-  cudnn_frontend::EngineConfigList filtered_configs;
-  auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
-    return GenericEngineFilter(
-        engine_config,
-        /*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
-        /*disable_nondeterminism*/ RequireCudnnDeterminism(),
-        /*disable_tensor_core*/ !IsTensorMathEnabled(stream, input_type));
-  };
-
-  if (!use_fallback) {
-    auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
-                          .setOperationGraph(*op_graph)
-                          .setHeurMode(GetCudnnFrontendHeurMode())
-                          .build();
-    RETURN_MSG_IF_CUDNN_ERROR(heuristics);
-
-    // cuDNN frontend sneakily puts error messages on the object and returns
-    // partially-initialized results when there's an error; make sure to check
-    // them.
-    int64_t engine_count = heuristics.getEngineConfigCount();
-    RETURN_MSG_IF_CUDNN_ERROR(heuristics);
-    auto& heuristics_configs = heuristics.getEngineConfig(engine_count);
-    RETURN_MSG_IF_CUDNN_ERROR(heuristics);
-    VLOG(4) << "\nHeuristics engine configs size: "
-            << heuristics_configs.size();
-
-    cudnn_frontend::filter(heuristics_configs, filtered_configs,
-                           generic_filter_fn);
-  } else {
-    auto fallback = cudnn_frontend::EngineFallbackListBuilder()
-                        .setOperationGraph(*op_graph)
-                        .setOperation(GetCudnnConvolutionType(kind))
-                        .build();
-    RETURN_MSG_IF_CUDNN_ERROR(fallback);
-
-    auto& fallback_configs = fallback.getFallbackList();
-    VLOG(4) << "\nFallback engine configs size: " << fallback_configs.size();
-
-    cudnn_frontend::filter(fallback_configs, filtered_configs,
-                           generic_filter_fn);
-  }
-  VLOG(4) << "\nFiltered engine configs size: " << filtered_configs.size();
-
-  auto fn = []() { return true; };
-  auto maybe_json_handle_static = CudnnExecutionPlanEngineFilterStatic();
-  auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime();
-
-  out_exec_plans->clear();
-  for (int i = 0; i < filtered_configs.size(); i++) {
-    auto plan = cudnn_frontend::ExecutionPlanBuilder()
-                    .setHandle(cudnn.handle())
-                    .setEngineConfig(filtered_configs[i], op_graph->getTag())
-                    .build();
-    if (plan.get_status() != CUDNN_STATUS_SUCCESS) {
-      continue;
-    }
-
-    if (maybe_json_handle_static &&
-        cudnn_frontend::check_errata(*maybe_json_handle_static, plan.getTag(),
-                                     cudnn.handle(), fn)) {
-      VLOG(4) << "Exclude engine (static): " << plan.getTag();
-      continue;
-    }
-    if (maybe_json_handle_runtime &&
-        cudnn_frontend::check_errata(*maybe_json_handle_runtime, plan.getTag(),
-                                     cudnn.handle(), fn)) {
-      VLOG(4) << "Exclude engine (runtime): " << plan.getTag();
-      continue;
-    }
-
-    auto runner_or = CudnnExecutionPlanRunner<dnn::ConvSignature>::Create(
-        parent_, cudnn_.get(), std::move(plan), {'x', 'w', 'y'});
-    if (!runner_or.ok()) {
-      // Note this can happen if cuDNN Frontend gives us partially-initialized
-      // ExecutionPlans because its error handling is broken in non-exception
-      // builds; those were meant to be filtered out earlier inside cuDNN
-      // Frontend, but instead they get filtered out here.
-      VLOG(4) << "Failed building runner from ExecutionPlan (i.e. failed "
-                 "getting its workspace size): "
-              << runner_or.status().ToString();
-      continue;
-    }
-
-    out_exec_plans->push_back(
-        std::make_unique<CudnnExecutionPlanRunner<dnn::ConvSignature>>(
-            runner_or.ConsumeValueOrDie()));
-
-    // We will use the first working plan when determinism is required.
-    if (RequireCudnnDeterminism()) {
-      break;
-    }
-  }
-
-  VLOG(4) << "\nReturned execution plans size: " << out_exec_plans->size();
-
-  return port::Status::OK();
+  return CreateOpRunners<dnn::ConvSignature>(
+      stream, cudnn, parent_, cudnn_.get(), std::move(op_graph), kind,
+      input_type, kUnfusedConvUids, use_fallback, out_exec_plans);
 #else
   return port::UnimplementedError(
       "Cudnn execution plans are only supported with Cudnn >= 8.1.");
@@ -5184,114 +5202,17 @@
 
 #if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
   auto cudnn = cudnn_->GetHandle(parent_, stream);
-  auto op_graph_status = GetCudnnFusedOperationGraph(
-      kind, input_type, bias_type, output_type, conv_scale, side_input_scale,
-      input_descriptor, filter_descriptor, bias_descriptor, output_descriptor,
-      convolution_descriptor, activation_mode, cudnn);
-  if (!op_graph_status.status().ok()) {
-    return port::Status(port::error::INTERNAL,
-                        absl::StrCat("Cudnn graph failed to build: ",
-                                     op_graph_status.status().ToString()));
-  }
-  auto op_graph = op_graph_status.ConsumeValueOrDie();
+  SE_ASSIGN_OR_RETURN(
+      auto op_graph,
+      GetCudnnFusedOperationGraph(
+          kind, input_type, bias_type, output_type, conv_scale,
+          side_input_scale, input_descriptor, filter_descriptor,
+          bias_descriptor, output_descriptor, convolution_descriptor,
+          activation_mode, cudnn));
 
-  cudnn_frontend::EngineConfigList filtered_configs;
-  auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
-    return GenericEngineFilter(
-        engine_config,
-        /*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
-        /*disable_nondeterminism*/ RequireCudnnDeterminism(),
-        /*disable_tensor_core*/ !IsTensorMathEnabled(stream, input_type));
-  };
-
-  if (!use_fallback) {
-    auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
-                          .setOperationGraph(*op_graph)
-                          .setHeurMode(GetCudnnFrontendHeurMode())
-                          .build();
-    RETURN_MSG_IF_CUDNN_ERROR(heuristics);
-
-    // cuDNN frontend sneakily puts error messages on the object and returns
-    // partially-initialized results when there's an error; make sure to check
-    // them.
-    int64_t engine_count = heuristics.getEngineConfigCount();
-    RETURN_MSG_IF_CUDNN_ERROR(heuristics);
-    auto& heuristics_configs = heuristics.getEngineConfig(engine_count);
-    RETURN_MSG_IF_CUDNN_ERROR(heuristics);
-    VLOG(4) << "\nHeuristics engine configs size: "
-            << heuristics_configs.size();
-
-    cudnn_frontend::filter(heuristics_configs, filtered_configs,
-                           generic_filter_fn);
-  } else {
-    auto fallback = cudnn_frontend::EngineFallbackListBuilder()
-                        .setOperationGraph(*op_graph)
-                        .setOperation(GetCudnnConvolutionType(
-                            dnn::ConvolutionKind::FORWARD))
-                        .build();
-    RETURN_MSG_IF_CUDNN_ERROR(fallback);
-
-    auto& fallback_configs = fallback.getFallbackList();
-    VLOG(4) << "\nFallback engine configs size: " << fallback_configs.size();
-
-    cudnn_frontend::filter(fallback_configs, filtered_configs,
-                           generic_filter_fn);
-  }
-  VLOG(4) << "\nFiltered engine configs size: " << filtered_configs.size();
-
-  auto fn = []() { return true; };
-  auto maybe_json_handle_static = CudnnExecutionPlanEngineFilterStatic();
-  auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime();
-
-  out_exec_plans->clear();
-  for (int i = 0; i < filtered_configs.size(); i++) {
-    auto plan = cudnn_frontend::ExecutionPlanBuilder()
-                    .setHandle(cudnn.handle())
-                    .setEngineConfig(filtered_configs[i], op_graph->getTag())
-                    .build();
-    if (plan.get_status() != CUDNN_STATUS_SUCCESS) {
-      continue;
-    }
-
-    if (maybe_json_handle_static &&
-        cudnn_frontend::check_errata(*maybe_json_handle_static, plan.getTag(),
-                                     cudnn.handle(), fn)) {
-      VLOG(4) << "Exclude engine (static): " << plan.getTag();
-      continue;
-    }
-    if (maybe_json_handle_runtime &&
-        cudnn_frontend::check_errata(*maybe_json_handle_runtime, plan.getTag(),
-                                     cudnn.handle(), fn)) {
-      VLOG(4) << "Exclude engine (runtime): " << plan.getTag();
-      continue;
-    }
-
-    auto runner_or = CudnnExecutionPlanRunner<dnn::FusedConvSignature>::Create(
-        parent_, cudnn_.get(), std::move(plan), {'x', 'w', 'z', 'b', 'y'});
-    if (!runner_or.ok()) {
-      // Note this can happen if cuDNN Frontend gives us partially-initialized
-      // ExecutionPlans because its error handling is broken in non-exception
-      // builds; those were meant to be filtered out earlier inside cuDNN
-      // Frontend, but instead they get filtered out here.
-      VLOG(4) << "Failed building runner from ExecutionPlan (i.e. failed "
-                 "getting its workspace size): "
-              << runner_or.status().ToString();
-      continue;
-    }
-
-    out_exec_plans->push_back(
-        std::make_unique<CudnnExecutionPlanRunner<dnn::FusedConvSignature>>(
-            runner_or.ConsumeValueOrDie()));
-
-    // We will use the first working plan when determinism is required.
-    if (RequireCudnnDeterminism()) {
-      break;
-    }
-  }
-
-  VLOG(4) << "\nReturned execution plans size: " << out_exec_plans->size();
-
-  return port::Status::OK();
+  return CreateOpRunners<dnn::FusedConvSignature>(
+      stream, cudnn, parent_, cudnn_.get(), std::move(op_graph), kind,
+      input_type, kFusedConvUids, use_fallback, out_exec_plans);
 #else
   return port::UnimplementedError(
       "Cudnn execution plans are only supported with Cudnn >= 8.1.");