Add function to materialize COW storages (#113396)
Part of #109833
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113396
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp
index 459960b..e2518dc 100644
--- a/aten/src/ATen/EmptyTensor.cpp
+++ b/aten/src/ATen/EmptyTensor.cpp
@@ -311,6 +311,7 @@
DeleterFnPtr raw_deleter() const override {
return deleter;
}
+ void copy_data(void* dest, const void* src, std::size_t count) const final {}
};
static MetaAllocator g_meta_alloc;
diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp
index 22dbb66..3c63ed7 100644
--- a/aten/src/ATen/cuda/CachingHostAllocator.cpp
+++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp
@@ -301,6 +301,10 @@
}
}
+ void copy_data(void* dest, const void* src, std::size_t count) const {
+ TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for CUDAHostAllocator");
+ }
+
private:
void process_events() {
while (true) {
@@ -496,6 +500,10 @@
&CUDAHostAllocatorDeleter,
at::DeviceType::CPU};
}
+
+ void copy_data(void* dest, const void* src, std::size_t count) const final {
+ getCUDAHostAllocator().copy_data(dest, src, count);
+ }
};
static CUDAHostAllocatorWrapper cuda_host_allocator;
diff --git a/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h
index c5a607c..e4c0cec 100644
--- a/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h
+++ b/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h
@@ -23,6 +23,9 @@
DeleterFnPtr raw_deleter() const override {
return allocator_->raw_deleter();
}
+ void copy_data(void* dest, const void* src, std::size_t count) const final {
+ allocator_->copy_data(dest, src, count);
+ }
};
}} // namespace c10::hip
diff --git a/aten/src/ATen/mps/MPSAllocator.mm b/aten/src/ATen/mps/MPSAllocator.mm
index fba4c82..32a9b31 100644
--- a/aten/src/ATen/mps/MPSAllocator.mm
+++ b/aten/src/ATen/mps/MPSAllocator.mm
@@ -819,6 +819,10 @@
return _getAllocImpl().format_size(size);
}
+ void copy_data(void* dest, const void* src, std::size_t count) const final {
+ default_copy_data(dest, src, count);
+ }
+
private:
bool m_has_unified_memory;
uint32_t m_usage;
diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp
index 0369e52..eeef680 100644
--- a/aten/src/ATen/native/Resize.cpp
+++ b/aten/src/ATen/native/Resize.cpp
@@ -94,13 +94,14 @@
if (size_bytes != 0) {
new_data = storage->allocator()->allocate(size_bytes);
}
- at::DataPtr old_data = storage->set_data_ptr(std::move(new_data));
+ const at::DataPtr& old_data = storage->data_ptr();
const auto old_capacity = storage->nbytes();
- storage->set_nbytes(size_bytes);
const auto copy_capacity = std::min(size_bytes, old_capacity);
if (old_data != nullptr && copy_capacity > 0) {
- memcpy(storage->mutable_data(), old_data.get(), copy_capacity);
+ memcpy(new_data.get(), old_data.get(), copy_capacity);
}
+ storage->set_data_ptr_noswap(std::move(new_data));
+ storage->set_nbytes(size_bytes);
}
// Call the sparse implementation in SparseTensor.cpp directly.
diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h
index fbb3a11..84de45a 100644
--- a/aten/src/ATen/native/TensorFactories.h
+++ b/aten/src/ATen/native/TensorFactories.h
@@ -129,6 +129,7 @@
DeleterFnPtr raw_deleter() const override {
return deleter;
}
+ void copy_data(void* dest, const void* src, std::size_t count) const final {}
at::Device device_;
};
diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt
index aa97677..d81ba78 100644
--- a/aten/src/ATen/test/CMakeLists.txt
+++ b/aten/src/ATen/test/CMakeLists.txt
@@ -13,6 +13,7 @@
${CMAKE_CURRENT_SOURCE_DIR}/atest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/basic.cpp
${CMAKE_CURRENT_SOURCE_DIR}/broadcast_test.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/cpu_allocator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpu_generator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpu_profiling_allocator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpu_rng_test.cpp
@@ -54,6 +55,7 @@
)
list(APPEND ATen_CUDA_TEST_SRCS
+ ${CMAKE_CURRENT_SOURCE_DIR}/cuda_allocator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_apply_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_atomic_ops_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_caching_host_allocator_test.cpp
diff --git a/aten/src/ATen/test/allocator_clone_test.h b/aten/src/ATen/test/allocator_clone_test.h
new file mode 100644
index 0000000..79a1f5f
--- /dev/null
+++ b/aten/src/ATen/test/allocator_clone_test.h
@@ -0,0 +1,35 @@
+#pragma once
+#include <gtest/gtest.h>
+#include <ATen/ATen.h>
+
+void test_allocator_clone(c10::Allocator* allocator) {
+ ASSERT_TRUE(allocator != nullptr);
+
+ c10::Storage a_storage(c10::make_intrusive<c10::StorageImpl>(
+ c10::StorageImpl::use_byte_size_t(),
+ 0,
+ allocator,
+ /*resizable=*/true));
+
+ c10::Storage b_storage(c10::make_intrusive<c10::StorageImpl>(
+ c10::StorageImpl::use_byte_size_t(),
+ 0,
+ allocator,
+ /*resizable=*/true));
+
+ at::Tensor a = at::empty({0}, at::TensorOptions().device(a_storage.device())).set_(a_storage);
+ at::Tensor b = at::empty({0}, at::TensorOptions().device(b_storage.device())).set_(b_storage);
+
+ std::vector<int64_t> sizes({13, 4, 5});
+
+ at::rand_out(a, sizes);
+ at::rand_out(b, sizes);
+
+ ASSERT_TRUE(a_storage.nbytes() == static_cast<size_t>(a.numel() * a.element_size()));
+ ASSERT_TRUE(a_storage.nbytes() == b_storage.nbytes());
+
+ void* a_data_ptr = a_storage.mutable_data();
+ b_storage.set_data_ptr(allocator->clone(a_data_ptr, a_storage.nbytes()));
+
+ ASSERT_TRUE((a == b).all().item<bool>());
+}
diff --git a/aten/src/ATen/test/cpu_allocator_test.cpp b/aten/src/ATen/test/cpu_allocator_test.cpp
new file mode 100644
index 0000000..db98522
--- /dev/null
+++ b/aten/src/ATen/test/cpu_allocator_test.cpp
@@ -0,0 +1,10 @@
+#include <gtest/gtest.h>
+
+#include <c10/core/CPUAllocator.h>
+#include <ATen/ATen.h>
+
+#include <ATen/test/allocator_clone_test.h>
+
+TEST(AllocatorTestCPU, test_clone) {
+ test_allocator_clone(c10::GetDefaultCPUAllocator());
+}
diff --git a/aten/src/ATen/test/cuda_allocator_test.cpp b/aten/src/ATen/test/cuda_allocator_test.cpp
new file mode 100644
index 0000000..27a352e
--- /dev/null
+++ b/aten/src/ATen/test/cuda_allocator_test.cpp
@@ -0,0 +1,10 @@
+#include <gtest/gtest.h>
+
+#include <ATen/ATen.h>
+#include <c10/cuda/CUDACachingAllocator.h>
+
+#include <ATen/test/allocator_clone_test.h>
+
+TEST(AllocatorTestCUDA, test_clone) {
+ test_allocator_clone(c10::cuda::CUDACachingAllocator::get());
+}
diff --git a/aten/src/ATen/test/xla_tensor_test.cpp b/aten/src/ATen/test/xla_tensor_test.cpp
index 21f911e..1c9e392 100644
--- a/aten/src/ATen/test/xla_tensor_test.cpp
+++ b/aten/src/ATen/test/xla_tensor_test.cpp
@@ -2,6 +2,8 @@
#include <ATen/ATen.h>
+#include <ATen/test/allocator_clone_test.h>
+
using namespace at;
void XLAFree(void *ptr) {
@@ -22,6 +24,9 @@
at::DeleterFnPtr raw_deleter() const override {
return &XLAFree;
}
+ void copy_data(void* dest, const void* src, std::size_t count) const final {
+ default_copy_data(dest, src, count);
+ }
};
TEST(XlaTensorTest, TestNoStorage) {
@@ -33,3 +38,11 @@
at::Tensor t(std::move(tensor_impl));
ASSERT_TRUE(t.device() == at::Device(DeviceType::XLA, 0));
}
+
+TEST(XlaTensorTest, test_allocator_clone) {
+ if (!at::hasXLA()) {
+ return;
+ }
+ XLAAllocator allocator;
+ test_allocator_clone(&allocator);
+}
diff --git a/c10/core/Allocator.cpp b/c10/core/Allocator.cpp
index dada5bb..c95fdec 100644
--- a/c10/core/Allocator.cpp
+++ b/c10/core/Allocator.cpp
@@ -4,6 +4,23 @@
namespace c10 {
+DataPtr Allocator::clone(const void* data, std::size_t n) const {
+ DataPtr new_data = allocate(n);
+ copy_data(new_data.mutable_get(), data, n);
+ return new_data;
+}
+
+void Allocator::default_copy_data(
+ void* dest,
+ const void* src,
+ std::size_t count) const {
+ std::memcpy(dest, src, count);
+}
+
+bool Allocator::is_simple_data_ptr(const DataPtr& data_ptr) const {
+ return data_ptr.get() == data_ptr.get_context();
+}
+
static void deleteInefficientStdFunctionContext(void* ptr) {
delete static_cast<InefficientStdFunctionContext*>(ptr);
}
diff --git a/c10/core/Allocator.h b/c10/core/Allocator.h
index c7c4c3a..2407174 100644
--- a/c10/core/Allocator.h
+++ b/c10/core/Allocator.h
@@ -156,6 +156,21 @@
virtual DataPtr allocate(size_t n) const = 0;
+ // Clones an allocation that came from this allocator.
+ //
+ // To perform the copy, this function calls `copy_data`, which
+ // must be implemented by derived classes.
+ //
+ // Note that this explicitly ignores any context that may have been
+ // attached to the input data.
+ //
+ // Requires: input data was allocated by the same allocator.
+ DataPtr clone(const void* data, std::size_t n) const;
+
+ // Checks if DataPtr has a simple context, not wrapped with any out of the
+ // ordinary contexts.
+ virtual bool is_simple_data_ptr(const DataPtr& data_ptr) const;
+
// If this returns a non nullptr, it means that allocate()
// is guaranteed to return a unique_ptr with this deleter attached;
// it means the rawAllocate and rawDeallocate APIs are safe to use.
@@ -173,6 +188,22 @@
AT_ASSERT(d);
d(ptr);
}
+
+ // Copies data from one allocation to another.
+ // Pure virtual, so derived classes must define behavior.
+ // Derived class implementation can simply call `default_copy_data`
+ // to use `std::memcpy`.
+ //
+ // Requires: src and dest were allocated by this allocator
+ // Requires: src and dest both have length >= count
+ virtual void copy_data(void* dest, const void* src, std::size_t count)
+ const = 0;
+
+ protected:
+ // Uses `std::memcpy` to copy data.
+ // Child classes can use this as `copy_data` when an alternative copy
+ // API is not needed.
+ void default_copy_data(void* dest, const void* src, std::size_t count) const;
};
// This context is used to generate DataPtr which have arbitrary
diff --git a/c10/core/CPUAllocator.cpp b/c10/core/CPUAllocator.cpp
index c103c42..0475904 100644
--- a/c10/core/CPUAllocator.cpp
+++ b/c10/core/CPUAllocator.cpp
@@ -40,6 +40,10 @@
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete;
}
+
+ void copy_data(void* dest, const void* src, std::size_t count) const final {
+ default_copy_data(dest, src, count);
+ }
};
ProfiledCPUMemoryReporter& profiledCPUMemoryReporter() {
@@ -142,6 +146,16 @@
DeleterFnPtr raw_deleter() const override {
return deleter;
}
+
+ bool is_simple_data_ptr(const c10::DataPtr& data_ptr) const final {
+ return reinterpret_cast<const uint8_t*>(data_ptr.get()) ==
+ reinterpret_cast<const uint8_t*>(data_ptr.get_context()) +
+ PreGuardBytes;
+ }
+
+ void copy_data(void* dest, const void* src, std::size_t count) const final {
+ default_copy_data(dest, src, count);
+ }
};
void NoDelete(void*) {}
diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h
index 082b5df..739e3f3 100644
--- a/c10/core/StorageImpl.h
+++ b/c10/core/StorageImpl.h
@@ -3,6 +3,8 @@
#include <c10/core/Allocator.h>
#include <c10/core/SymInt.h>
#include <c10/core/impl/PyObjectSlot.h>
+#include <c10/core/impl/cow/COW.h>
+#include <c10/core/impl/cow/COWDeleter.h>
#include <c10/util/intrusive_ptr.h>
@@ -111,6 +113,7 @@
}
at::DataPtr& mutable_data_ptr() {
+ maybe_materialize_cow();
return data_ptr_;
}
@@ -120,9 +123,10 @@
// Returns the previous data_ptr
at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) {
- at::DataPtr old_data_ptr(std::move(data_ptr_));
- data_ptr_ = std::move(data_ptr);
- return old_data_ptr;
+ // We need to materialize the old COW DataPtr because it is
+ // being returned as mutable.
+ maybe_materialize_cow();
+ return set_data_ptr_no_materialize_cow(std::move(data_ptr));
}
void set_data_ptr_noswap(at::DataPtr&& data_ptr) {
@@ -134,6 +138,7 @@
}
void* mutable_data() {
+ maybe_materialize_cow();
return data_ptr_.mutable_get();
}
@@ -211,7 +216,26 @@
return &pyobj_slot_;
}
+ protected:
+ // materialize_cow_storage needs to call set_data_ptr_no_materlize_cow
+ friend void c10::impl::cow::materialize_cow_storage(StorageImpl& storage);
+
+ // Returns the previous data_ptr. If the old data_ptr was COW,
+ // this avoids materializing it
+ at::DataPtr set_data_ptr_no_materialize_cow(at::DataPtr&& data_ptr) {
+ at::DataPtr old_data_ptr(std::move(data_ptr_));
+ data_ptr_ = std::move(data_ptr);
+ return old_data_ptr;
+ }
+
private:
+ // Triggers a copy if this is a copy-on-write tensor.
+ void maybe_materialize_cow() {
+ if (data_ptr_.get_deleter() == impl::cow::cow_deleter) {
+ impl::cow::materialize_cow_storage(*this);
+ }
+ }
+
DataPtr data_ptr_;
SymInt size_bytes_;
bool size_bytes_is_heap_allocated_;
diff --git a/c10/core/build.bzl b/c10/core/build.bzl
index 781c740..f395b9b 100644
--- a/c10/core/build.bzl
+++ b/c10/core/build.bzl
@@ -58,22 +58,22 @@
[
"*.cpp",
"impl/*.cpp",
+ "impl/cow/*.cpp",
],
exclude = [
"CPUAllocator.cpp",
"impl/alloc_cpu.cpp",
- "impl/cow/*.cpp",
],
),
hdrs = rules.glob(
[
"*.h",
"impl/*.h",
+ "impl/cow/*.h",
],
exclude = [
"CPUAllocator.h",
"impl/alloc_cpu.h",
- "impl/cow/*.h",
],
),
linkstatic = True,
@@ -92,22 +92,6 @@
alwayslink = True,
)
- rules.cc_library(
- name = "impl_cow",
- srcs = rules.glob([
- "impl/cow/*.cpp",
- ]),
- hdrs = rules.glob([
- "impl/cow/*.h",
- ]),
- deps = [
- ":base",
- ":CPUAllocator",
- ],
- visibility = ["//c10/test:__pkg__"],
-
- )
-
rules.filegroup(
name = "headers",
srcs = rules.glob(
diff --git a/c10/core/impl/cow/COW.cpp b/c10/core/impl/cow/COW.cpp
index f32e0ea..e2ed465 100644
--- a/c10/core/impl/cow/COW.cpp
+++ b/c10/core/impl/cow/COW.cpp
@@ -1,7 +1,6 @@
#include <c10/core/impl/cow/COW.h>
#include <c10/core/Allocator.h>
-#include <c10/core/CPUAllocator.h>
#include <c10/core/StorageImpl.h>
#include <c10/core/alignment.h>
#include <c10/core/impl/cow/COWDeleter.h>
@@ -30,24 +29,18 @@
return make_data_ptr(data_ptr, *ctx);
}
-bool is_simple_context(
- const void* context,
- const void* data,
- const at::Allocator* allocator) {
- if (allocator == c10::GetDefaultMobileCPUAllocator()) {
- return reinterpret_cast<size_t>(data) ==
- reinterpret_cast<size_t>(context) + c10::gAlignment;
- } else {
- return data == context;
- }
-}
-
} // namespace
bool has_simple_data_ptr(const c10::StorageImpl& storage) {
const c10::DataPtr& data_ptr = storage.data_ptr();
- return is_simple_context(
- data_ptr.get_context(), data_ptr.get(), storage.allocator());
+ const void* ctx = data_ptr.get_context();
+ const void* data = data_ptr.get();
+ const c10::Allocator* allocator = storage.allocator();
+ if (allocator != nullptr) {
+ return allocator->is_simple_data_ptr(data_ptr);
+ } else {
+ return ctx == data;
+ }
}
bool is_cow_data_ptr(const c10::DataPtr& data_ptr) {
@@ -88,8 +81,6 @@
// Case 1) We have a simple data pointer: wrap it.
std::unique_ptr<void, DeleterFnPtr> original_ctx =
storage.mutable_data_ptr().move_context();
- TORCH_INTERNAL_ASSERT(is_simple_context(
- original_ctx.get(), data_ptr.get(), storage.allocator()));
// Save this for the result.
new_data_ptr = make_data_ptr(
@@ -117,4 +108,40 @@
storage.resizable());
}
+C10_API void materialize_cow_storage(StorageImpl& storage) {
+ const at::DataPtr& data_ptr = storage.data_ptr();
+
+ auto* ctx = data_ptr.cast_context<cow::COWDeleterContext>(cow::cow_deleter);
+ TORCH_INTERNAL_ASSERT(ctx != nullptr);
+
+ auto result = ctx->decrement_refcount();
+
+ // This must be set by each branch below.
+ std::optional<DataPtr> new_data_ptr;
+
+ if (std::holds_alternative<cow::COWDeleterContext::LastReference>(result)) {
+ // This is the only reference to the data. If there were any racing writes,
+ // the context ensured they finished before giving us the result.
+ std::unique_ptr<void, DeleterFnPtr> data =
+ std::get<cow::COWDeleterContext::LastReference>(std::move(result));
+ TORCH_INTERNAL_ASSERT(data.get() == data_ptr.get());
+ new_data_ptr = DataPtr(
+ data.release(), data_ptr.get(), data.get_deleter(), data_ptr.device());
+ } else {
+ TORCH_INTERNAL_ASSERT(
+ std::holds_alternative<cow::COWDeleterContext::NotLastReference>(
+ result));
+ // We don't need to consume the result, it's just a shared lock ensuring
+ // that the data will remain while we copy it.
+ new_data_ptr = storage.allocator()->clone(data_ptr.get(), storage.nbytes());
+ }
+
+ TORCH_INTERNAL_ASSERT(new_data_ptr.has_value());
+ DataPtr old_data_ptr =
+ storage.set_data_ptr_no_materialize_cow(*std::move(new_data_ptr));
+ // The refcount of the context was already decremented above. Release the
+ // reference to the context so the refcount doesn't get decremented again
+ old_data_ptr.release_context();
+}
+
} // namespace c10::impl::cow
diff --git a/c10/core/impl/cow/COW.h b/c10/core/impl/cow/COW.h
index 07ef5e4..1cf81ed 100644
--- a/c10/core/impl/cow/COW.h
+++ b/c10/core/impl/cow/COW.h
@@ -26,4 +26,7 @@
// Check if a DataPtr is COW
C10_API bool is_cow_data_ptr(const c10::DataPtr& data_ptr);
+// Eagerly copies a COW storage's data, turning it into a non-COW storage.
+C10_API void materialize_cow_storage(StorageImpl& storage);
+
} // namespace c10::impl::cow
diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp
index 5f30127..eab2fb4 100644
--- a/c10/cuda/CUDACachingAllocator.cpp
+++ b/c10/cuda/CUDACachingAllocator.cpp
@@ -3235,6 +3235,10 @@
std::string name() override {
return "native";
}
+ void copy_data(void* dest, const void* src, std::size_t count) const final {
+ C10_CUDA_CHECK(
+ cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
+ }
};
NativeCachingAllocator allocator;
diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp
index 222ec45..4982fc6 100644
--- a/c10/cuda/CUDAMallocAsyncAllocator.cpp
+++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp
@@ -875,6 +875,10 @@
std::string name() override {
return "cudaMallocAsync";
}
+ void copy_data(void* dest, const void* src, std::size_t count) const final {
+ C10_CUDA_CHECK(
+ cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
+ }
};
CudaMallocAsyncAllocator device_allocator;
diff --git a/c10/test/build.bzl b/c10/test/build.bzl
index d9b5897..2f54c8a 100644
--- a/c10/test/build.bzl
+++ b/c10/test/build.bzl
@@ -21,7 +21,6 @@
"//c10/core:base",
"//c10/util:base",
"//c10/core:CPUAllocator",
- "//c10/core:impl_cow",
"@com_google_googletest//:gtest_main",
],
)
diff --git a/c10/test/core/impl/cow_test.cpp b/c10/test/core/impl/cow_test.cpp
index 5fd30f5..a87bf4d 100644
--- a/c10/test/core/impl/cow_test.cpp
+++ b/c10/test/core/impl/cow_test.cpp
@@ -167,6 +167,72 @@
ASSERT_THAT(new_storage->data(), testing::Eq(original_storage.data()));
}
+TEST(materialize_test, not_copy_on_write_context) {
+ StorageImpl storage(
+ {}, /*size_bytes=*/6, GetCPUAllocator(), /*resizable=*/false);
+ ASSERT_THAT(storage, testing::Not(is_copy_on_write()));
+
+ void const* original_data = storage.data();
+
+ // Nothing to materialize.
+ ASSERT_THAT(storage.mutable_data(), testing::Eq(original_data));
+}
+
+TEST(materialize_test, copy_on_write_single_reference) {
+ // A copy-on-write storage with only a single reference can just
+ // drop the copy-on-write context upon materialization.
+ std::unique_ptr<void, DeleterFnPtr> data(
+ new std::byte[4],
+ +[](void* bytes) { delete[] static_cast<std::byte*>(bytes); });
+ void* data_ptr = data.get();
+ StorageImpl storage(
+ {},
+ /*size_bytes=*/4,
+ at::DataPtr(
+ /*data=*/data_ptr,
+ /*ctx=*/new cow::COWDeleterContext(std::move(data)),
+ cow::cow_deleter,
+ Device(Device::Type::CPU)),
+ /*allocator=*/nullptr,
+ /*resizable=*/false);
+
+ ASSERT_THAT(storage, is_copy_on_write());
+
+ ASSERT_THAT(storage.data(), testing::Eq(data_ptr));
+
+ void const* original_data = storage.data();
+
+ // Materializes storage. Only reference, so no new allocation.
+ ASSERT_THAT(storage.mutable_data(), testing::Eq(original_data));
+ // But it is no longer copy-on-write.
+ ASSERT_THAT(storage, testing::Not(is_copy_on_write()));
+}
+
+TEST(materialize_test, copy_on_write) {
+ StorageImpl original_storage(
+ {}, /*size_bytes=*/6, GetCPUAllocator(), /*resizable=*/false);
+ std::memcpy(original_storage.mutable_data(), "abcd", 4);
+ void const* original_data = original_storage.data();
+
+ auto new_storage = cow::lazy_clone_storage(original_storage);
+ ASSERT_THAT(new_storage, testing::NotNull());
+
+ auto context = new_storage->data_ptr().cast_context<cow::COWDeleterContext>(
+ cow::cow_deleter);
+ ASSERT_THAT(context, testing::NotNull());
+
+ // Materialized storage has new copy of data.
+ ASSERT_THAT(new_storage->mutable_data(), testing::Ne(original_data));
+
+ // But the original storage still has the original copy.
+ ASSERT_THAT(original_storage.data(), testing::Eq(original_data));
+
+ // But their data is the same.
+ ASSERT_THAT(
+ static_cast<char const*>(new_storage->data()),
+ testing::StrEq(static_cast<char const*>(original_storage.data())));
+}
+
} // namespace
} // namespace c10::impl
// NOLINTEND(clang-analyzer-cplusplus*)
diff --git a/caffe2/core/context_gpu.cu b/caffe2/core/context_gpu.cu
index 3359e88..6555b97 100644
--- a/caffe2/core/context_gpu.cu
+++ b/caffe2/core/context_gpu.cu
@@ -336,6 +336,10 @@
return &Delete;
}
+ void copy_data(void* dest, const void* src, std::size_t count) const final {
+ TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for PinnedCPUAllocator");
+ }
+
private:
static void Delete(void* data) {
if (!data) {
@@ -581,6 +585,10 @@
return &Delete;
}
+ void copy_data(void* dest, const void* src, std::size_t count) const final {
+ TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for DefaultCUDAAllocator");
+ }
+
private:
static void Delete(void* ptr) {
// lock the mutex
diff --git a/test/cpp_extensions/open_registration_extension.cpp b/test/cpp_extensions/open_registration_extension.cpp
index 5818d64..cfe49e0 100644
--- a/test/cpp_extensions/open_registration_extension.cpp
+++ b/test/cpp_extensions/open_registration_extension.cpp
@@ -188,6 +188,10 @@
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete;
}
+
+ void copy_data(void* dest, const void* src, std::size_t count) const final {
+ default_copy_data(dest, src, count);
+ }
};
// Register our dummy allocator
diff --git a/test/inductor/extension_backends/extension_device.cpp b/test/inductor/extension_backends/extension_device.cpp
index fa2922a..2e86cde 100644
--- a/test/inductor/extension_backends/extension_device.cpp
+++ b/test/inductor/extension_backends/extension_device.cpp
@@ -81,6 +81,10 @@
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete;
}
+
+ void copy_data(void* dest, const void* src, std::size_t count) const final {
+ default_copy_data(dest, src, count);
+ }
};
// Register our dummy allocator
diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp
index b5bed3c..b304c5e 100644
--- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp
+++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp
@@ -330,6 +330,14 @@
return "pluggable";
}
+void CUDAPluggableAllocator::copy_data(
+ void* dest,
+ const void* src,
+ std::size_t count) const {
+ C10_CUDA_CHECK(
+ cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
+}
+
std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
current_custom_allocator;
diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h
index 6a0f3e0..f2c2e6e 100644
--- a/torch/csrc/cuda/CUDAPluggableAllocator.h
+++ b/torch/csrc/cuda/CUDAPluggableAllocator.h
@@ -120,6 +120,7 @@
cudaStream_t stream,
bool p2p_enabled) override;
std::string name() override;
+ void copy_data(void* dest, const void* src, std::size_t count) const final;
protected:
std::function<void*(size_t, int, cudaStream_t)> alloc_fn_;