[MLIR][KernelGen] Compile for multiple NVIDIA GPU architectures simultaneously

For every architecture, compile the kernel module to ptx and to asm.  The
resulting cubins are then combined into one fatbin using the fatbinary tool.
This change only affects the `tf_to_kernel` tool.

PiperOrigin-RevId: 335025889
Change-Id: Ic8f7325081aad497ba7f5bea709acc802beb8b86
diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py
index e403a75..17410b4 100644
--- a/tensorflow/compiler/mlir/runlit.cfg.py
+++ b/tensorflow/compiler/mlir/runlit.cfg.py
@@ -74,8 +74,8 @@
     'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
     'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
     'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt',
-    'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_gpu_binary', 'xla-thunks-opt',
-    'tfjs-opt'
+    'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_kernel', 'tf_to_gpu_binary',
+    'xla-thunks-opt', 'tfjs-opt'
 ]
 tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
 llvm_config.add_tool_substitutions(tools, tool_dirs)
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
index 619a56c..834115a 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
@@ -105,7 +105,10 @@
 tf_cc_binary(
     name = "tf_to_kernel",
     srcs = ["tf_to_kernel.cc"],
-    visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"],
+    visibility = [
+        "//tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel:__pkg__",
+        "//tensorflow/core/kernels/mlir_generated:__pkg__",
+    ],
     deps = [
         ":kernel_creator",
         "//tensorflow/compiler/mlir:init_mlir",
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
index 48696f6..c3b1672 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc
@@ -174,7 +174,8 @@
 Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only,
                       llvm::ArrayRef<uint32_t> same_shape,
                       llvm::StringRef gpu_binary_attr_name,
-                      int32_t architecture) {
+                      llvm::ArrayRef<uint32_t> architectures,
+                      bool generate_fatbin) {
   mlir::PassManager pm(module.getContext());
   applyTensorflowAndCLOptions(pm);
 
@@ -187,7 +188,7 @@
   }
   kernel_pm.addPass(mlir::createStripDebugInfoPass());
   kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass(
-      gpu_binary_attr_name, architecture));
+      gpu_binary_attr_name, architectures, generate_fatbin));
 
   if (!gpu_binary_only) {
     pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass());
@@ -202,9 +203,9 @@
 
 StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
     mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
-    int32_t architecture, llvm::ArrayRef<uint32_t> tile_sizes,
+    llvm::ArrayRef<uint32_t> architectures, llvm::ArrayRef<uint32_t> tile_sizes,
     llvm::ArrayRef<uint32_t> same_shape,
-    llvm::ArrayRef<uint32_t> unroll_factors) {
+    llvm::ArrayRef<uint32_t> unroll_factors, bool generate_fatbin) {
   mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
   mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
   TF_RETURN_IF_ERROR(
@@ -221,7 +222,8 @@
   TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
 #endif
   TF_RETURN_IF_ERROR(LowerGPUToLLVM(module.get(), gpu_binary_only, same_shape,
-                                    kGpuBinaryAttrName, architecture));
+                                    kGpuBinaryAttrName, architectures,
+                                    generate_fatbin));
   return module;
 }
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h
index b168ec8..0a74a8a3 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h
@@ -38,9 +38,10 @@
 // false, lowers the host side to LLVM Dialect.
 xla::StatusOr<mlir::OwningModuleRef> GenerateKernelForTfCode(
     mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only,
-    int32_t architecture = 75, llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
+    llvm::ArrayRef<uint32_t> architectures = {75},
+    llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
     llvm::ArrayRef<uint32_t> same_shape = {},
-    llvm::ArrayRef<uint32_t> unroll_factors = {});
+    llvm::ArrayRef<uint32_t> unroll_factors = {}, bool generate_fatbin = true);
 
 // Extracts gpu_binary from the converted module.
 xla::StatusOr<std::string> ExtractGpuBinary(mlir::ModuleOp module);
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir
index e596c33..de9f4ae 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir
@@ -1,6 +1,5 @@
 // RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70
 func @tanh(%arg0: tensor<?xf32>) -> tensor<?xf32> {
-  %0 = "tf.Tanh"(%arg0) { }
-    : (tensor<?xf32>) -> tensor<?xf32>
+  %0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD
new file mode 100644
index 0000000..24e288c
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD
@@ -0,0 +1,17 @@
+load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
+
+package(licenses = ["notice"])
+
+glob_lit_tests(
+    data = [
+        "//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel",
+        "@llvm-project//mlir:run_lit.sh",
+    ],
+    default_tags = [
+        # We need access to the CUDA SDK.
+        "gpu",
+        "no_rocm",
+    ],
+    driver = "//tensorflow/compiler/mlir:run_lit.sh",
+    test_file_exts = ["mlir"],
+)
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/tanh.mlir
new file mode 100644
index 0000000..d5d1b87
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/tanh.mlir
@@ -0,0 +1,6 @@
+// RUN: tf_to_kernel --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70,75
+
+func @tanh(%arg: tensor<*xf32>) -> tensor<*xf32> {
+  %0 = "tf.Tanh"(%arg) : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc
index c7cb924..cbd97e2 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc
@@ -48,7 +48,7 @@
       mlir::OwningModuleRef module,
       GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/true,
                               architecture, tile_sizes, same_shape,
-                              unroll_factors));
+                              unroll_factors, /*generate_fatbin=*/false));
   // Extract gpu_binary.
   TF_ASSIGN_OR_RETURN(std::string gpu_binary, ExtractGpuBinary(*module));
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc
index e62fa47..d2d71a2 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc
@@ -95,7 +95,8 @@
 }
 
 xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file,
-                int32_t architecture, llvm::ArrayRef<uint32_t> tile_sizes,
+                llvm::ArrayRef<uint32_t> architectures,
+                llvm::ArrayRef<uint32_t> tile_sizes,
                 llvm::ArrayRef<uint32_t> same_shape,
                 llvm::ArrayRef<uint32_t> unroll_factors) {
   // Read TF code.
@@ -107,7 +108,7 @@
   TF_ASSIGN_OR_RETURN(
       mlir::OwningModuleRef module,
       GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/false,
-                              architecture, tile_sizes, same_shape,
+                              architectures, tile_sizes, same_shape,
                               unroll_factors));
   // Get binary.
   TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module));
@@ -129,8 +130,8 @@
   llvm::cl::opt<std::string> output_file(
       "output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"),
       llvm::cl::init("foo.bin"));
-  llvm::cl::list<int32_t> architecture(
-      "arch", llvm::cl::desc("target architecture (e.g. 50 for sm_50)"),
+  llvm::cl::list<uint32_t> architectures(
+      "arch", llvm::cl::desc("target architectures (e.g. 50 for sm_50)"),
       llvm::cl::OneOrMore, llvm::cl::CommaSeparated);
   llvm::cl::list<uint32_t> tile_sizes(
       "tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore,
@@ -151,7 +152,7 @@
   llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n");
 
   auto status =
-      tensorflow::kernel_gen::Run(input_file, output_file, architecture.front(),
+      tensorflow::kernel_gen::Run(input_file, output_file, architectures,
                                   tile_sizes, same_shape, unroll_factors);
   if (!status.ok()) {
     LOG(ERROR) << status;
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
index d4110b4..caa665b 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
@@ -117,6 +117,7 @@
         "@llvm-project//mlir:AllPassesAndDialects",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Transforms",
+        "@llvm-project//llvm:TransformUtils",
         "//tensorflow/compiler/mlir/hlo",
         "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
         "//tensorflow/compiler/mlir/hlo:lhlo",
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc
index dda0e24..f995c22 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc
@@ -13,6 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "llvm/Transforms/Utils/Cloning.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/Target/NVVMIR.h"  // from @llvm-project
 #include "mlir/Target/ROCDLIR.h"  // from @llvm-project
@@ -49,9 +50,12 @@
 class GpuKernelToBlobPass
     : public GpuKernelToBlobPassBase<GpuKernelToBlobPass> {
  public:
-  GpuKernelToBlobPass(mlir::StringRef blob_annotation, int32_t arch) {
+  GpuKernelToBlobPass(mlir::StringRef blob_annotation,
+                      llvm::ArrayRef<uint32_t> architectures,
+                      bool generate_fatbin) {
     blob_annotation_ = blob_annotation.str();
-    arch_ = arch;
+    architectures_ = architectures;
+    generate_fatbin_ = generate_fatbin;
   }
 
   void runOnOperation() override {
@@ -69,7 +73,17 @@
 
   xla::StatusOr<std::vector<uint8_t>> GetGpuBinaryBlob(
       mlir::gpu::GPUModuleOp gpu_module) {
+    if (architectures_.empty()) {
+      return InternalError("Expected at least one GPU architecture.");
+    }
+    if (!generate_fatbin_ && architectures_.size() > 1) {
+      return InternalError(
+          "Can only generate machine code for more than one architecture as a "
+          "fatbin.");
+    }
+
     llvm::LLVMContext llvmContext;
+
 #if TENSORFLOW_USE_ROCM
     auto llvmModule = mlir::translateModuleToROCDLIR(gpu_module, llvmContext);
     if (!llvmModule) {
@@ -81,9 +95,14 @@
     xla::HloModuleConfig config;
     config.set_debug_options(xla::GetDebugOptionsFromFlags());
 
-    std::string libdevice_dir = tensorflow::RocdlRoot();
+    // TODO(b/169066682): Support fatbin on ROCm.
+    if (generate_fatbin_) {
+      return InternalError("Fatbins are not yet supported for ROCm.");
+    }
 
-    return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch_, config,
+    uint32_t arch = architectures_.front();
+    std::string libdevice_dir = tensorflow::RocdlRoot();
+    return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch, config,
                                             libdevice_dir);
 
 #elif GOOGLE_CUDA
@@ -102,19 +121,42 @@
       target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast;
     };
 
-    int32_t cc_major = arch_ / 10;
-    int32_t cc_minor = arch_ % 10;
+    // Compile and collect requested cubin and PTX images.
+    std::vector<tensorflow::se::CubinOrPTXImage> images;
     TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config));
-    TF_ASSIGN_OR_RETURN(
-        std::string ptx,
-        xla::gpu::nvptx::CompileToPtx(llvmModule.get(),
-                                      std::make_pair(cc_major, cc_minor),
-                                      config, libdevice_dir, enable_fusion));
-    VLOG(1) << ptx;
+    auto gpu_asm_opts = xla::gpu::PtxOptsFromConfig(config);
+    for (uint32_t arch : architectures_) {
+      int32_t cc_major = arch / 10;
+      int32_t cc_minor = arch % 10;
+      // Module may be changed by CompileToPtx.
+      auto llvmModuleCopy = llvm::CloneModule(*llvmModule);
+      TF_ASSIGN_OR_RETURN(
+          std::string ptx,
+          xla::gpu::nvptx::CompileToPtx(llvmModuleCopy.get(),
+                                        std::make_pair(cc_major, cc_minor),
+                                        config, libdevice_dir, enable_fusion));
+      // TODO(b/169066682): If compute_XX profile, collect PTX image here.
+      VLOG(1) << ptx;
+      TF_ASSIGN_OR_RETURN(std::vector<uint8_t> gpu_asm,
+                          tensorflow::se::CompileGpuAsm(
+                              cc_major, cc_minor, ptx.c_str(), gpu_asm_opts));
 
-    return tensorflow::se::CompileGpuAsm(cc_major, cc_minor, ptx.c_str(),
-                                         xla::gpu::PtxOptsFromConfig(config));
+      if (!generate_fatbin_) {
+        // Skip fatbin generation and return the first and only GPU machine
+        // code.
+        return gpu_asm;
+      }
+
+      // Collect cubin image.
+      images.push_back({absl::StrCat("sm_", arch), std::move(gpu_asm)});
+    }
+
+    // TODO(b/169870789): Revisit the use of fatbins.
+    // Bundle cubin and PTX images into a single fatbin.
+    return tensorflow::se::BundleGpuAsm(images,
+                                        gpu_asm_opts.preferred_cuda_dir);
 #endif
+
     return InternalError(
         "Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined."
         " Did you specify either --config=rocm or --config=cuda ?");
@@ -141,8 +183,10 @@
 }  // namespace
 
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
-    mlir::StringRef blob_annotation, int32_t architecture) {
-  return std::make_unique<GpuKernelToBlobPass>(blob_annotation, architecture);
+    mlir::StringRef blob_annotation, ArrayRef<uint32_t> architectures,
+    bool generate_fatbin) {
+  return std::make_unique<GpuKernelToBlobPass>(blob_annotation, architectures,
+                                               generate_fatbin);
 }
 
 }  // namespace transforms
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
index 2ef863a..43e4646 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
@@ -61,7 +61,8 @@
 
 // Pass to annotate GPU Module with its PTX.
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>> CreateGpuKernelToBlobPass(
-    mlir::StringRef blob_annotation = "", int32_t architecture = 0);
+    mlir::StringRef blob_annotation = "", ArrayRef<uint32_t> architectures = {},
+    bool generate_fatbin = true);
 
 // Pass to unfuse batch norm.
 std::unique_ptr<FunctionPass> CreateUnfuseBatchNormPass();
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
index 5bdd466..e84971b 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
@@ -53,7 +53,10 @@
   let options = [
     Option<"blob_annotation_", "blob-annotation", "std::string",
            /*default=*/"", "Blob attribute name">,
-    Option<"arch_", "arch", "int32_t", /*default=*/"0", "GPU architecture">,
+    ListOption<"architectures_", "arch", "uint32_t", "GPU architectures">,
+    Option<"generate_fatbin_", "generate-fatbin", "bool", /*default=*/"true",
+           "Bundle machine code for the different architectures in one "
+           "fatbin.">,
   ];
   let constructor = "transforms::CreateGpuKernelToBlobPass()";
 }
diff --git a/tensorflow/core/kernels/mlir_generated/build_defs.bzl b/tensorflow/core/kernels/mlir_generated/build_defs.bzl
index 93e2e55..79944cf 100644
--- a/tensorflow/core/kernels/mlir_generated/build_defs.bzl
+++ b/tensorflow/core/kernels/mlir_generated/build_defs.bzl
@@ -296,9 +296,6 @@
         archs_trimmed.append(arch[3:])
     arch_flag = ",".join(archs_trimmed)
 
-    # TODO(b/169066682): Generate Fatbin when lowering GPU module.
-    arch_flag = "75"
-
     filename = "%s.a" % (name)
     gpu_bin = ctx.outputs.output
     ctx.actions.run(
diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD
index ea65d7a..56d24bf 100644
--- a/tensorflow/stream_executor/cuda/BUILD
+++ b/tensorflow/stream_executor/cuda/BUILD
@@ -104,7 +104,7 @@
 
 # Buildozer can not remove dependencies inside select guards, so we have to use
 # an intermediate target.
-cc_library(name = "ptxas_wrapper")
+cc_library(name = "cuda_root_wrapper")
 
 cc_library(
     name = "cuda_driver",
diff --git a/tensorflow/stream_executor/gpu/BUILD b/tensorflow/stream_executor/gpu/BUILD
index 6328fa4..7fbb40e 100644
--- a/tensorflow/stream_executor/gpu/BUILD
+++ b/tensorflow/stream_executor/gpu/BUILD
@@ -250,7 +250,7 @@
         "@com_google_absl//absl/container:flat_hash_map",
     ]) + if_cuda_is_configured([
         "//tensorflow/stream_executor/cuda:cuda_driver",
-        "//tensorflow/stream_executor/cuda:ptxas_wrapper",
+        "//tensorflow/stream_executor/cuda:cuda_root_wrapper",
     ]),
 )
 
diff --git a/tensorflow/stream_executor/gpu/asm_compiler.cc b/tensorflow/stream_executor/gpu/asm_compiler.cc
index 0f6fd4d..53f7650 100644
--- a/tensorflow/stream_executor/gpu/asm_compiler.cc
+++ b/tensorflow/stream_executor/gpu/asm_compiler.cc
@@ -140,34 +140,44 @@
   return CompileGpuAsm(cc_major, cc_minor, ptx_contents, options);
 }
 
-port::StatusOr<std::vector<uint8>> CompileGpuAsm(int cc_major, int cc_minor,
-                                                 const char* ptx_contents,
-                                                 GpuAsmOpts options) {
-  std::string ptxas_path;
-  auto env = tensorflow::Env::Default();
-  std::string ptxas_binary_name = "ptxas";
+static std::string findCudaExecutable(const std::string binary_name,
+                                      const std::string preferred_cuda_dir) {
 #if defined(PLATFORM_WINDOWS)
-  ptxas_binary_name += ".exe";
+  const std::string binary_filename = binary_name + ".exe";
+#else
+  const std::string& binary_filename = binary_name;
 #endif
 
+  // Search in cuda root candidates.
+  auto env = tensorflow::Env::Default();
+  std::string binary_path;
   for (const std::string& cuda_root :
-       tensorflow::CandidateCudaRoots(options.preferred_cuda_dir)) {
-    ptxas_path = tensorflow::io::JoinPath(cuda_root, "bin", ptxas_binary_name);
-    VLOG(2) << "Looking for ptxas at " << ptxas_path;
-    if (env->FileExists(ptxas_path).ok()) {
+       tensorflow::CandidateCudaRoots(preferred_cuda_dir)) {
+    binary_path = tensorflow::io::JoinPath(cuda_root, "bin", binary_filename);
+    VLOG(2) << "Looking for " << binary_filename << " at " << binary_path;
+    if (env->FileExists(binary_path).ok()) {
       break;
     }
   }
-  if (!env->FileExists(ptxas_path).ok()) {
+  if (!env->FileExists(binary_path).ok()) {
     // Rely on subprocess invocation to find the correct binary.
-    ptxas_path = ptxas_binary_name;
+    binary_path = binary_filename;
   }
-  VLOG(2) << "Using ptxas at " << ptxas_path;
+  VLOG(2) << "Using " << binary_filename << " at " << binary_path;
+  return binary_path;
+}
+
+port::StatusOr<std::vector<uint8>> CompileGpuAsm(int cc_major, int cc_minor,
+                                                 const char* ptx_contents,
+                                                 GpuAsmOpts options) {
+  std::string ptxas_path =
+      findCudaExecutable("ptxas", options.preferred_cuda_dir);
 
   WarnIfBadPtxasVersion(ptxas_path);
 
   // Write ptx into a temporary file.
   std::string ptx_path;
+  auto env = tensorflow::Env::Default();
   if (!env->LocalTempFilename(&ptx_path)) {
     return port::InternalError("couldn't get temp PTX file name");
   }
@@ -232,4 +242,78 @@
   return cubin_vector;
 }
 
+port::StatusOr<std::vector<uint8>> BundleGpuAsm(
+    std::vector<CubinOrPTXImage> images, const std::string preferred_cuda_dir) {
+  std::string fatbinary_path =
+      findCudaExecutable("fatbinary", preferred_cuda_dir);
+
+  // Write images to temporary files.
+  std::vector<std::string> image_paths;
+  auto env = tensorflow::Env::Default();
+  for (const CubinOrPTXImage& img : images) {
+    std::string img_path;
+    if (!env->LocalTempFilename(&img_path)) {
+      return port::InternalError(
+          "Could not get temporary filenames for images.");
+    }
+    TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(
+        env, img_path, std::string(img.bytes.begin(), img.bytes.end())));
+    VLOG(2) << "image written to " << img_path;
+    image_paths.push_back(std::move(img_path));
+  }
+  auto image_files_cleaner = tensorflow::gtl::MakeCleanup([&image_paths] {
+    for (const auto& path : image_paths) {
+      TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(path));
+    }
+  });
+
+  // Prepare temorary result file.
+  std::string result_path;
+  if (!env->LocalTempFilename(&result_path)) {
+    return port::InternalError(
+        "Could not get temporary filename for fatbin result.");
+  }
+  auto result_file_cleaner = tensorflow::gtl::MakeCleanup([&result_path] {
+    // This file may never be created, so the failure to delete it should not
+    // propagate to TF.
+    tensorflow::Env::Default()->DeleteFile(result_path).IgnoreError();
+  });
+
+  // Invoke fatbinary and collect its output.
+  tensorflow::SubProcess fatbinary;
+  std::vector<std::string> fatbinary_args = {
+      fatbinary_path, "--64",           "--cmdline=--compile-only",
+      "--link",       "--compress-all", absl::StrCat("--create=", result_path)};
+  assert(images.size() == image_paths.size());
+  for (int i = 0; i < images.size(); i++) {
+    fatbinary_args.push_back(absl::StrFormat(
+        "--image=profile=%s,file=%s", images[i].profile, image_paths[i]));
+  }
+  if (VLOG_IS_ON(3)) {
+    VLOG(3) << absl::StrJoin(fatbinary_args, " ");
+  }
+  fatbinary.SetProgram(fatbinary_path, fatbinary_args);
+  fatbinary.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
+  if (!fatbinary.Start()) {
+    return port::InternalError("Failed to launch fatbinary.");
+  }
+  std::string stderr_output;
+  int exit_status = fatbinary.Communicate(
+      /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output);
+  if (exit_status != 0) {
+    return port::InternalError(absl::StrFormat(
+        "fatbinary exited with non-zero error code %d, output: %s", exit_status,
+        stderr_output));
+  }
+  if (!stderr_output.empty()) {
+    VLOG(2) << stderr_output;
+  }
+
+  // Read in the result and return it as a byte vector.
+  std::string result_blob;
+  TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
+                                                  result_path, &result_blob));
+  return std::vector<uint8>(result_blob.begin(), result_blob.end());
+}
+
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/gpu/asm_compiler.h b/tensorflow/stream_executor/gpu/asm_compiler.h
index e5f67a7..513ac6c 100644
--- a/tensorflow/stream_executor/gpu/asm_compiler.h
+++ b/tensorflow/stream_executor/gpu/asm_compiler.h
@@ -52,6 +52,16 @@
 port::StatusOr<absl::Span<const uint8>> CompileGpuAsmOrGetCached(
     int device_ordinal, const char* ptx, GpuAsmOpts compilation_options);
 
+struct CubinOrPTXImage {
+  std::string profile;
+  std::vector<uint8> bytes;
+};
+
+// Bundles the GPU machine code (cubins) and PTX if requested and returns the
+// resulting binary (i.e. a fatbin) as a byte array.
+port::StatusOr<std::vector<uint8>> BundleGpuAsm(
+    std::vector<CubinOrPTXImage> images, const std::string preferred_cuda_dir);
+
 }  // namespace stream_executor
 
 #endif  // TENSORFLOW_STREAM_EXECUTOR_GPU_ASM_COMPILER_H_