[PyTorch] Make TensorImpl::sizes() customizable and disable it for NestedTensorImpl (#73817)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73817
NestedTensorImpl doesn't have sizes(). Silently getting wrong results back from it is not conducive to efficient software development. Make it throw while allowing sizes() to be inlined in the common case anyway, just like is_contiguous(). Thanks ezyang for the reminder that we could do this.
ghstack-source-id: 151302903
Test Plan: Updated test_nestedtensor.py
Reviewed By: ezyang
Differential Revision: D34660829
fbshipit-source-id: 1289f21127d6a8359893f9174f3c430a290f2c7f
(cherry picked from commit 7098b9fcfbd25a03bac19e1148426ff073810edd)
diff --git a/aten/src/ATen/NestedTensorImpl.cpp b/aten/src/ATen/NestedTensorImpl.cpp
index 51e93fc..a7b6d97 100644
--- a/aten/src/ATen/NestedTensorImpl.cpp
+++ b/aten/src/ATen/NestedTensorImpl.cpp
@@ -30,6 +30,7 @@
key_set_ =
key_set_ - c10::DispatchKeySet({c10::DispatchKey::ADInplaceOrView});
refresh_dim();
+ set_sizes_customization_policy(CustomizableMethodPolicy::NotSupported);
}
void NestedTensorImpl::refresh_dim() {
@@ -38,5 +39,8 @@
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dim() == my_dim);
}
+const char* NestedTensorImpl::tensorimpl_type_name() const {
+ return "NestedTensorImpl";
+}
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/NestedTensorImpl.h b/aten/src/ATen/NestedTensorImpl.h
index 1a3b985..da99667 100644
--- a/aten/src/ATen/NestedTensorImpl.h
+++ b/aten/src/ATen/NestedTensorImpl.h
@@ -53,6 +53,9 @@
return buffer_;
}
+ protected:
+ const char* tensorimpl_type_name() const override;
+
private:
// Must be called after any changes to our dim() to sync the state
// to TensorImpl.
diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp
index fad9dcb..bed86ca 100644
--- a/c10/core/TensorImpl.cpp
+++ b/c10/core/TensorImpl.cpp
@@ -420,7 +420,7 @@
bool TensorImpl::is_contiguous_nondefault_policy_impl(
at::MemoryFormat memory_format) const {
if (has_contiguity_ ==
- static_cast<uint8_t>(HasContiguityPolicy::ContiguityNotSupported)) {
+ static_cast<uint8_t>(CustomizableMethodPolicy::ContiguityNotSupported)) {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"Tensors of type ",
@@ -429,7 +429,7 @@
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
has_contiguity_ ==
- static_cast<uint8_t>(HasContiguityPolicy::CustomBehavior));
+ static_cast<uint8_t>(CustomizableMethodPolicy::CustomBehavior));
return is_contiguous_custom(memory_format);
}
}
@@ -441,6 +441,22 @@
"set_has_contiguity_policy and forget to override is_contiguous_custom?");
}
+IntArrayRef TensorImpl::sizes_nondefault_policy_impl() const {
+ if (sizes_customization_policy_ ==
+ static_cast<uint8_t>(CustomizableMethodPolicy::NotSupported)) {
+ TORCH_CHECK_NOT_IMPLEMENTED(
+ false,
+ "Tensors of type ",
+ tensorimpl_type_name(),
+ " do not have sizes");
+ } else {
+ TORCH_CHECK_NOT_IMPLEMENTED(
+ false,
+ "custom behavior for sizes() is not supported; please add it or file "
+ "an issue.")
+ }
+}
+
static void deletePlacementDeleteContext(void* ptr) {
delete static_cast<PlacementDeleteContext*>(ptr);
}
@@ -572,6 +588,8 @@
dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
dest_impl->reserved_ = src_impl->reserved_;
dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
+ dest_impl->sizes_customization_policy_ =
+ src_impl->sizes_customization_policy_;
dest_impl->storage_access_should_throw_ =
src_impl->storage_access_should_throw_;
if (src_impl->named_tensor_meta_ != nullptr) {
diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h
index 4f6019a..2ce743f 100644
--- a/c10/core/TensorImpl.h
+++ b/c10/core/TensorImpl.h
@@ -699,16 +699,32 @@
/**
* Return a reference to the sizes of this tensor. This reference remains
* valid as long as the tensor is live and not resized.
+ *
+ * NOTE: sizes() is only `TENSORIMPL_MAYBE_VIRTUAL` for backward
+ * compatibility. See `set_sizes_customization_policy` for the
+ * encouraged customization point.
+ *
+ * NOTE: Currently, CustomizableMethodPolicy::CustomBehavior is not
+ * supported due to a lack of use case, but it can easily be added.
*/
TENSORIMPL_MAYBE_VIRTUAL IntArrayRef sizes() const
#ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
{
+ if (C10_UNLIKELY(
+ sizes_customization_policy_ !=
+ static_cast<uint8_t>(CustomizableMethodPolicy::Default))) {
+ return sizes_nondefault_policy_impl();
+ }
return sizes_and_strides_.sizes_arrayref();
}
#else
;
#endif
+ private:
+ IntArrayRef sizes_nondefault_policy_impl() const;
+
+ public:
/**
* Return a reference to the strides of this tensor. This reference remains
* valid as long as the tensor is live and not restrided.
@@ -2408,24 +2424,33 @@
}
protected:
- // Policy for adjusting the behavior of is_contiguous(). Allows
- // subclass customization while still being able to inline
- // is_contiguous() in the common case.
- enum class HasContiguityPolicy : uint8_t {
- // Default behavior: check is_contiguous_ and similar bitflags.
+ // Policy for adjusting the behavior of customizable methods like
+ // is_contiguous() and sizes(). Allows subclass customization while
+ // still being able to inline the methods in the common case.
+ enum class CustomizableMethodPolicy : uint8_t {
+ // Default behavior.
Default,
// Throw a generic error message that this tensor type does not
- // support is_contiguous.
- ContiguityNotSupported,
- // Call virtual is_contiguous_custom method to implement custom
- // is_contiguous behavior.
+ // support the method in question.
+ NotSupported,
+ // For backward compatibility.
+ ContiguityNotSupported = NotSupported,
+ // Call virtual foo_custom method to implement custom foo
+ // behavior.
CustomBehavior,
};
- void set_has_contiguity_policy(HasContiguityPolicy p) {
+ // For backward compatibility.
+ using HasContiguityPolicy = CustomizableMethodPolicy;
+
+ void set_has_contiguity_policy(CustomizableMethodPolicy p) {
has_contiguity_ = static_cast<uint8_t>(p);
}
+ void set_sizes_customization_policy(CustomizableMethodPolicy p) {
+ sizes_customization_policy_ = static_cast<uint8_t>(p);
+ }
+
Storage storage_;
private:
@@ -2536,7 +2561,7 @@
// or -std=gnu++2a
inline void init_bitfields() {
is_contiguous_ = true;
- has_contiguity_ = static_cast<uint8_t>(HasContiguityPolicy::Default);
+ has_contiguity_ = static_cast<uint8_t>(CustomizableMethodPolicy::Default);
is_channels_last_ = false;
is_channels_last_contiguous_ = false;
@@ -2547,6 +2572,8 @@
allow_tensor_metadata_change_ = true;
reserved_ = false;
owns_pyobj_ = false;
+ sizes_customization_policy_ =
+ static_cast<uint8_t>(CustomizableMethodPolicy::Default);
storage_access_should_throw_ = false;
}
@@ -2607,6 +2634,9 @@
// direction (to make sure the pyobj stays live).
bool owns_pyobj_ : 1;
+ // Customization policy for the sizes() virtual method.
+ /* CustomizableMethodPolicy */ uint8_t sizes_customization_policy_ : 2;
+
// The set of DispatchKeys which describe this tensor. NB: this
// does NOT include Autograd (historically, it did, but
// not anymore!)
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index cf868f2..368a510 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -133,14 +133,15 @@
RuntimeError, "numel is disabled", lambda: a1.numel(),
)
- @unittest.skipIf(IS_FBCODE, "size is not virtual in fbcode.")
@torch.inference_mode()
def test_size(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertRaisesRegex(
RuntimeError,
- "NestedTensorImpl doesn't support sizes",
+ "Tensors of type NestedTensorImpl do not have sizes"
+ if IS_FBCODE
+ else "NestedTensorImpl doesn't support sizes",
lambda: a1.size(),
)