Add function to materialize COW storages (#113396)

Part of #109833

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113396
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp
index 459960b..e2518dc 100644
--- a/aten/src/ATen/EmptyTensor.cpp
+++ b/aten/src/ATen/EmptyTensor.cpp
@@ -311,6 +311,7 @@
   DeleterFnPtr raw_deleter() const override {
     return deleter;
   }
+  void copy_data(void* dest, const void* src, std::size_t count) const final {}
 };
 
 static MetaAllocator g_meta_alloc;
diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp
index 22dbb66..3c63ed7 100644
--- a/aten/src/ATen/cuda/CachingHostAllocator.cpp
+++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp
@@ -301,6 +301,10 @@
     }
   }
 
+  void copy_data(void* dest, const void* src, std::size_t count) const {
+    TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for CUDAHostAllocator");
+  }
+
  private:
   void process_events() {
     while (true) {
@@ -496,6 +500,10 @@
         &CUDAHostAllocatorDeleter,
         at::DeviceType::CPU};
   }
+
+  void copy_data(void* dest, const void* src, std::size_t count) const final {
+    getCUDAHostAllocator().copy_data(dest, src, count);
+  }
 };
 
 static CUDAHostAllocatorWrapper cuda_host_allocator;
diff --git a/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h
index c5a607c..e4c0cec 100644
--- a/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h
+++ b/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h
@@ -23,6 +23,9 @@
   DeleterFnPtr raw_deleter() const override {
     return allocator_->raw_deleter();
   }
+  void copy_data(void* dest, const void* src, std::size_t count) const final {
+    allocator_->copy_data(dest, src, count);
+  }
 };
 
 }} // namespace c10::hip
diff --git a/aten/src/ATen/mps/MPSAllocator.mm b/aten/src/ATen/mps/MPSAllocator.mm
index fba4c82..32a9b31 100644
--- a/aten/src/ATen/mps/MPSAllocator.mm
+++ b/aten/src/ATen/mps/MPSAllocator.mm
@@ -819,6 +819,10 @@
     return _getAllocImpl().format_size(size);
   }
 
+  void copy_data(void* dest, const void* src, std::size_t count) const final {
+    default_copy_data(dest, src, count);
+  }
+
  private:
   bool m_has_unified_memory;
   uint32_t m_usage;
diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp
index 0369e52..eeef680 100644
--- a/aten/src/ATen/native/Resize.cpp
+++ b/aten/src/ATen/native/Resize.cpp
@@ -94,13 +94,14 @@
   if (size_bytes != 0) {
     new_data = storage->allocator()->allocate(size_bytes);
   }
-  at::DataPtr old_data = storage->set_data_ptr(std::move(new_data));
+  const at::DataPtr& old_data = storage->data_ptr();
   const auto old_capacity = storage->nbytes();
-  storage->set_nbytes(size_bytes);
   const auto copy_capacity = std::min(size_bytes, old_capacity);
   if (old_data != nullptr && copy_capacity > 0) {
-    memcpy(storage->mutable_data(), old_data.get(), copy_capacity);
+    memcpy(new_data.get(), old_data.get(), copy_capacity);
   }
+  storage->set_data_ptr_noswap(std::move(new_data));
+  storage->set_nbytes(size_bytes);
 }
 
 // Call the sparse implementation in SparseTensor.cpp directly.
diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h
index fbb3a11..84de45a 100644
--- a/aten/src/ATen/native/TensorFactories.h
+++ b/aten/src/ATen/native/TensorFactories.h
@@ -129,6 +129,7 @@
   DeleterFnPtr raw_deleter() const override {
     return deleter;
   }
+  void copy_data(void* dest, const void* src, std::size_t count) const final {}
   at::Device device_;
 };
 
diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt
index aa97677..d81ba78 100644
--- a/aten/src/ATen/test/CMakeLists.txt
+++ b/aten/src/ATen/test/CMakeLists.txt
@@ -13,6 +13,7 @@
   ${CMAKE_CURRENT_SOURCE_DIR}/atest.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/basic.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/broadcast_test.cpp
+  ${CMAKE_CURRENT_SOURCE_DIR}/cpu_allocator_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/cpu_generator_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/cpu_profiling_allocator_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/cpu_rng_test.cpp
@@ -54,6 +55,7 @@
   )
 
 list(APPEND ATen_CUDA_TEST_SRCS
+  ${CMAKE_CURRENT_SOURCE_DIR}/cuda_allocator_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/cuda_apply_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/cuda_atomic_ops_test.cu
   ${CMAKE_CURRENT_SOURCE_DIR}/cuda_caching_host_allocator_test.cpp
diff --git a/aten/src/ATen/test/allocator_clone_test.h b/aten/src/ATen/test/allocator_clone_test.h
new file mode 100644
index 0000000..79a1f5f
--- /dev/null
+++ b/aten/src/ATen/test/allocator_clone_test.h
@@ -0,0 +1,35 @@
+#pragma once
+#include <gtest/gtest.h>
+#include <ATen/ATen.h>
+
+void test_allocator_clone(c10::Allocator* allocator) {
+  ASSERT_TRUE(allocator != nullptr);
+
+  c10::Storage a_storage(c10::make_intrusive<c10::StorageImpl>(
+    c10::StorageImpl::use_byte_size_t(),
+    0,
+    allocator,
+    /*resizable=*/true));
+
+  c10::Storage b_storage(c10::make_intrusive<c10::StorageImpl>(
+    c10::StorageImpl::use_byte_size_t(),
+    0,
+    allocator,
+    /*resizable=*/true));
+
+  at::Tensor a = at::empty({0}, at::TensorOptions().device(a_storage.device())).set_(a_storage);
+  at::Tensor b = at::empty({0}, at::TensorOptions().device(b_storage.device())).set_(b_storage);
+
+  std::vector<int64_t> sizes({13, 4, 5});
+
+  at::rand_out(a, sizes);
+  at::rand_out(b, sizes);
+
+  ASSERT_TRUE(a_storage.nbytes() == static_cast<size_t>(a.numel() * a.element_size()));
+  ASSERT_TRUE(a_storage.nbytes() == b_storage.nbytes());
+
+  void* a_data_ptr = a_storage.mutable_data();
+  b_storage.set_data_ptr(allocator->clone(a_data_ptr, a_storage.nbytes()));
+
+  ASSERT_TRUE((a == b).all().item<bool>());
+}
diff --git a/aten/src/ATen/test/cpu_allocator_test.cpp b/aten/src/ATen/test/cpu_allocator_test.cpp
new file mode 100644
index 0000000..db98522
--- /dev/null
+++ b/aten/src/ATen/test/cpu_allocator_test.cpp
@@ -0,0 +1,10 @@
+#include <gtest/gtest.h>
+
+#include <c10/core/CPUAllocator.h>
+#include <ATen/ATen.h>
+
+#include <ATen/test/allocator_clone_test.h>
+
+TEST(AllocatorTestCPU, test_clone) {
+  test_allocator_clone(c10::GetDefaultCPUAllocator());
+}
diff --git a/aten/src/ATen/test/cuda_allocator_test.cpp b/aten/src/ATen/test/cuda_allocator_test.cpp
new file mode 100644
index 0000000..27a352e
--- /dev/null
+++ b/aten/src/ATen/test/cuda_allocator_test.cpp
@@ -0,0 +1,10 @@
+#include <gtest/gtest.h>
+
+#include <ATen/ATen.h>
+#include <c10/cuda/CUDACachingAllocator.h>
+
+#include <ATen/test/allocator_clone_test.h>
+
+TEST(AllocatorTestCUDA, test_clone) {
+  test_allocator_clone(c10::cuda::CUDACachingAllocator::get());
+}
diff --git a/aten/src/ATen/test/xla_tensor_test.cpp b/aten/src/ATen/test/xla_tensor_test.cpp
index 21f911e..1c9e392 100644
--- a/aten/src/ATen/test/xla_tensor_test.cpp
+++ b/aten/src/ATen/test/xla_tensor_test.cpp
@@ -2,6 +2,8 @@
 
 #include <ATen/ATen.h>
 
+#include <ATen/test/allocator_clone_test.h>
+
 using namespace at;
 
 void XLAFree(void *ptr) {
@@ -22,6 +24,9 @@
   at::DeleterFnPtr raw_deleter() const override {
     return &XLAFree;
   }
+  void copy_data(void* dest, const void* src, std::size_t count) const final {
+    default_copy_data(dest, src, count);
+  }
 };
 
 TEST(XlaTensorTest, TestNoStorage) {
@@ -33,3 +38,11 @@
   at::Tensor t(std::move(tensor_impl));
   ASSERT_TRUE(t.device() == at::Device(DeviceType::XLA, 0));
 }
+
+TEST(XlaTensorTest, test_allocator_clone) {
+  if (!at::hasXLA()) {
+    return;
+  }
+  XLAAllocator allocator;
+  test_allocator_clone(&allocator);
+}
diff --git a/c10/core/Allocator.cpp b/c10/core/Allocator.cpp
index dada5bb..c95fdec 100644
--- a/c10/core/Allocator.cpp
+++ b/c10/core/Allocator.cpp
@@ -4,6 +4,23 @@
 
 namespace c10 {
 
+DataPtr Allocator::clone(const void* data, std::size_t n) const {
+  DataPtr new_data = allocate(n);
+  copy_data(new_data.mutable_get(), data, n);
+  return new_data;
+}
+
+void Allocator::default_copy_data(
+    void* dest,
+    const void* src,
+    std::size_t count) const {
+  std::memcpy(dest, src, count);
+}
+
+bool Allocator::is_simple_data_ptr(const DataPtr& data_ptr) const {
+  return data_ptr.get() == data_ptr.get_context();
+}
+
 static void deleteInefficientStdFunctionContext(void* ptr) {
   delete static_cast<InefficientStdFunctionContext*>(ptr);
 }
diff --git a/c10/core/Allocator.h b/c10/core/Allocator.h
index c7c4c3a..2407174 100644
--- a/c10/core/Allocator.h
+++ b/c10/core/Allocator.h
@@ -156,6 +156,21 @@
 
   virtual DataPtr allocate(size_t n) const = 0;
 
+  // Clones an allocation that came from this allocator.
+  //
+  // To perform the copy, this function calls `copy_data`, which
+  // must be implemented by derived classes.
+  //
+  // Note that this explicitly ignores any context that may have been
+  // attached to the input data.
+  //
+  // Requires: input data was allocated by the same allocator.
+  DataPtr clone(const void* data, std::size_t n) const;
+
+  // Checks if DataPtr has a simple context, not wrapped with any out of the
+  // ordinary contexts.
+  virtual bool is_simple_data_ptr(const DataPtr& data_ptr) const;
+
   // If this returns a non nullptr, it means that allocate()
   // is guaranteed to return a unique_ptr with this deleter attached;
   // it means the rawAllocate and rawDeallocate APIs are safe to use.
@@ -173,6 +188,22 @@
     AT_ASSERT(d);
     d(ptr);
   }
+
+  // Copies data from one allocation to another.
+  // Pure virtual, so derived classes must define behavior.
+  // Derived class implementation can simply call `default_copy_data`
+  // to use `std::memcpy`.
+  //
+  // Requires: src and dest were allocated by this allocator
+  // Requires: src and dest both have length >= count
+  virtual void copy_data(void* dest, const void* src, std::size_t count)
+      const = 0;
+
+ protected:
+  // Uses `std::memcpy` to copy data.
+  // Child classes can use this as `copy_data` when an alternative copy
+  // API is not needed.
+  void default_copy_data(void* dest, const void* src, std::size_t count) const;
 };
 
 // This context is used to generate DataPtr which have arbitrary
diff --git a/c10/core/CPUAllocator.cpp b/c10/core/CPUAllocator.cpp
index c103c42..0475904 100644
--- a/c10/core/CPUAllocator.cpp
+++ b/c10/core/CPUAllocator.cpp
@@ -40,6 +40,10 @@
   at::DeleterFnPtr raw_deleter() const override {
     return &ReportAndDelete;
   }
+
+  void copy_data(void* dest, const void* src, std::size_t count) const final {
+    default_copy_data(dest, src, count);
+  }
 };
 
 ProfiledCPUMemoryReporter& profiledCPUMemoryReporter() {
@@ -142,6 +146,16 @@
   DeleterFnPtr raw_deleter() const override {
     return deleter;
   }
+
+  bool is_simple_data_ptr(const c10::DataPtr& data_ptr) const final {
+    return reinterpret_cast<const uint8_t*>(data_ptr.get()) ==
+        reinterpret_cast<const uint8_t*>(data_ptr.get_context()) +
+        PreGuardBytes;
+  }
+
+  void copy_data(void* dest, const void* src, std::size_t count) const final {
+    default_copy_data(dest, src, count);
+  }
 };
 
 void NoDelete(void*) {}
diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h
index 082b5df..739e3f3 100644
--- a/c10/core/StorageImpl.h
+++ b/c10/core/StorageImpl.h
@@ -3,6 +3,8 @@
 #include <c10/core/Allocator.h>
 #include <c10/core/SymInt.h>
 #include <c10/core/impl/PyObjectSlot.h>
+#include <c10/core/impl/cow/COW.h>
+#include <c10/core/impl/cow/COWDeleter.h>
 
 #include <c10/util/intrusive_ptr.h>
 
@@ -111,6 +113,7 @@
   }
 
   at::DataPtr& mutable_data_ptr() {
+    maybe_materialize_cow();
     return data_ptr_;
   }
 
@@ -120,9 +123,10 @@
 
   // Returns the previous data_ptr
   at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) {
-    at::DataPtr old_data_ptr(std::move(data_ptr_));
-    data_ptr_ = std::move(data_ptr);
-    return old_data_ptr;
+    // We need to materialize the old COW DataPtr because it is
+    // being returned as mutable.
+    maybe_materialize_cow();
+    return set_data_ptr_no_materialize_cow(std::move(data_ptr));
   }
 
   void set_data_ptr_noswap(at::DataPtr&& data_ptr) {
@@ -134,6 +138,7 @@
   }
 
   void* mutable_data() {
+    maybe_materialize_cow();
     return data_ptr_.mutable_get();
   }
 
@@ -211,7 +216,26 @@
     return &pyobj_slot_;
   }
 
+ protected:
+  // materialize_cow_storage needs to call set_data_ptr_no_materlize_cow
+  friend void c10::impl::cow::materialize_cow_storage(StorageImpl& storage);
+
+  // Returns the previous data_ptr. If the old data_ptr was COW,
+  // this avoids materializing it
+  at::DataPtr set_data_ptr_no_materialize_cow(at::DataPtr&& data_ptr) {
+    at::DataPtr old_data_ptr(std::move(data_ptr_));
+    data_ptr_ = std::move(data_ptr);
+    return old_data_ptr;
+  }
+
  private:
+  // Triggers a copy if this is a copy-on-write tensor.
+  void maybe_materialize_cow() {
+    if (data_ptr_.get_deleter() == impl::cow::cow_deleter) {
+      impl::cow::materialize_cow_storage(*this);
+    }
+  }
+
   DataPtr data_ptr_;
   SymInt size_bytes_;
   bool size_bytes_is_heap_allocated_;
diff --git a/c10/core/build.bzl b/c10/core/build.bzl
index 781c740..f395b9b 100644
--- a/c10/core/build.bzl
+++ b/c10/core/build.bzl
@@ -58,22 +58,22 @@
             [
                 "*.cpp",
                 "impl/*.cpp",
+                "impl/cow/*.cpp",
             ],
             exclude = [
                 "CPUAllocator.cpp",
                 "impl/alloc_cpu.cpp",
-                "impl/cow/*.cpp",
             ],
         ),
         hdrs = rules.glob(
             [
                 "*.h",
                 "impl/*.h",
+                "impl/cow/*.h",
             ],
             exclude = [
                 "CPUAllocator.h",
                 "impl/alloc_cpu.h",
-                "impl/cow/*.h",
             ],
         ),
         linkstatic = True,
@@ -92,22 +92,6 @@
         alwayslink = True,
     )
 
-    rules.cc_library(
-        name = "impl_cow",
-        srcs = rules.glob([
-            "impl/cow/*.cpp",
-        ]),
-        hdrs = rules.glob([
-            "impl/cow/*.h",
-        ]),
-        deps = [
-            ":base",
-            ":CPUAllocator",
-        ],
-        visibility = ["//c10/test:__pkg__"],
-
-    )
-
     rules.filegroup(
         name = "headers",
         srcs = rules.glob(
diff --git a/c10/core/impl/cow/COW.cpp b/c10/core/impl/cow/COW.cpp
index f32e0ea..e2ed465 100644
--- a/c10/core/impl/cow/COW.cpp
+++ b/c10/core/impl/cow/COW.cpp
@@ -1,7 +1,6 @@
 #include <c10/core/impl/cow/COW.h>
 
 #include <c10/core/Allocator.h>
-#include <c10/core/CPUAllocator.h>
 #include <c10/core/StorageImpl.h>
 #include <c10/core/alignment.h>
 #include <c10/core/impl/cow/COWDeleter.h>
@@ -30,24 +29,18 @@
   return make_data_ptr(data_ptr, *ctx);
 }
 
-bool is_simple_context(
-    const void* context,
-    const void* data,
-    const at::Allocator* allocator) {
-  if (allocator == c10::GetDefaultMobileCPUAllocator()) {
-    return reinterpret_cast<size_t>(data) ==
-        reinterpret_cast<size_t>(context) + c10::gAlignment;
-  } else {
-    return data == context;
-  }
-}
-
 } // namespace
 
 bool has_simple_data_ptr(const c10::StorageImpl& storage) {
   const c10::DataPtr& data_ptr = storage.data_ptr();
-  return is_simple_context(
-      data_ptr.get_context(), data_ptr.get(), storage.allocator());
+  const void* ctx = data_ptr.get_context();
+  const void* data = data_ptr.get();
+  const c10::Allocator* allocator = storage.allocator();
+  if (allocator != nullptr) {
+    return allocator->is_simple_data_ptr(data_ptr);
+  } else {
+    return ctx == data;
+  }
 }
 
 bool is_cow_data_ptr(const c10::DataPtr& data_ptr) {
@@ -88,8 +81,6 @@
     // Case 1) We have a simple data pointer: wrap it.
     std::unique_ptr<void, DeleterFnPtr> original_ctx =
         storage.mutable_data_ptr().move_context();
-    TORCH_INTERNAL_ASSERT(is_simple_context(
-        original_ctx.get(), data_ptr.get(), storage.allocator()));
 
     // Save this for the result.
     new_data_ptr = make_data_ptr(
@@ -117,4 +108,40 @@
       storage.resizable());
 }
 
+C10_API void materialize_cow_storage(StorageImpl& storage) {
+  const at::DataPtr& data_ptr = storage.data_ptr();
+
+  auto* ctx = data_ptr.cast_context<cow::COWDeleterContext>(cow::cow_deleter);
+  TORCH_INTERNAL_ASSERT(ctx != nullptr);
+
+  auto result = ctx->decrement_refcount();
+
+  // This must be set by each branch below.
+  std::optional<DataPtr> new_data_ptr;
+
+  if (std::holds_alternative<cow::COWDeleterContext::LastReference>(result)) {
+    // This is the only reference to the data. If there were any racing writes,
+    // the context ensured they finished before giving us the result.
+    std::unique_ptr<void, DeleterFnPtr> data =
+        std::get<cow::COWDeleterContext::LastReference>(std::move(result));
+    TORCH_INTERNAL_ASSERT(data.get() == data_ptr.get());
+    new_data_ptr = DataPtr(
+        data.release(), data_ptr.get(), data.get_deleter(), data_ptr.device());
+  } else {
+    TORCH_INTERNAL_ASSERT(
+        std::holds_alternative<cow::COWDeleterContext::NotLastReference>(
+            result));
+    // We don't need to consume the result, it's just a shared lock ensuring
+    // that the data will remain while we copy it.
+    new_data_ptr = storage.allocator()->clone(data_ptr.get(), storage.nbytes());
+  }
+
+  TORCH_INTERNAL_ASSERT(new_data_ptr.has_value());
+  DataPtr old_data_ptr =
+      storage.set_data_ptr_no_materialize_cow(*std::move(new_data_ptr));
+  // The refcount of the context was already decremented above. Release the
+  // reference to the context so the refcount doesn't get decremented again
+  old_data_ptr.release_context();
+}
+
 } // namespace c10::impl::cow
diff --git a/c10/core/impl/cow/COW.h b/c10/core/impl/cow/COW.h
index 07ef5e4..1cf81ed 100644
--- a/c10/core/impl/cow/COW.h
+++ b/c10/core/impl/cow/COW.h
@@ -26,4 +26,7 @@
 // Check if a DataPtr is COW
 C10_API bool is_cow_data_ptr(const c10::DataPtr& data_ptr);
 
+// Eagerly copies a COW storage's data, turning it into a non-COW storage.
+C10_API void materialize_cow_storage(StorageImpl& storage);
+
 } // namespace c10::impl::cow
diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp
index 5f30127..eab2fb4 100644
--- a/c10/cuda/CUDACachingAllocator.cpp
+++ b/c10/cuda/CUDACachingAllocator.cpp
@@ -3235,6 +3235,10 @@
   std::string name() override {
     return "native";
   }
+  void copy_data(void* dest, const void* src, std::size_t count) const final {
+    C10_CUDA_CHECK(
+        cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
+  }
 };
 
 NativeCachingAllocator allocator;
diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp
index 222ec45..4982fc6 100644
--- a/c10/cuda/CUDAMallocAsyncAllocator.cpp
+++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp
@@ -875,6 +875,10 @@
   std::string name() override {
     return "cudaMallocAsync";
   }
+  void copy_data(void* dest, const void* src, std::size_t count) const final {
+    C10_CUDA_CHECK(
+        cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
+  }
 };
 
 CudaMallocAsyncAllocator device_allocator;
diff --git a/c10/test/build.bzl b/c10/test/build.bzl
index d9b5897..2f54c8a 100644
--- a/c10/test/build.bzl
+++ b/c10/test/build.bzl
@@ -21,7 +21,6 @@
             "//c10/core:base",
             "//c10/util:base",
             "//c10/core:CPUAllocator",
-            "//c10/core:impl_cow",
             "@com_google_googletest//:gtest_main",
         ],
     )
diff --git a/c10/test/core/impl/cow_test.cpp b/c10/test/core/impl/cow_test.cpp
index 5fd30f5..a87bf4d 100644
--- a/c10/test/core/impl/cow_test.cpp
+++ b/c10/test/core/impl/cow_test.cpp
@@ -167,6 +167,72 @@
   ASSERT_THAT(new_storage->data(), testing::Eq(original_storage.data()));
 }
 
+TEST(materialize_test, not_copy_on_write_context) {
+  StorageImpl storage(
+      {}, /*size_bytes=*/6, GetCPUAllocator(), /*resizable=*/false);
+  ASSERT_THAT(storage, testing::Not(is_copy_on_write()));
+
+  void const* original_data = storage.data();
+
+  // Nothing to materialize.
+  ASSERT_THAT(storage.mutable_data(), testing::Eq(original_data));
+}
+
+TEST(materialize_test, copy_on_write_single_reference) {
+  // A copy-on-write storage with only a single reference can just
+  // drop the copy-on-write context upon materialization.
+  std::unique_ptr<void, DeleterFnPtr> data(
+      new std::byte[4],
+      +[](void* bytes) { delete[] static_cast<std::byte*>(bytes); });
+  void* data_ptr = data.get();
+  StorageImpl storage(
+      {},
+      /*size_bytes=*/4,
+      at::DataPtr(
+          /*data=*/data_ptr,
+          /*ctx=*/new cow::COWDeleterContext(std::move(data)),
+          cow::cow_deleter,
+          Device(Device::Type::CPU)),
+      /*allocator=*/nullptr,
+      /*resizable=*/false);
+
+  ASSERT_THAT(storage, is_copy_on_write());
+
+  ASSERT_THAT(storage.data(), testing::Eq(data_ptr));
+
+  void const* original_data = storage.data();
+
+  // Materializes storage. Only reference, so no new allocation.
+  ASSERT_THAT(storage.mutable_data(), testing::Eq(original_data));
+  // But it is no longer copy-on-write.
+  ASSERT_THAT(storage, testing::Not(is_copy_on_write()));
+}
+
+TEST(materialize_test, copy_on_write) {
+  StorageImpl original_storage(
+      {}, /*size_bytes=*/6, GetCPUAllocator(), /*resizable=*/false);
+  std::memcpy(original_storage.mutable_data(), "abcd", 4);
+  void const* original_data = original_storage.data();
+
+  auto new_storage = cow::lazy_clone_storage(original_storage);
+  ASSERT_THAT(new_storage, testing::NotNull());
+
+  auto context = new_storage->data_ptr().cast_context<cow::COWDeleterContext>(
+      cow::cow_deleter);
+  ASSERT_THAT(context, testing::NotNull());
+
+  // Materialized storage has new copy of data.
+  ASSERT_THAT(new_storage->mutable_data(), testing::Ne(original_data));
+
+  // But the original storage still has the original copy.
+  ASSERT_THAT(original_storage.data(), testing::Eq(original_data));
+
+  // But their data is the same.
+  ASSERT_THAT(
+      static_cast<char const*>(new_storage->data()),
+      testing::StrEq(static_cast<char const*>(original_storage.data())));
+}
+
 } // namespace
 } // namespace c10::impl
 // NOLINTEND(clang-analyzer-cplusplus*)
diff --git a/caffe2/core/context_gpu.cu b/caffe2/core/context_gpu.cu
index 3359e88..6555b97 100644
--- a/caffe2/core/context_gpu.cu
+++ b/caffe2/core/context_gpu.cu
@@ -336,6 +336,10 @@
     return &Delete;
   }
 
+  void copy_data(void* dest, const void* src, std::size_t count) const final {
+    TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for PinnedCPUAllocator");
+  }
+
  private:
   static void Delete(void* data) {
     if (!data) {
@@ -581,6 +585,10 @@
     return &Delete;
   }
 
+  void copy_data(void* dest, const void* src, std::size_t count) const final {
+    TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for DefaultCUDAAllocator");
+  }
+
  private:
   static void Delete(void* ptr) {
     // lock the mutex
diff --git a/test/cpp_extensions/open_registration_extension.cpp b/test/cpp_extensions/open_registration_extension.cpp
index 5818d64..cfe49e0 100644
--- a/test/cpp_extensions/open_registration_extension.cpp
+++ b/test/cpp_extensions/open_registration_extension.cpp
@@ -188,6 +188,10 @@
   at::DeleterFnPtr raw_deleter() const override {
     return &ReportAndDelete;
   }
+
+  void copy_data(void* dest, const void* src, std::size_t count) const final {
+    default_copy_data(dest, src, count);
+  }
 };
 
 // Register our dummy allocator
diff --git a/test/inductor/extension_backends/extension_device.cpp b/test/inductor/extension_backends/extension_device.cpp
index fa2922a..2e86cde 100644
--- a/test/inductor/extension_backends/extension_device.cpp
+++ b/test/inductor/extension_backends/extension_device.cpp
@@ -81,6 +81,10 @@
   at::DeleterFnPtr raw_deleter() const override {
     return &ReportAndDelete;
   }
+
+  void copy_data(void* dest, const void* src, std::size_t count) const final {
+    default_copy_data(dest, src, count);
+  }
 };
 
 // Register our dummy allocator
diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp
index b5bed3c..b304c5e 100644
--- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp
+++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp
@@ -330,6 +330,14 @@
   return "pluggable";
 }
 
+void CUDAPluggableAllocator::copy_data(
+    void* dest,
+    const void* src,
+    std::size_t count) const {
+  C10_CUDA_CHECK(
+      cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
+}
+
 std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
     current_custom_allocator;
 
diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h
index 6a0f3e0..f2c2e6e 100644
--- a/torch/csrc/cuda/CUDAPluggableAllocator.h
+++ b/torch/csrc/cuda/CUDAPluggableAllocator.h
@@ -120,6 +120,7 @@
       cudaStream_t stream,
       bool p2p_enabled) override;
   std::string name() override;
+  void copy_data(void* dest, const void* src, std::size_t count) const final;
 
  protected:
   std::function<void*(size_t, int, cudaStream_t)> alloc_fn_;