Port to_padded_tensor CUDA kernel from pytorch/nestedtensor

This PR adds a custom CUDA kernel to pad NestedTensors between dimension 2 and 4.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76157
Approved by: https://github.com/ngimel
diff --git a/aten/src/ATen/NestedTensorImpl.cpp b/aten/src/ATen/NestedTensorImpl.cpp
index 2a0b439..cfabdc4 100644
--- a/aten/src/ATen/NestedTensorImpl.cpp
+++ b/aten/src/ATen/NestedTensorImpl.cpp
@@ -8,6 +8,33 @@
 namespace at {
 namespace native {
 
+inline std::vector<int64_t> construct_opt_sizes(const at::Tensor& sizes) {
+  if (sizes.dim() == 0) {
+    return std::vector<int64_t>();
+  }
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2);
+  std::vector<int64_t> result(1, sizes.sizes()[0]);
+  if (sizes.dim() > 0) {
+    size_t nested_dim = result.size();
+    int64_t* sizes_ptr = sizes.data_ptr<int64_t>();
+    result.resize(nested_dim + sizes.sizes()[1]);
+    int64_t sizes_size_0 = sizes.sizes()[0];
+    int64_t sizes_size_1 = sizes.sizes()[1];
+    for (const auto i : c10::irange(sizes_size_1)) {
+      result[nested_dim + i] = sizes_ptr[i];
+    }
+    for (const auto j : c10::irange(sizes_size_1)) {
+      for (const auto i : c10::irange(sizes_size_0)) {
+        if (result[nested_dim + j] &&
+            (result[nested_dim + j] != sizes_ptr[i * sizes.size(1) + j])) {
+          result[nested_dim + j] = -1;
+        }
+      }
+    }
+  }
+  return result;
+}
+
 NestedTensorImpl::NestedTensorImpl(
     at::Tensor buffer,
     at::Tensor nested_size_tensor)
@@ -17,7 +44,9 @@
           buffer.dtype(),
           buffer.device()),
       buffer_(std::move(buffer)),
-      nested_size_tensor_(std::move(nested_size_tensor)) {
+      nested_size_tensor_(std::move(nested_size_tensor)),
+      opt_sizes_(construct_opt_sizes(nested_size_tensor_))
+{
   TORCH_WARN_ONCE(
       "The PyTorch API of nested tensors is in prototype stage and will change "
       "in the near future.");
@@ -41,5 +70,6 @@
 const char* NestedTensorImpl::tensorimpl_type_name() const {
   return "NestedTensorImpl";
 }
+
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/NestedTensorImpl.h b/aten/src/ATen/NestedTensorImpl.h
index f64014d..1343ec1 100644
--- a/aten/src/ATen/NestedTensorImpl.h
+++ b/aten/src/ATen/NestedTensorImpl.h
@@ -32,6 +32,16 @@
   const Tensor& get_nested_size_tensor() const {
     return nested_size_tensor_;
   }
+  // Returns nullopt if the ith dimension is irregular. The ith dimension
+  // of a NestedTensor is regular if the unbound tensors match in
+  // size at the (i-1)th dimension.
+  c10::optional<int64_t> opt_size(int64_t d) const {
+    d = at::maybe_wrap_dim(d, dim(), false);
+    if (opt_sizes_[d] == -1) {
+      return c10::nullopt;
+    }
+    return opt_sizes_[d];
+  }
 #ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
   IntArrayRef sizes() const override {
     TORCH_CHECK(
@@ -63,6 +73,8 @@
 
   at::Tensor buffer_;
   const at::Tensor nested_size_tensor_;
+  // NOTE: -1 here means the size is missing
+  std::vector<int64_t> opt_sizes_;
 };
 
 inline NestedTensorImpl* get_nested_tensor_impl_or_null(const at::Tensor& tensor) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index f7d0c2c..e4c9f8c 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -11617,7 +11617,8 @@
 - func: to_padded_tensor(Tensor self, float padding) -> Tensor
   variants: method
   dispatch:
-    NestedTensorCPU, NestedTensorCUDA: NestedTensor_to_padded_tensor
+    NestedTensorCPU: NestedTensor_to_padded_tensor_generic
+    NestedTensorCUDA: NestedTensor_to_padded_tensor_cuda
 
 - func: _nested_tensor_layer_norm(Tensor self, Tensor? weight, Tensor? bias, float eps) -> Tensor
   variants: method
diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp
index f28f7ae..b16cdc6 100644
--- a/aten/src/ATen/native/nested/NestedTensorMath.cpp
+++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp
@@ -8,6 +8,7 @@
 #include <ATen/native/layer_norm.h>
 #include <ATen/NestedTensorImpl.h>
 #include <c10/core/DispatchKey.h>
+#include <ATen/native/nested/NestedTensorMath.h>
 
 namespace at {
 namespace native {
@@ -216,8 +217,7 @@
 }
 
 std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl& nt) {
-  const auto& sizes = nt.get_nested_size_tensor();
-  return NestedTensor_get_max_size_from_size_tensor(sizes);
+  return NestedTensor_get_max_size_from_size_tensor(nt.get_nested_size_tensor());
 }
 
 Tensor NestedTensor_layer_norm(
@@ -303,8 +303,7 @@
       std::move(new_buffer), sizes);
 }
 
-Tensor NestedTensor_to_padded_tensor(const Tensor& t, double padding) {
-  // TODO port CUDA path in pytorch/nestedtensor to_padded_tensor!
+Tensor NestedTensor_to_padded_tensor_generic(const Tensor& t, double padding) {
   // TODO: skipped optimization for case of all 1x1 tensors
   auto& nt = *get_nested_tensor_impl(t);
   auto max_size = NestedTensor_get_max_size(nt);
diff --git a/aten/src/ATen/native/nested/NestedTensorMath.h b/aten/src/ATen/native/nested/NestedTensorMath.h
index 0211cd1..0993863 100644
--- a/aten/src/ATen/native/nested/NestedTensorMath.h
+++ b/aten/src/ATen/native/nested/NestedTensorMath.h
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <c10/macros/Macros.h>
+#include <ATen/NestedTensorImpl.h>
 
 #include <vector>
 
@@ -13,5 +14,7 @@
 
 TORCH_API std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl& nt);
 
+TORCH_API Tensor NestedTensor_to_padded_tensor_generic(const Tensor& t, double padding);
+
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h
index ce84e85..2a41b79 100644
--- a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h
+++ b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h
@@ -52,7 +52,6 @@
 
 Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim);
 
-
 template <typename T>
 void remove_padding_kernelLauncher(
     const T* input,
@@ -72,5 +71,16 @@
     const int* output_sizes,
     int output_dim,
     const int batch_size);
+
+template <typename T>
+void add_padding_kernelLauncher(
+    T* input,
+    T* output,
+    T padding_value,
+    const int* offsets,
+    const int* input_sizes,
+    int input_dim,
+    const std::vector<int64_t>& output_sizes,
+    const int batch_size);
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp
index 896c966..2677d7b 100644
--- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp
+++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp
@@ -10,6 +10,7 @@
 #endif
 
 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
+#include <ATen/native/nested/NestedTensorMath.h>
 
 namespace at {
 namespace native {
@@ -107,5 +108,77 @@
     return at::native::nested_from_padded_generic(padded, sizes);
   }
 }
+
+Tensor batch_offsets_from_efficient_size(const Tensor& ef_sizes) {
+  int64_t* nt_sizes_ptr = ef_sizes.data_ptr<int64_t>();
+  int64_t ef_sizes_size_0 = ef_sizes.sizes()[0];
+  Tensor offsets = at::empty({1 + ef_sizes_size_0}, at::kLong);
+  int64_t* offsets_ptr = offsets.data_ptr<int64_t>();
+  offsets_ptr[0] = 0;
+  int64_t ef_sizes_size_1 = ef_sizes.sizes()[1];
+  for (const auto i : c10::irange(ef_sizes_size_0)) {
+    int64_t prod = 1;
+    for (const auto j : c10::irange(ef_sizes_size_1)) {
+      prod = prod * nt_sizes_ptr[i * ef_sizes_size_1 + j];
+    }
+    offsets_ptr[i + 1] = offsets_ptr[i] + prod;
+  }
+  return offsets;
+}
+
+Tensor NestedTensor_to_padded_tensor_cuda(const Tensor& t, double padding) {
+  int64_t t_dim = t.dim();
+  if (t_dim >= 2 && t_dim <= 4 &&
+      (t.dtype() == at::kFloat || t.dtype() == at::kDouble ||
+       t.dtype() == at::kHalf)) {
+    auto* nt_input = get_nested_tensor_impl(t);
+    TORCH_CHECK(nested_tensor_impl_is_contiguous(nt_input));
+    const auto& nt_buffer = nt_input->get_buffer();
+
+    if (t_dim == 3 && nt_input->opt_size(2) && (*nt_input->opt_size(2) > 0)) {
+      Tensor nt_sizes = nt_input->get_nested_size_tensor();
+      Tensor sizes_dim1 = at::native::narrow(nt_sizes, 1, 0, 1);
+      Tensor sizes_dim2 = at::native::narrow(nt_sizes, 1, 1, 1);
+      Tensor result = at::detail::make_tensor<NestedTensorImpl>(
+          nt_input->get_buffer(), sizes_dim1 * sizes_dim2[0]);
+      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.dim() == 2);
+      result = NestedTensor_to_padded_tensor_cuda(result, padding);
+      return result.reshape({result.sizes()[0], -1, *nt_input->opt_size(2)});
+    }
+
+    Tensor nt_sizes = nt_input->get_nested_size_tensor();
+    Tensor offsets = batch_offsets_from_efficient_size(nt_sizes);
+    auto new_size = NestedTensor_get_max_size(*nt_input);
+    new_size.insert(new_size.begin(), nt_sizes.sizes()[0]);
+    Tensor output = at::empty(IntArrayRef(new_size), nt_buffer.options());
+
+    int64_t input_dim = nt_sizes.sizes()[1];
+    int64_t batch_size = nt_sizes.sizes()[0];
+    // TODO: Remove need for cat here
+    at::Tensor metadata = at::cat({offsets, nt_sizes.reshape(-1)});
+    metadata = metadata.to(at::Device(kCUDA), at::kInt);
+
+    std::vector<Tensor> split =
+        at::split_with_sizes(metadata, {offsets.numel(), nt_sizes.numel()}, 0);
+
+    offsets = split[0];
+    nt_sizes = split[1];
+
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+        nt_buffer.scalar_type(), "NestedTensor_to_padded_tensor_cuda", [&]() {
+          add_padding_kernelLauncher(
+              nt_buffer.data_ptr<scalar_t>(),
+              output.data_ptr<scalar_t>(),
+              (scalar_t)(padding),
+              offsets.data_ptr<int>(),
+              nt_sizes.data_ptr<int>(),
+              input_dim,
+              new_size,
+              batch_size);
+        });
+    return output;
+  }
+  return NestedTensor_to_padded_tensor_generic(t, padding);
+}
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu
index c3f57db..197afe7 100644
--- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu
+++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu
@@ -226,5 +226,217 @@
     const int* output_sizes,
     int output_dim,
     const int batch_size);
+
+template <typename T>
+__global__ void add_padding_1(
+    const T* input,
+    T* output,
+    T padding_value,
+    const int* offsets,
+    const int* input_sizes,
+    int input_dim,
+    int output_sizes_1,
+    const int batch_size) {
+  const int batch_id = blockIdx.x;
+  const int grid_id = blockIdx.y;
+  const int tid = threadIdx.x + grid_id * 256;
+  const int grainsize = 16 * 256;
+  const int batch_input_offset = offsets[batch_id];
+  const int* sizes_i = input_sizes + batch_id * input_dim;
+  const int batch_output_offset = batch_id * output_sizes_1;
+  for (int ii = 0; ii < (output_sizes_1 / grainsize); ii++) {
+    const int i = ii * grainsize + tid;
+    const int output_offset = batch_output_offset + i;
+    if (i < sizes_i[0]) {
+      output[output_offset] = input[batch_input_offset + i];
+    } else {
+      output[output_offset] = padding_value;
+    }
+  }
+  const int i = (output_sizes_1 / grainsize) * grainsize + tid;
+  if (i < output_sizes_1) {
+    const int output_offset = batch_output_offset + i;
+    if (i < sizes_i[0]) {
+      output[output_offset] = input[batch_input_offset + i];
+    } else {
+      output[output_offset] = padding_value;
+    }
+  }
+}
+
+template <typename T>
+__global__ void add_padding_2(
+    const T* input,
+    T* output,
+    T padding_value,
+    const int* offsets,
+    const int* input_sizes,
+    int input_dim,
+    int output_sizes_1,
+    int output_sizes_2,
+    const int batch_size) {
+  const int batch_id = blockIdx.x;
+  const int grid_id = blockIdx.y;
+  const int tid = threadIdx.x + grid_id * 256;
+  const int grainsize = 16 * 256;
+  const int offset = offsets[batch_id];
+  const int* sizes_i = input_sizes + batch_id * input_dim;
+  const int output_offset = batch_id * output_sizes_1 * output_sizes_2;
+  const int output_numel = output_sizes_1 * output_sizes_2;
+  for (int ii = 0; ii < (output_numel / grainsize); ii++) {
+    const int i = ii * grainsize + tid;
+    const int i0 = i / (output_sizes_2);
+    const int i1 = i - i0 * output_sizes_2;
+    if (i0 < sizes_i[0] && i1 < sizes_i[1]) {
+      const int input_offset = offset + i0 * sizes_i[1] + i1;
+      output[output_offset + i] = input[input_offset];
+    } else {
+      output[output_offset + i] = padding_value;
+    }
+  }
+  const int i = (output_numel / grainsize) * grainsize + tid;
+  if (i < output_numel) {
+    const int i0 = i / (output_sizes_2);
+    const int i1 = i - i0 * output_sizes_2;
+    if (i0 < sizes_i[0] && i1 < sizes_i[1]) {
+      const int input_offset = offset + i0 * sizes_i[1] + i1;
+      output[output_offset + i] = input[input_offset];
+    } else {
+      output[output_offset + i] = padding_value;
+    }
+  }
+}
+
+template <typename T>
+__global__ void add_padding_3(
+    const T* input,
+    T* output,
+    T padding_value,
+    const int* offsets,
+    const int* input_sizes,
+    int input_dim,
+    int output_sizes_1,
+    int output_sizes_2,
+    int output_sizes_3,
+    const int batch_size) {
+  const int batch_id = blockIdx.x;
+  const int grid_id = blockIdx.y;
+  const int tid = threadIdx.x + grid_id * 256;
+  const int grainsize = 16 * 256;
+  const int offset = offsets[batch_id];
+  const int* sizes_i = input_sizes + batch_id * input_dim;
+  const int output_offset =
+      batch_id * output_sizes_1 * output_sizes_2 * output_sizes_3;
+  const int output_numel = output_sizes_1 * output_sizes_2 * output_sizes_3;
+  for (int ii = 0; ii < (output_numel / grainsize); ii++) {
+    const int i = ii * grainsize + tid;
+    const int i0 = i / (output_sizes_2 * output_sizes_3);
+    const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
+    const int i2 = i % output_sizes_3;
+    if (i0 < sizes_i[0] && i1 < sizes_i[1] && i2 < sizes_i[2]) {
+      const int input_offset =
+          offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
+      output[output_offset + i] = input[input_offset];
+    } else {
+      output[output_offset + i] = padding_value;
+    }
+  }
+  const int i = (output_numel / grainsize) * grainsize + tid;
+  if (i < output_numel) {
+    const int i0 = i / (output_sizes_2 * output_sizes_3);
+    const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
+    const int i2 = i % output_sizes_3;
+    if (i0 < sizes_i[0] && i1 < sizes_i[1] && i2 < sizes_i[2]) {
+      const int input_offset =
+          offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
+      output[output_offset + i] = input[input_offset];
+    } else {
+      output[output_offset + i] = padding_value;
+    }
+  }
+}
+
+template <typename T>
+void add_padding_kernelLauncher(
+    T* input, // [batch_size x None]
+    T* output, // [batch_size x max(input.nested_size(1)) x inner_size]
+    T padding_value,
+    const int* offsets,
+    const int* input_sizes,
+    int input_dim,
+    const std::vector<int64_t>& output_sizes,
+    const int batch_size) {
+  at::cuda::CUDAStream stream = at::cuda::getDefaultCUDAStream();
+  dim3 grid;
+  grid.x = batch_size;
+  grid.y = 16;
+  if (input_dim == 1) {
+    add_padding_1<T><<<grid, 256, 0, stream>>>(
+        input,
+        output,
+        padding_value,
+        offsets,
+        input_sizes,
+        input_dim,
+        output_sizes[1],
+        batch_size);
+  }
+  if (input_dim == 2) {
+    add_padding_2<T><<<grid, 256, 0, stream>>>(
+        input,
+        output,
+        padding_value,
+        offsets,
+        input_sizes,
+        input_dim,
+        output_sizes[1],
+        output_sizes[2],
+        batch_size);
+  }
+  if (input_dim == 3) {
+    add_padding_3<T><<<grid, 256, 0, stream>>>(
+        input,
+        output,
+        padding_value,
+        offsets,
+        input_sizes,
+        input_dim,
+        output_sizes[1],
+        output_sizes[2],
+        output_sizes[3],
+        batch_size);
+  }
+}
+
+template void add_padding_kernelLauncher<double>(
+    double* input,
+    double* output,
+    double padding_value,
+    const int* offsets,
+    const int* input_sizes,
+    int input_dim,
+    const std::vector<int64_t>& output_sizes,
+    const int batch_size);
+
+template void add_padding_kernelLauncher<float>(
+    float* input,
+    float* output,
+    float padding_value,
+    const int* offsets,
+    const int* input_sizes,
+    int input_dim,
+    const std::vector<int64_t>& output_sizes,
+    const int batch_size);
+
+template void add_padding_kernelLauncher<c10::Half>(
+    c10::Half* input,
+    c10::Half* output,
+    c10::Half padding_value,
+    const int* offsets,
+    const int* input_sizes,
+    int input_dim,
+    const std::vector<int64_t>& output_sizes,
+    const int batch_size);
+
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp
index 2480aeb..697aabb 100644
--- a/aten/src/ATen/native/transformers/attention.cpp
+++ b/aten/src/ATen/native/transformers/attention.cpp
@@ -232,7 +232,7 @@
     const Tensor& qkv_bias,
     const int64_t num_head) {
   auto qkv_ = qkv.is_nested()
-    ? c10::MaybeOwned<Tensor>::owned((NestedTensor_to_padded_tensor(qkv, 0)))
+    ? c10::MaybeOwned<Tensor>::owned(qkv.to_padded_tensor(0))
     : c10::MaybeOwned<Tensor>::borrowed(qkv);
   auto B = qkv_->size(0);
   auto T = qkv_->size(1);
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index 229f4a2..3cbee6b 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -256,11 +256,12 @@
         for i, inp in enumerate(inputs):
             self.assertEqual(emb(inp), ys[i])
 
-    def test_to_padded_tensor_simple(self, device):
-        t = torch.randn(4, 4, 4, device=device)
+    @dtypes(torch.float, torch.float16)
+    def test_to_padded_tensor_simple(self, device, dtype):
+        t = torch.randn(4, 4, 4, device=device, dtype=dtype)
         ts = list(torch.unbind(t))
         ts[0] = ts[0][:-1]
-        nt = torch.nested_tensor(ts, device=device)
+        nt = torch.nested_tensor(ts, device=device, dtype=dtype)
         for padding_value in (0, 1):
             padded = nt.to_padded_tensor(padding_value)
 
@@ -272,17 +273,59 @@
 
             self.assertEqual(padded, correct_output)
             self.assertEqual(padded.device, torch.device(device))
+            self.assertEqual(padded.dtype, dtype)
 
-    def test_to_padded_tensor_unrelated_shapes(self, device):
+    @dtypes(torch.float, torch.float16, torch.double)
+    def test_to_padded_tensor_dim2(self, device, dtype):
         ts = [
-            torch.randn(1, 2, 3, device=device),
-            torch.randn(2, 3, 4, device=device),
-            torch.randn(4, 5, 6, device=device),
+            torch.randn(160, device=device, dtype=dtype),
+            torch.randn(1240, device=device, dtype=dtype),
+            torch.randn(2400, device=device, dtype=dtype),
         ]
-        nt = torch.nested_tensor(ts, device=device)
+        nt = torch.nested_tensor(ts, device=device, dtype=dtype)
         pad = 42
-        correct_output = torch.cat(
-            [torch.nn.ConstantPad3d((0, 6 - x.shape[2], 0, 5 - x.shape[1], 0, 4 - x.shape[0]), pad)(x.unsqueeze(0)) for x in ts])
+        correct_output = []
+        for t in ts:
+            next_output = torch.ones_like(ts[2]) * pad
+            correct_output.append(next_output)
+            next_output[:t.size(0)].copy_(t)
+        correct_output = torch.stack(correct_output)
+        padded = nt.to_padded_tensor(pad)
+        self.assertEqual(padded, correct_output)
+
+    @dtypes(torch.float, torch.float16, torch.double)
+    def test_to_padded_tensor_dim3(self, device, dtype):
+        ts = [
+            torch.randn(16, 21, device=device, dtype=dtype),
+            torch.randn(24, 32, device=device, dtype=dtype),
+            torch.randn(40, 53, device=device, dtype=dtype),
+        ]
+        nt = torch.nested_tensor(ts, device=device, dtype=dtype)
+        pad = 42
+        correct_output = []
+        for t in ts:
+            next_output = torch.ones_like(ts[2]) * pad
+            correct_output.append(next_output)
+            next_output[:t.size(0), :t.size(1)].copy_(t)
+        correct_output = torch.stack(correct_output)
+        padded = nt.to_padded_tensor(pad)
+        self.assertEqual(padded, correct_output)
+
+    @dtypes(torch.float, torch.float16, torch.double)
+    def test_to_padded_tensor_dim4(self, device, dtype):
+        ts = [
+            torch.randn(16, 21, 13, device=device, dtype=dtype),
+            torch.randn(24, 32, 14, device=device, dtype=dtype),
+            torch.randn(40, 53, 16, device=device, dtype=dtype),
+        ]
+        nt = torch.nested_tensor(ts, device=device, dtype=dtype)
+        pad = 42
+        correct_output = []
+        for t in ts:
+            next_output = torch.ones_like(ts[2]) * pad
+            correct_output.append(next_output)
+            next_output[:t.size(0), :t.size(1), :t.size(2)].copy_(t)
+        correct_output = torch.stack(correct_output)
         padded = nt.to_padded_tensor(pad)
         self.assertEqual(padded, correct_output)