[Nested Tensor] Add xpu device in assertion for nested tensor creation (#114664)
Add xpu device checking in nested tensor creation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114664
Approved by: https://github.com/jgong5, https://github.com/xunnanxu
diff --git a/aten/src/ATen/NestedTensorImpl.cpp b/aten/src/ATen/NestedTensorImpl.cpp
index 99da6fb..223a95e 100644
--- a/aten/src/ATen/NestedTensorImpl.cpp
+++ b/aten/src/ATen/NestedTensorImpl.cpp
@@ -179,8 +179,8 @@
"in the near future.");
auto storage_device = storage_.device();
TORCH_INTERNAL_ASSERT(
- storage_device.is_cpu() || storage_device.is_cuda() || storage_device.is_privateuseone(),
- "NestedTensorImpl storage must be either CUDA, CPU or ", get_privateuse1_backend(), " but got ",
+ storage_device.is_cpu() || storage_device.is_cuda() || storage_device.is_xpu() || storage_device.is_privateuseone(),
+ "NestedTensorImpl storage must be either CUDA, CPU, XPU or ", get_privateuse1_backend(), " but got ",
storage_device);
validate_nested_tensor_metadata(nested_sizes_, nested_strides_, storage_offsets_);
refresh_dim();