blob: 0e8b947805ab0a186888e1215ff14c3ae348d1d9 [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/Resize.h>
namespace at {
namespace native {
// The meta allocator ignores whatever allocation is requested and always
// gives you nullptr
struct MetaAllocator final : public at::Allocator {
MetaAllocator() = default;
~MetaAllocator() override = default;
static void deleter(void* const pointer) {
TORCH_INTERNAL_ASSERT(!pointer);
}
DataPtr allocate(const size_t nbytes) const override {
return {nullptr, nullptr, &deleter, at::Device(DeviceType::Meta)};
}
DeleterFnPtr raw_deleter() const override {
return deleter;
}
};
static MetaAllocator g_meta_alloc;
at::Allocator* GetMetaAllocator() {
return &g_meta_alloc;
}
Tensor empty_meta(
IntArrayRef size,
c10::optional<ScalarType> dtype_opt,
c10::optional<Layout> layout_opt,
c10::optional<Device> device_opt,
c10::optional<bool> pin_memory_opt,
c10::optional<c10::MemoryFormat> memory_format_opt
) {
auto device = device_or_default(device_opt);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device.type() == DeviceType::Meta);
// NB: because there is no SparseMeta (yet), non-strided layout is
// exerciseable
TORCH_CHECK_NOT_IMPLEMENTED(
layout_or_default(layout_opt) == Layout::Strided,
"strided meta tensors not supported yet"
);
auto* allocator = GetMetaAllocator();
auto dtype = dtype_or_default(dtype_opt);
auto r = at::detail::empty_generic(size, allocator, at::DispatchKey::Meta, dtype, device, memory_format_opt);
return r;
}
Tensor empty_strided_meta(
IntArrayRef size,
IntArrayRef stride,
c10::optional<ScalarType> dtype_opt,
c10::optional<Layout> layout_opt,
c10::optional<Device> device_opt,
c10::optional<bool> pin_memory_opt
) {
auto t = at::native::empty_meta({0}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
// Amazingly the CPU implementation will work for us, because most of resize
// is generic except the memcpy, but the memcpy will be skipped if the source
// storage is nullptr (which it always is, for meta tensors)
at::native::resize_impl_cpu_(t.unsafeGetTensorImpl(), size, stride);
return t;
}
} // namespace native
} // namespace at