[PyTorch][Static Runtime] Separate overlap checks for easier debugging (#66637)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66637
We can give more information when verify_no_memory_overlap would fail by separating the DCHECK.
ghstack-source-id: 142226105
Test Plan: fitsships
Reviewed By: d1jang
Differential Revision: D31517151
fbshipit-source-id: 8cbc324c27f6b4db4489d1bd469d37b1d8ae6ce1
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index fc5df28..0b3f0dd 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -1575,6 +1575,11 @@
}
bool ProcessedNode::verify_no_memory_overlap() const {
+ return verify_outputs_dont_overlap_each_other() &&
+ verify_inputs_dont_overlap_outputs();
+}
+
+bool ProcessedNode::verify_outputs_dont_overlap_each_other() const {
for (const auto i : c10::irange(outputs_size_)) {
if (!outputs_[i].isTensor()) {
continue;
@@ -1592,7 +1597,10 @@
}
}
}
+ return true;
+}
+bool ProcessedNode::verify_inputs_dont_overlap_outputs() const {
auto schema = node()->maybeSchema();
// skip memory overlap check for mutable ops with only one output
if (!schema || (schema->is_mutable() && outputs_size_ == 1)) {
diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h
index c453b12..b6d0262 100644
--- a/torch/csrc/jit/runtime/static/impl.h
+++ b/torch/csrc/jit/runtime/static/impl.h
@@ -506,6 +506,10 @@
}
private:
+ C10_NODISCARD bool verify_outputs_dont_overlap_each_other() const;
+
+ C10_NODISCARD bool verify_inputs_dont_overlap_outputs() const;
+
Node* node_;
enum class FunctionKind {
kOutVariant,