Crash fix cudaMallocAsync usage of TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC.
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.cc
index acc40fb..6d26a76 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.cc
@@ -100,7 +100,8 @@
 GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator(
     PlatformDeviceId platform_device_id, size_t pool_size, bool reserve_memory,
     bool compute_stats)
-    : name_(absl::StrCat("gpu_async_", platform_device_id.value())) {
+    : name_(absl::StrCat("gpu_async_", platform_device_id.value())),
+      reserve_memory_(reserve_memory) {
   ++number_instantiated_;
 #if TF_CUDA_MALLOC_ASYNC_SUPPORTED
   stream_exec_ = DeviceIdUtil::ExecutorForPlatformDeviceId(GPUMachineManager(),
@@ -255,25 +256,6 @@
   all_ids_->push_back(platform_device_id);
 
   VLOG(2) << Name() << " GpuCudaMallocAsyncAllocator PoolSize " << pool_size;
-  int64 prealloc_size = 0;
-  // TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=-1 is a special value that
-  // preallocates the total pool size.
-  TF_CHECK_OK(ReadInt64FromEnvVar("TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC", 0,
-                                  &prealloc_size));
-  if (prealloc_size == -1) {
-    prealloc_size = pool_size;
-  } else if (reserve_memory) {
-    prealloc_size = pool_size;
-  }
-
-  if (prealloc_size != 0) {
-    void* ptr = AllocateRaw(0, prealloc_size);
-    DeallocateRaw(ptr);
-    VLOG(2) << Name() << " GpuCudaMallocAsyncAllocator reserved the pool for "
-            << prealloc_size << " bytes"
-            << ". First ptr: " << ptr;
-    ClearStats();
-  }
 #else   // TF_CUDA_MALLOC_ASYNC_SUPPORTED
   LOG(FATAL) << "GpuCudaMallocAsyncAllocator requires CUDA 11.2+";  // Crash OK.
 #endif  // TF_CUDA_MALLOC_ASYNC_SUPPORTED
@@ -397,4 +379,36 @@
   return true;
 }
 
+void GpuCudaMallocAsyncAllocator::SetStream(void* stream) {
+#if TF_CUDA_MALLOC_ASYNC_SUPPORTED
+  uint64_t pool_size_64 = 0;
+  if (auto status = cuMemPoolGetAttribute(
+          pool_, CU_MEMPOOL_ATTR_RELEASE_THRESHOLD, &pool_size_64)) {
+    LOG(FATAL) <<  // Crash OK.
+        "Failed to get CUDA pool attribute: " << GetCudaErrorMessage(status);
+
+  }
+  cuda_stream_ = *(reinterpret_cast<CUstream*>(stream));
+  int64 prealloc_size = 0;
+  // TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=-1 is a special value that
+  // preallocates the total pool size.
+  TF_CHECK_OK(ReadInt64FromEnvVar("TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC", 0,
+                                  &prealloc_size));
+  if (prealloc_size == -1) {
+    prealloc_size = pool_size_64;
+  } else if (reserve_memory_) {
+    prealloc_size = pool_size_64;
+  }
+
+  if (prealloc_size != 0) {
+    void* ptr = AllocateRaw(0, prealloc_size);
+    DeallocateRaw(ptr);
+    VLOG(2) << Name() << " GpuCudaMallocAsyncAllocator reserved the pool for "
+            << prealloc_size << " bytes"
+            << ". First ptr: " << ptr;
+    ClearStats();
+  }
+#endif
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.h
index af61ad5..1d6ef9d 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.h
@@ -83,11 +83,7 @@
 
   bool ClearStats() override;
 
-  void SetStream(void* stream) override {
-#if TF_CUDA_MALLOC_ASYNC_SUPPORTED
-    cuda_stream_ = *(static_cast<CUstream*>(stream));
-#endif
-  }
+  void SetStream(void* stream) override;
 
   // With the right VLOG set, it prints:
   // - the number of ptr currently allocated per size (histogram).
@@ -120,6 +116,8 @@
 
   string name_;
 
+  bool reserve_memory_;
+
   TF_DISALLOW_COPY_AND_ASSIGN(GpuCudaMallocAsyncAllocator);
 
   // Stats.
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
index 560d211..c2a8318 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
@@ -25,6 +25,7 @@
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
@@ -158,6 +159,39 @@
   EXPECT_EQ(status.code(), error::OK);
 }
 
+TEST_F(GPUDeviceTest, CudaMallocAsyncPreallocate) {
+  SessionOptions opts = MakeSessionOptions("0", 0, 1, {}, {},
+                                           /*use_cuda_malloc_async=*/true);
+  setenv("TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC", "2048", 1);
+  std::vector<std::unique_ptr<Device>> devices;
+  Status status;
+
+  int number_instantiated = GpuCudaMallocAsyncAllocator::GetInstantiatedCountTestOnly();
+  { // The new scope is to trigger the destruction of the object.
+    status = DeviceFactory::GetFactory("GPU")->CreateDevices(
+        opts, kDeviceNamePrefix, &devices);
+    EXPECT_EQ(devices.size(), 1);
+    Device* device = devices[0].get();
+    auto* device_info = device->tensorflow_gpu_device_info();
+    CHECK(device_info);
+    DeviceContext* device_context = device_info->default_context;
+
+    AllocatorAttributes allocator_attributes = AllocatorAttributes();
+    allocator_attributes.set_gpu_compatible(true);
+    Allocator* allocator = devices[0]->GetAllocator(allocator_attributes);
+    void* ptr = allocator->AllocateRaw(Allocator::kAllocatorAlignment,
+                                       1024);
+    EXPECT_NE(ptr, nullptr);
+    allocator->DeallocateRaw(ptr);
+  }
+
+  unsetenv("TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC");
+
+  EXPECT_EQ(number_instantiated + 1 , GpuCudaMallocAsyncAllocator::GetInstantiatedCountTestOnly());
+
+  EXPECT_EQ(status.code(), error::OK);
+}
+
 TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) {
   SessionOptions opts = MakeSessionOptions("0,abc");
   std::vector<std::unique_ptr<Device>> devices;