Port to_padded_tensor CUDA kernel from pytorch/nestedtensor
This PR adds a custom CUDA kernel to pad NestedTensors between dimension 2 and 4.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76157
Approved by: https://github.com/ngimel
diff --git a/aten/src/ATen/NestedTensorImpl.cpp b/aten/src/ATen/NestedTensorImpl.cpp
index 2a0b439..cfabdc4 100644
--- a/aten/src/ATen/NestedTensorImpl.cpp
+++ b/aten/src/ATen/NestedTensorImpl.cpp
@@ -8,6 +8,33 @@
namespace at {
namespace native {
+inline std::vector<int64_t> construct_opt_sizes(const at::Tensor& sizes) {
+ if (sizes.dim() == 0) {
+ return std::vector<int64_t>();
+ }
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2);
+ std::vector<int64_t> result(1, sizes.sizes()[0]);
+ if (sizes.dim() > 0) {
+ size_t nested_dim = result.size();
+ int64_t* sizes_ptr = sizes.data_ptr<int64_t>();
+ result.resize(nested_dim + sizes.sizes()[1]);
+ int64_t sizes_size_0 = sizes.sizes()[0];
+ int64_t sizes_size_1 = sizes.sizes()[1];
+ for (const auto i : c10::irange(sizes_size_1)) {
+ result[nested_dim + i] = sizes_ptr[i];
+ }
+ for (const auto j : c10::irange(sizes_size_1)) {
+ for (const auto i : c10::irange(sizes_size_0)) {
+ if (result[nested_dim + j] &&
+ (result[nested_dim + j] != sizes_ptr[i * sizes.size(1) + j])) {
+ result[nested_dim + j] = -1;
+ }
+ }
+ }
+ }
+ return result;
+}
+
NestedTensorImpl::NestedTensorImpl(
at::Tensor buffer,
at::Tensor nested_size_tensor)
@@ -17,7 +44,9 @@
buffer.dtype(),
buffer.device()),
buffer_(std::move(buffer)),
- nested_size_tensor_(std::move(nested_size_tensor)) {
+ nested_size_tensor_(std::move(nested_size_tensor)),
+ opt_sizes_(construct_opt_sizes(nested_size_tensor_))
+{
TORCH_WARN_ONCE(
"The PyTorch API of nested tensors is in prototype stage and will change "
"in the near future.");
@@ -41,5 +70,6 @@
const char* NestedTensorImpl::tensorimpl_type_name() const {
return "NestedTensorImpl";
}
+
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/NestedTensorImpl.h b/aten/src/ATen/NestedTensorImpl.h
index f64014d..1343ec1 100644
--- a/aten/src/ATen/NestedTensorImpl.h
+++ b/aten/src/ATen/NestedTensorImpl.h
@@ -32,6 +32,16 @@
const Tensor& get_nested_size_tensor() const {
return nested_size_tensor_;
}
+ // Returns nullopt if the ith dimension is irregular. The ith dimension
+ // of a NestedTensor is regular if the unbound tensors match in
+ // size at the (i-1)th dimension.
+ c10::optional<int64_t> opt_size(int64_t d) const {
+ d = at::maybe_wrap_dim(d, dim(), false);
+ if (opt_sizes_[d] == -1) {
+ return c10::nullopt;
+ }
+ return opt_sizes_[d];
+ }
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
IntArrayRef sizes() const override {
TORCH_CHECK(
@@ -63,6 +73,8 @@
at::Tensor buffer_;
const at::Tensor nested_size_tensor_;
+ // NOTE: -1 here means the size is missing
+ std::vector<int64_t> opt_sizes_;
};
inline NestedTensorImpl* get_nested_tensor_impl_or_null(const at::Tensor& tensor) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index f7d0c2c..e4c9f8c 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -11617,7 +11617,8 @@
- func: to_padded_tensor(Tensor self, float padding) -> Tensor
variants: method
dispatch:
- NestedTensorCPU, NestedTensorCUDA: NestedTensor_to_padded_tensor
+ NestedTensorCPU: NestedTensor_to_padded_tensor_generic
+ NestedTensorCUDA: NestedTensor_to_padded_tensor_cuda
- func: _nested_tensor_layer_norm(Tensor self, Tensor? weight, Tensor? bias, float eps) -> Tensor
variants: method
diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp
index f28f7ae..b16cdc6 100644
--- a/aten/src/ATen/native/nested/NestedTensorMath.cpp
+++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp
@@ -8,6 +8,7 @@
#include <ATen/native/layer_norm.h>
#include <ATen/NestedTensorImpl.h>
#include <c10/core/DispatchKey.h>
+#include <ATen/native/nested/NestedTensorMath.h>
namespace at {
namespace native {
@@ -216,8 +217,7 @@
}
std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl& nt) {
- const auto& sizes = nt.get_nested_size_tensor();
- return NestedTensor_get_max_size_from_size_tensor(sizes);
+ return NestedTensor_get_max_size_from_size_tensor(nt.get_nested_size_tensor());
}
Tensor NestedTensor_layer_norm(
@@ -303,8 +303,7 @@
std::move(new_buffer), sizes);
}
-Tensor NestedTensor_to_padded_tensor(const Tensor& t, double padding) {
- // TODO port CUDA path in pytorch/nestedtensor to_padded_tensor!
+Tensor NestedTensor_to_padded_tensor_generic(const Tensor& t, double padding) {
// TODO: skipped optimization for case of all 1x1 tensors
auto& nt = *get_nested_tensor_impl(t);
auto max_size = NestedTensor_get_max_size(nt);
diff --git a/aten/src/ATen/native/nested/NestedTensorMath.h b/aten/src/ATen/native/nested/NestedTensorMath.h
index 0211cd1..0993863 100644
--- a/aten/src/ATen/native/nested/NestedTensorMath.h
+++ b/aten/src/ATen/native/nested/NestedTensorMath.h
@@ -1,6 +1,7 @@
#pragma once
#include <c10/macros/Macros.h>
+#include <ATen/NestedTensorImpl.h>
#include <vector>
@@ -13,5 +14,7 @@
TORCH_API std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl& nt);
+TORCH_API Tensor NestedTensor_to_padded_tensor_generic(const Tensor& t, double padding);
+
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h
index ce84e85..2a41b79 100644
--- a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h
+++ b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h
@@ -52,7 +52,6 @@
Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim);
-
template <typename T>
void remove_padding_kernelLauncher(
const T* input,
@@ -72,5 +71,16 @@
const int* output_sizes,
int output_dim,
const int batch_size);
+
+template <typename T>
+void add_padding_kernelLauncher(
+ T* input,
+ T* output,
+ T padding_value,
+ const int* offsets,
+ const int* input_sizes,
+ int input_dim,
+ const std::vector<int64_t>& output_sizes,
+ const int batch_size);
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp
index 896c966..2677d7b 100644
--- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp
+++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp
@@ -10,6 +10,7 @@
#endif
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
+#include <ATen/native/nested/NestedTensorMath.h>
namespace at {
namespace native {
@@ -107,5 +108,77 @@
return at::native::nested_from_padded_generic(padded, sizes);
}
}
+
+Tensor batch_offsets_from_efficient_size(const Tensor& ef_sizes) {
+ int64_t* nt_sizes_ptr = ef_sizes.data_ptr<int64_t>();
+ int64_t ef_sizes_size_0 = ef_sizes.sizes()[0];
+ Tensor offsets = at::empty({1 + ef_sizes_size_0}, at::kLong);
+ int64_t* offsets_ptr = offsets.data_ptr<int64_t>();
+ offsets_ptr[0] = 0;
+ int64_t ef_sizes_size_1 = ef_sizes.sizes()[1];
+ for (const auto i : c10::irange(ef_sizes_size_0)) {
+ int64_t prod = 1;
+ for (const auto j : c10::irange(ef_sizes_size_1)) {
+ prod = prod * nt_sizes_ptr[i * ef_sizes_size_1 + j];
+ }
+ offsets_ptr[i + 1] = offsets_ptr[i] + prod;
+ }
+ return offsets;
+}
+
+Tensor NestedTensor_to_padded_tensor_cuda(const Tensor& t, double padding) {
+ int64_t t_dim = t.dim();
+ if (t_dim >= 2 && t_dim <= 4 &&
+ (t.dtype() == at::kFloat || t.dtype() == at::kDouble ||
+ t.dtype() == at::kHalf)) {
+ auto* nt_input = get_nested_tensor_impl(t);
+ TORCH_CHECK(nested_tensor_impl_is_contiguous(nt_input));
+ const auto& nt_buffer = nt_input->get_buffer();
+
+ if (t_dim == 3 && nt_input->opt_size(2) && (*nt_input->opt_size(2) > 0)) {
+ Tensor nt_sizes = nt_input->get_nested_size_tensor();
+ Tensor sizes_dim1 = at::native::narrow(nt_sizes, 1, 0, 1);
+ Tensor sizes_dim2 = at::native::narrow(nt_sizes, 1, 1, 1);
+ Tensor result = at::detail::make_tensor<NestedTensorImpl>(
+ nt_input->get_buffer(), sizes_dim1 * sizes_dim2[0]);
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.dim() == 2);
+ result = NestedTensor_to_padded_tensor_cuda(result, padding);
+ return result.reshape({result.sizes()[0], -1, *nt_input->opt_size(2)});
+ }
+
+ Tensor nt_sizes = nt_input->get_nested_size_tensor();
+ Tensor offsets = batch_offsets_from_efficient_size(nt_sizes);
+ auto new_size = NestedTensor_get_max_size(*nt_input);
+ new_size.insert(new_size.begin(), nt_sizes.sizes()[0]);
+ Tensor output = at::empty(IntArrayRef(new_size), nt_buffer.options());
+
+ int64_t input_dim = nt_sizes.sizes()[1];
+ int64_t batch_size = nt_sizes.sizes()[0];
+ // TODO: Remove need for cat here
+ at::Tensor metadata = at::cat({offsets, nt_sizes.reshape(-1)});
+ metadata = metadata.to(at::Device(kCUDA), at::kInt);
+
+ std::vector<Tensor> split =
+ at::split_with_sizes(metadata, {offsets.numel(), nt_sizes.numel()}, 0);
+
+ offsets = split[0];
+ nt_sizes = split[1];
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ nt_buffer.scalar_type(), "NestedTensor_to_padded_tensor_cuda", [&]() {
+ add_padding_kernelLauncher(
+ nt_buffer.data_ptr<scalar_t>(),
+ output.data_ptr<scalar_t>(),
+ (scalar_t)(padding),
+ offsets.data_ptr<int>(),
+ nt_sizes.data_ptr<int>(),
+ input_dim,
+ new_size,
+ batch_size);
+ });
+ return output;
+ }
+ return NestedTensor_to_padded_tensor_generic(t, padding);
+}
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu
index c3f57db..197afe7 100644
--- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu
+++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu
@@ -226,5 +226,217 @@
const int* output_sizes,
int output_dim,
const int batch_size);
+
+template <typename T>
+__global__ void add_padding_1(
+ const T* input,
+ T* output,
+ T padding_value,
+ const int* offsets,
+ const int* input_sizes,
+ int input_dim,
+ int output_sizes_1,
+ const int batch_size) {
+ const int batch_id = blockIdx.x;
+ const int grid_id = blockIdx.y;
+ const int tid = threadIdx.x + grid_id * 256;
+ const int grainsize = 16 * 256;
+ const int batch_input_offset = offsets[batch_id];
+ const int* sizes_i = input_sizes + batch_id * input_dim;
+ const int batch_output_offset = batch_id * output_sizes_1;
+ for (int ii = 0; ii < (output_sizes_1 / grainsize); ii++) {
+ const int i = ii * grainsize + tid;
+ const int output_offset = batch_output_offset + i;
+ if (i < sizes_i[0]) {
+ output[output_offset] = input[batch_input_offset + i];
+ } else {
+ output[output_offset] = padding_value;
+ }
+ }
+ const int i = (output_sizes_1 / grainsize) * grainsize + tid;
+ if (i < output_sizes_1) {
+ const int output_offset = batch_output_offset + i;
+ if (i < sizes_i[0]) {
+ output[output_offset] = input[batch_input_offset + i];
+ } else {
+ output[output_offset] = padding_value;
+ }
+ }
+}
+
+template <typename T>
+__global__ void add_padding_2(
+ const T* input,
+ T* output,
+ T padding_value,
+ const int* offsets,
+ const int* input_sizes,
+ int input_dim,
+ int output_sizes_1,
+ int output_sizes_2,
+ const int batch_size) {
+ const int batch_id = blockIdx.x;
+ const int grid_id = blockIdx.y;
+ const int tid = threadIdx.x + grid_id * 256;
+ const int grainsize = 16 * 256;
+ const int offset = offsets[batch_id];
+ const int* sizes_i = input_sizes + batch_id * input_dim;
+ const int output_offset = batch_id * output_sizes_1 * output_sizes_2;
+ const int output_numel = output_sizes_1 * output_sizes_2;
+ for (int ii = 0; ii < (output_numel / grainsize); ii++) {
+ const int i = ii * grainsize + tid;
+ const int i0 = i / (output_sizes_2);
+ const int i1 = i - i0 * output_sizes_2;
+ if (i0 < sizes_i[0] && i1 < sizes_i[1]) {
+ const int input_offset = offset + i0 * sizes_i[1] + i1;
+ output[output_offset + i] = input[input_offset];
+ } else {
+ output[output_offset + i] = padding_value;
+ }
+ }
+ const int i = (output_numel / grainsize) * grainsize + tid;
+ if (i < output_numel) {
+ const int i0 = i / (output_sizes_2);
+ const int i1 = i - i0 * output_sizes_2;
+ if (i0 < sizes_i[0] && i1 < sizes_i[1]) {
+ const int input_offset = offset + i0 * sizes_i[1] + i1;
+ output[output_offset + i] = input[input_offset];
+ } else {
+ output[output_offset + i] = padding_value;
+ }
+ }
+}
+
+template <typename T>
+__global__ void add_padding_3(
+ const T* input,
+ T* output,
+ T padding_value,
+ const int* offsets,
+ const int* input_sizes,
+ int input_dim,
+ int output_sizes_1,
+ int output_sizes_2,
+ int output_sizes_3,
+ const int batch_size) {
+ const int batch_id = blockIdx.x;
+ const int grid_id = blockIdx.y;
+ const int tid = threadIdx.x + grid_id * 256;
+ const int grainsize = 16 * 256;
+ const int offset = offsets[batch_id];
+ const int* sizes_i = input_sizes + batch_id * input_dim;
+ const int output_offset =
+ batch_id * output_sizes_1 * output_sizes_2 * output_sizes_3;
+ const int output_numel = output_sizes_1 * output_sizes_2 * output_sizes_3;
+ for (int ii = 0; ii < (output_numel / grainsize); ii++) {
+ const int i = ii * grainsize + tid;
+ const int i0 = i / (output_sizes_2 * output_sizes_3);
+ const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
+ const int i2 = i % output_sizes_3;
+ if (i0 < sizes_i[0] && i1 < sizes_i[1] && i2 < sizes_i[2]) {
+ const int input_offset =
+ offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
+ output[output_offset + i] = input[input_offset];
+ } else {
+ output[output_offset + i] = padding_value;
+ }
+ }
+ const int i = (output_numel / grainsize) * grainsize + tid;
+ if (i < output_numel) {
+ const int i0 = i / (output_sizes_2 * output_sizes_3);
+ const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
+ const int i2 = i % output_sizes_3;
+ if (i0 < sizes_i[0] && i1 < sizes_i[1] && i2 < sizes_i[2]) {
+ const int input_offset =
+ offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
+ output[output_offset + i] = input[input_offset];
+ } else {
+ output[output_offset + i] = padding_value;
+ }
+ }
+}
+
+template <typename T>
+void add_padding_kernelLauncher(
+ T* input, // [batch_size x None]
+ T* output, // [batch_size x max(input.nested_size(1)) x inner_size]
+ T padding_value,
+ const int* offsets,
+ const int* input_sizes,
+ int input_dim,
+ const std::vector<int64_t>& output_sizes,
+ const int batch_size) {
+ at::cuda::CUDAStream stream = at::cuda::getDefaultCUDAStream();
+ dim3 grid;
+ grid.x = batch_size;
+ grid.y = 16;
+ if (input_dim == 1) {
+ add_padding_1<T><<<grid, 256, 0, stream>>>(
+ input,
+ output,
+ padding_value,
+ offsets,
+ input_sizes,
+ input_dim,
+ output_sizes[1],
+ batch_size);
+ }
+ if (input_dim == 2) {
+ add_padding_2<T><<<grid, 256, 0, stream>>>(
+ input,
+ output,
+ padding_value,
+ offsets,
+ input_sizes,
+ input_dim,
+ output_sizes[1],
+ output_sizes[2],
+ batch_size);
+ }
+ if (input_dim == 3) {
+ add_padding_3<T><<<grid, 256, 0, stream>>>(
+ input,
+ output,
+ padding_value,
+ offsets,
+ input_sizes,
+ input_dim,
+ output_sizes[1],
+ output_sizes[2],
+ output_sizes[3],
+ batch_size);
+ }
+}
+
+template void add_padding_kernelLauncher<double>(
+ double* input,
+ double* output,
+ double padding_value,
+ const int* offsets,
+ const int* input_sizes,
+ int input_dim,
+ const std::vector<int64_t>& output_sizes,
+ const int batch_size);
+
+template void add_padding_kernelLauncher<float>(
+ float* input,
+ float* output,
+ float padding_value,
+ const int* offsets,
+ const int* input_sizes,
+ int input_dim,
+ const std::vector<int64_t>& output_sizes,
+ const int batch_size);
+
+template void add_padding_kernelLauncher<c10::Half>(
+ c10::Half* input,
+ c10::Half* output,
+ c10::Half padding_value,
+ const int* offsets,
+ const int* input_sizes,
+ int input_dim,
+ const std::vector<int64_t>& output_sizes,
+ const int batch_size);
+
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp
index 2480aeb..697aabb 100644
--- a/aten/src/ATen/native/transformers/attention.cpp
+++ b/aten/src/ATen/native/transformers/attention.cpp
@@ -232,7 +232,7 @@
const Tensor& qkv_bias,
const int64_t num_head) {
auto qkv_ = qkv.is_nested()
- ? c10::MaybeOwned<Tensor>::owned((NestedTensor_to_padded_tensor(qkv, 0)))
+ ? c10::MaybeOwned<Tensor>::owned(qkv.to_padded_tensor(0))
: c10::MaybeOwned<Tensor>::borrowed(qkv);
auto B = qkv_->size(0);
auto T = qkv_->size(1);
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index 229f4a2..3cbee6b 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -256,11 +256,12 @@
for i, inp in enumerate(inputs):
self.assertEqual(emb(inp), ys[i])
- def test_to_padded_tensor_simple(self, device):
- t = torch.randn(4, 4, 4, device=device)
+ @dtypes(torch.float, torch.float16)
+ def test_to_padded_tensor_simple(self, device, dtype):
+ t = torch.randn(4, 4, 4, device=device, dtype=dtype)
ts = list(torch.unbind(t))
ts[0] = ts[0][:-1]
- nt = torch.nested_tensor(ts, device=device)
+ nt = torch.nested_tensor(ts, device=device, dtype=dtype)
for padding_value in (0, 1):
padded = nt.to_padded_tensor(padding_value)
@@ -272,17 +273,59 @@
self.assertEqual(padded, correct_output)
self.assertEqual(padded.device, torch.device(device))
+ self.assertEqual(padded.dtype, dtype)
- def test_to_padded_tensor_unrelated_shapes(self, device):
+ @dtypes(torch.float, torch.float16, torch.double)
+ def test_to_padded_tensor_dim2(self, device, dtype):
ts = [
- torch.randn(1, 2, 3, device=device),
- torch.randn(2, 3, 4, device=device),
- torch.randn(4, 5, 6, device=device),
+ torch.randn(160, device=device, dtype=dtype),
+ torch.randn(1240, device=device, dtype=dtype),
+ torch.randn(2400, device=device, dtype=dtype),
]
- nt = torch.nested_tensor(ts, device=device)
+ nt = torch.nested_tensor(ts, device=device, dtype=dtype)
pad = 42
- correct_output = torch.cat(
- [torch.nn.ConstantPad3d((0, 6 - x.shape[2], 0, 5 - x.shape[1], 0, 4 - x.shape[0]), pad)(x.unsqueeze(0)) for x in ts])
+ correct_output = []
+ for t in ts:
+ next_output = torch.ones_like(ts[2]) * pad
+ correct_output.append(next_output)
+ next_output[:t.size(0)].copy_(t)
+ correct_output = torch.stack(correct_output)
+ padded = nt.to_padded_tensor(pad)
+ self.assertEqual(padded, correct_output)
+
+ @dtypes(torch.float, torch.float16, torch.double)
+ def test_to_padded_tensor_dim3(self, device, dtype):
+ ts = [
+ torch.randn(16, 21, device=device, dtype=dtype),
+ torch.randn(24, 32, device=device, dtype=dtype),
+ torch.randn(40, 53, device=device, dtype=dtype),
+ ]
+ nt = torch.nested_tensor(ts, device=device, dtype=dtype)
+ pad = 42
+ correct_output = []
+ for t in ts:
+ next_output = torch.ones_like(ts[2]) * pad
+ correct_output.append(next_output)
+ next_output[:t.size(0), :t.size(1)].copy_(t)
+ correct_output = torch.stack(correct_output)
+ padded = nt.to_padded_tensor(pad)
+ self.assertEqual(padded, correct_output)
+
+ @dtypes(torch.float, torch.float16, torch.double)
+ def test_to_padded_tensor_dim4(self, device, dtype):
+ ts = [
+ torch.randn(16, 21, 13, device=device, dtype=dtype),
+ torch.randn(24, 32, 14, device=device, dtype=dtype),
+ torch.randn(40, 53, 16, device=device, dtype=dtype),
+ ]
+ nt = torch.nested_tensor(ts, device=device, dtype=dtype)
+ pad = 42
+ correct_output = []
+ for t in ts:
+ next_output = torch.ones_like(ts[2]) * pad
+ correct_output.append(next_output)
+ next_output[:t.size(0), :t.size(1), :t.size(2)].copy_(t)
+ correct_output = torch.stack(correct_output)
padded = nt.to_padded_tensor(pad)
self.assertEqual(padded, correct_output)