VaryingShape<Strides>::isComplete() needs to consider whether each Stride is complete (#58510)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58510
In some case that I don't fully understand we're getting a stride that is:
```
{2:1, 1:1, 0:*}
```
(in this debug output, M:N means stride index M, stride value N). This shape
should be considered incomplete, since we don't actually know the values of the
stride, but VaryingShape::isComplete considers it complete because it only
checks the presence of elements in the vector, not whether those elements are
themselves complete.
ghstack-source-id: 129279583
Test Plan:
new unit test in test/cpp/jit
To see the failure in the context of a real model:
```
./fblearner/predictor/loadgen/download-requests.sh 272478342_0 10 ~/local/requests/272478342_0.recordio
buck-out/gen/fblearner/predictor/loadgen/replay_model_requests --model_id=272478342_0 --replay_record_source=recordio:/data/users/bertrand/requests/272478342_0.recordio --remote_port=9119 --output_file=/data/users/bertrand/responses/272478342_0_actual.recordio --output_type=recordio
buck-out/gen/fblearner/predictor/loadgen/replay_model_requests --model_id=272478342_0 --replay_record_source=recordio:/data/users/bertrand/requests/272478342_0.recordio --remote_port=9119 --output_file=/data/users/bertrand/responses/272478342_0_actual.recordio --output_type=recordio
```
Reviewed By: Krovatkin
Differential Revision: D28520062
fbshipit-source-id: 3ca900337d86480a40fbd90349a698cbb2fa5f11
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index b620215..d0a4f62 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -186,6 +186,10 @@
stride_ == b.stride_;
}
+ bool isComplete() const {
+ return stride_index_ && contiguous_ && stride_;
+ }
+
c10::optional<size_t> stride_index_;
c10::optional<bool> contiguous_;
c10::optional<size_t> stride_;
@@ -351,6 +355,17 @@
c10::optional<std::vector<ShapeSymbol>> dims_;
};
+namespace detail {
+inline bool isComplete(const Stride& s) {
+ return s.isComplete();
+}
+
+template<typename T>
+inline bool isComplete(const T& t) {
+ return true;
+}
+}
+
template <typename T>
struct VaryingShape {
using ListOfOptionalElements = std::vector<c10::optional<T>>;
@@ -414,7 +429,7 @@
return false;
}
for (auto d : *dims_) {
- if(!d) {
+ if (!d || !detail::isComplete(*d)) {
return false;
}
}
diff --git a/test/cpp/jit/test_jit_type.cpp b/test/cpp/jit/test_jit_type.cpp
index 98fe0da..78c9484 100644
--- a/test/cpp/jit/test_jit_type.cpp
+++ b/test/cpp/jit/test_jit_type.cpp
@@ -8,6 +8,20 @@
namespace torch {
namespace jit {
+TEST(JitTypeTest, IsComplete) {
+ auto tt = c10::TensorType::create(
+ at::kFloat,
+ at::kCPU,
+ c10::SymbolicShape(std::vector<c10::optional<int64_t>>({1, 49})),
+ std::vector<c10::Stride>(
+ {c10::Stride{2, true, 1},
+ c10::Stride{1, true, 1},
+ c10::Stride{0, true, c10::nullopt}}),
+ false);
+ TORCH_INTERNAL_ASSERT(!tt->isComplete());
+ TORCH_INTERNAL_ASSERT(!tt->strides().isComplete());
+}
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(JitTypeTest, UnifyTypes) {
auto bool_tensor = TensorType::get()->withScalarType(at::kBool);