Unified cudnn exec plan builder
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index e405887..df42bf1 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -4559,6 +4559,117 @@
float as_float_;
};
+namespace {
+
+template<typename Sig>
+port::Status CreateOpRunners(
+ Stream* stream, CudnnHandle &cudnn, GpuExecutor* gpu_executor,
+ CudnnAccess* cudnn_access,
+ std::unique_ptr<cudnn_frontend::OperationGraph> op_graph,
+ dnn::ConvolutionKind kind, dnn::DataType input_type,
+ const int64_t* input_uids, bool use_fallback,
+ std::vector<std::unique_ptr<const dnn::OpRunner<Sig>>>* out_runners) {
+ cudnn_frontend::EngineConfigList filtered_configs;
+ auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
+ return GenericEngineFilter(
+ engine_config,
+ /*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
+ /*disable_nondeterminism*/ RequireCudnnDeterminism(),
+ /*disable_tensor_core*/ !IsTensorMathEnabled(stream, input_type));
+ };
+
+ if (!use_fallback) {
+ auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
+ .setOperationGraph(*op_graph)
+ .setHeurMode(GetCudnnFrontendHeurMode())
+ .build();
+ RETURN_MSG_IF_CUDNN_ERROR(heuristics);
+
+ // cuDNN frontend sneakily puts error messages on the object and returns
+ // partially-initialized results when there's an error; make sure to check
+ // them.
+ int64_t engine_count = heuristics.getEngineConfigCount();
+ RETURN_MSG_IF_CUDNN_ERROR(heuristics);
+ auto& heuristics_configs = heuristics.getEngineConfig(engine_count);
+ RETURN_MSG_IF_CUDNN_ERROR(heuristics);
+ VLOG(4) << "\nHeuristics engine configs size: "
+ << heuristics_configs.size();
+
+ cudnn_frontend::filter(heuristics_configs, filtered_configs,
+ generic_filter_fn);
+ } else {
+ auto fallback = cudnn_frontend::EngineFallbackListBuilder()
+ .setOperationGraph(*op_graph)
+ .setOperation(GetCudnnConvolutionType(kind))
+ .build();
+ RETURN_MSG_IF_CUDNN_ERROR(fallback);
+
+ auto& fallback_configs = fallback.getFallbackList();
+ VLOG(4) << "\nFallback engine configs size: " << fallback_configs.size();
+
+ cudnn_frontend::filter(fallback_configs, filtered_configs,
+ generic_filter_fn);
+ }
+ VLOG(4) << "\nFiltered engine configs size: " << filtered_configs.size();
+
+ auto fn = []() { return true; };
+ auto maybe_json_handle_static = CudnnExecutionPlanEngineFilterStatic();
+ auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime();
+
+ out_runners->clear();
+ for (int i = 0; i < filtered_configs.size(); i++) {
+ auto plan = cudnn_frontend::ExecutionPlanBuilder()
+ .setHandle(cudnn.handle())
+ .setEngineConfig(filtered_configs[i], op_graph->getTag())
+ .build();
+ if (plan.get_status() != CUDNN_STATUS_SUCCESS) {
+ continue;
+ }
+
+ if (maybe_json_handle_static &&
+ cudnn_frontend::check_errata(*maybe_json_handle_static, plan.getTag(),
+ cudnn.handle(), fn)) {
+ VLOG(4) << "Exclude engine (static): " << plan.getTag();
+ continue;
+ }
+ if (maybe_json_handle_runtime &&
+ cudnn_frontend::check_errata(*maybe_json_handle_runtime, plan.getTag(),
+ cudnn.handle(), fn)) {
+ VLOG(4) << "Exclude engine (runtime): " << plan.getTag();
+ continue;
+ }
+
+ auto runner_or = CudnnExecutionPlanRunner<Sig>::Create(
+ gpu_executor, cudnn_access, std::move(plan),
+ input_uids);
+ if (!runner_or.ok()) {
+ // Note this can happen if cuDNN Frontend gives us partially-initialized
+ // ExecutionPlans because its error handling is broken in non-exception
+ // builds; those were meant to be filtered out earlier inside cuDNN
+ // Frontend, but instead they get filtered out here.
+ VLOG(4) << "Failed building runner from ExecutionPlan (i.e. failed "
+ "getting its workspace size): "
+ << runner_or.status().ToString();
+ continue;
+ }
+
+ out_runners->push_back(
+ std::make_unique<CudnnExecutionPlanRunner<Sig>>(
+ runner_or.ConsumeValueOrDie()));
+
+ // We will use the first working plan when determinism is required.
+ if (RequireCudnnDeterminism()) {
+ break;
+ }
+ }
+
+ VLOG(4) << "\nReturned execution plans size: " << out_runners->size();
+
+ return port::Status::OK();
+}
+
+} // namespace
+
port::Status CudnnSupport::GetConvolveRunners(
bool use_cudnn_frontend, dnn::ConvolutionKind kind,
dnn::DataType input_type, dnn::DataType output_type, Stream* stream,
@@ -4654,102 +4765,9 @@
filter_descriptor, output_descriptor,
convolution_descriptor, cudnn));
- cudnn_frontend::EngineConfigList filtered_configs;
- auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
- return GenericEngineFilter(
- engine_config,
- /*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
- /*disable_nondeterminism*/ RequireCudnnDeterminism(),
- /*disable_tensor_core*/ !IsTensorMathEnabled(stream, input_type));
- };
-
- if (!use_fallback) {
- auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
- .setOperationGraph(*op_graph)
- .setHeurMode(GetCudnnFrontendHeurMode())
- .build();
- RETURN_MSG_IF_CUDNN_ERROR(heuristics);
-
- // cuDNN frontend sneakily puts error messages on the object and returns
- // partially-initialized results when there's an error; make sure to check
- // them.
- int64_t engine_count = heuristics.getEngineConfigCount();
- RETURN_MSG_IF_CUDNN_ERROR(heuristics);
- auto& heuristics_configs = heuristics.getEngineConfig(engine_count);
- RETURN_MSG_IF_CUDNN_ERROR(heuristics);
- VLOG(4) << "\nHeuristics engine configs size: "
- << heuristics_configs.size();
-
- cudnn_frontend::filter(heuristics_configs, filtered_configs,
- generic_filter_fn);
- } else {
- auto fallback = cudnn_frontend::EngineFallbackListBuilder()
- .setOperationGraph(*op_graph)
- .setOperation(GetCudnnConvolutionType(kind))
- .build();
- RETURN_MSG_IF_CUDNN_ERROR(fallback);
-
- auto& fallback_configs = fallback.getFallbackList();
- VLOG(4) << "\nFallback engine configs size: " << fallback_configs.size();
-
- cudnn_frontend::filter(fallback_configs, filtered_configs,
- generic_filter_fn);
- }
- VLOG(4) << "\nFiltered engine configs size: " << filtered_configs.size();
-
- auto fn = []() { return true; };
- auto maybe_json_handle_static = CudnnExecutionPlanEngineFilterStatic();
- auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime();
-
- out_exec_plans->clear();
- for (int i = 0; i < filtered_configs.size(); i++) {
- auto plan = cudnn_frontend::ExecutionPlanBuilder()
- .setHandle(cudnn.handle())
- .setEngineConfig(filtered_configs[i], op_graph->getTag())
- .build();
- if (plan.get_status() != CUDNN_STATUS_SUCCESS) {
- continue;
- }
-
- if (maybe_json_handle_static &&
- cudnn_frontend::check_errata(*maybe_json_handle_static, plan.getTag(),
- cudnn.handle(), fn)) {
- VLOG(4) << "Exclude engine (static): " << plan.getTag();
- continue;
- }
- if (maybe_json_handle_runtime &&
- cudnn_frontend::check_errata(*maybe_json_handle_runtime, plan.getTag(),
- cudnn.handle(), fn)) {
- VLOG(4) << "Exclude engine (runtime): " << plan.getTag();
- continue;
- }
-
- auto runner_or = CudnnExecutionPlanRunner<dnn::ConvSignature>::Create(
- parent_, cudnn_.get(), std::move(plan), {'x', 'w', 'y'});
- if (!runner_or.ok()) {
- // Note this can happen if cuDNN Frontend gives us partially-initialized
- // ExecutionPlans because its error handling is broken in non-exception
- // builds; those were meant to be filtered out earlier inside cuDNN
- // Frontend, but instead they get filtered out here.
- VLOG(4) << "Failed building runner from ExecutionPlan (i.e. failed "
- "getting its workspace size): "
- << runner_or.status().ToString();
- continue;
- }
-
- out_exec_plans->push_back(
- std::make_unique<CudnnExecutionPlanRunner<dnn::ConvSignature>>(
- runner_or.ConsumeValueOrDie()));
-
- // We will use the first working plan when determinism is required.
- if (RequireCudnnDeterminism()) {
- break;
- }
- }
-
- VLOG(4) << "\nReturned execution plans size: " << out_exec_plans->size();
-
- return port::Status::OK();
+ return CreateOpRunners<dnn::ConvSignature>(
+ stream, cudnn, parent_, cudnn_.get(), std::move(op_graph), kind,
+ input_type, kUnfusedConvUids, use_fallback, out_exec_plans);
#else
return port::UnimplementedError(
"Cudnn execution plans are only supported with Cudnn >= 8.1.");
@@ -5184,114 +5202,17 @@
#if CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto op_graph_status = GetCudnnFusedOperationGraph(
- kind, input_type, bias_type, output_type, conv_scale, side_input_scale,
- input_descriptor, filter_descriptor, bias_descriptor, output_descriptor,
- convolution_descriptor, activation_mode, cudnn);
- if (!op_graph_status.status().ok()) {
- return port::Status(port::error::INTERNAL,
- absl::StrCat("Cudnn graph failed to build: ",
- op_graph_status.status().ToString()));
- }
- auto op_graph = op_graph_status.ConsumeValueOrDie();
+ SE_ASSIGN_OR_RETURN(
+ auto op_graph,
+ GetCudnnFusedOperationGraph(
+ kind, input_type, bias_type, output_type, conv_scale,
+ side_input_scale, input_descriptor, filter_descriptor,
+ bias_descriptor, output_descriptor, convolution_descriptor,
+ activation_mode, cudnn));
- cudnn_frontend::EngineConfigList filtered_configs;
- auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
- return GenericEngineFilter(
- engine_config,
- /*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
- /*disable_nondeterminism*/ RequireCudnnDeterminism(),
- /*disable_tensor_core*/ !IsTensorMathEnabled(stream, input_type));
- };
-
- if (!use_fallback) {
- auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
- .setOperationGraph(*op_graph)
- .setHeurMode(GetCudnnFrontendHeurMode())
- .build();
- RETURN_MSG_IF_CUDNN_ERROR(heuristics);
-
- // cuDNN frontend sneakily puts error messages on the object and returns
- // partially-initialized results when there's an error; make sure to check
- // them.
- int64_t engine_count = heuristics.getEngineConfigCount();
- RETURN_MSG_IF_CUDNN_ERROR(heuristics);
- auto& heuristics_configs = heuristics.getEngineConfig(engine_count);
- RETURN_MSG_IF_CUDNN_ERROR(heuristics);
- VLOG(4) << "\nHeuristics engine configs size: "
- << heuristics_configs.size();
-
- cudnn_frontend::filter(heuristics_configs, filtered_configs,
- generic_filter_fn);
- } else {
- auto fallback = cudnn_frontend::EngineFallbackListBuilder()
- .setOperationGraph(*op_graph)
- .setOperation(GetCudnnConvolutionType(
- dnn::ConvolutionKind::FORWARD))
- .build();
- RETURN_MSG_IF_CUDNN_ERROR(fallback);
-
- auto& fallback_configs = fallback.getFallbackList();
- VLOG(4) << "\nFallback engine configs size: " << fallback_configs.size();
-
- cudnn_frontend::filter(fallback_configs, filtered_configs,
- generic_filter_fn);
- }
- VLOG(4) << "\nFiltered engine configs size: " << filtered_configs.size();
-
- auto fn = []() { return true; };
- auto maybe_json_handle_static = CudnnExecutionPlanEngineFilterStatic();
- auto maybe_json_handle_runtime = CudnnExecutionPlanEngineFilterRuntime();
-
- out_exec_plans->clear();
- for (int i = 0; i < filtered_configs.size(); i++) {
- auto plan = cudnn_frontend::ExecutionPlanBuilder()
- .setHandle(cudnn.handle())
- .setEngineConfig(filtered_configs[i], op_graph->getTag())
- .build();
- if (plan.get_status() != CUDNN_STATUS_SUCCESS) {
- continue;
- }
-
- if (maybe_json_handle_static &&
- cudnn_frontend::check_errata(*maybe_json_handle_static, plan.getTag(),
- cudnn.handle(), fn)) {
- VLOG(4) << "Exclude engine (static): " << plan.getTag();
- continue;
- }
- if (maybe_json_handle_runtime &&
- cudnn_frontend::check_errata(*maybe_json_handle_runtime, plan.getTag(),
- cudnn.handle(), fn)) {
- VLOG(4) << "Exclude engine (runtime): " << plan.getTag();
- continue;
- }
-
- auto runner_or = CudnnExecutionPlanRunner<dnn::FusedConvSignature>::Create(
- parent_, cudnn_.get(), std::move(plan), {'x', 'w', 'z', 'b', 'y'});
- if (!runner_or.ok()) {
- // Note this can happen if cuDNN Frontend gives us partially-initialized
- // ExecutionPlans because its error handling is broken in non-exception
- // builds; those were meant to be filtered out earlier inside cuDNN
- // Frontend, but instead they get filtered out here.
- VLOG(4) << "Failed building runner from ExecutionPlan (i.e. failed "
- "getting its workspace size): "
- << runner_or.status().ToString();
- continue;
- }
-
- out_exec_plans->push_back(
- std::make_unique<CudnnExecutionPlanRunner<dnn::FusedConvSignature>>(
- runner_or.ConsumeValueOrDie()));
-
- // We will use the first working plan when determinism is required.
- if (RequireCudnnDeterminism()) {
- break;
- }
- }
-
- VLOG(4) << "\nReturned execution plans size: " << out_exec_plans->size();
-
- return port::Status::OK();
+ return CreateOpRunners<dnn::FusedConvSignature>(
+ stream, cudnn, parent_, cudnn_.get(), std::move(op_graph), kind,
+ input_type, kFusedConvUids, use_fallback, out_exec_plans);
#else
return port::UnimplementedError(
"Cudnn execution plans are only supported with Cudnn >= 8.1.");