[PyTorch][jit] Fix excess refcounting in TupleType::compare (#66286)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66286
No need to take refcount bumps on each comparator call.
Test Plan: CI, review
Reviewed By: hlu1, JasonHanwen
Differential Revision: D31487058
fbshipit-source-id: 98d2447ac27a12695cb0ebe1e279a6b50744ff4f
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index f7f6546..f41c9a6 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -1062,7 +1062,7 @@
bool compare(
const Type& rhs,
- std::function<bool(const TypePtr, const TypePtr)> fn) const {
+ std::function<bool(const Type&, const Type&)> fn) const {
if (rhs.kind() != kind()) {
return false;
}
@@ -1072,7 +1072,7 @@
if (l_elements.size() != r_elements.size())
return false;
for (size_t i = 0; i < l_elements.size(); ++i) {
- if (!fn(l_elements[i], r_elements[i]))
+ if (!fn(*l_elements[i], *r_elements[i]))
return false;
}
return true;
diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp
index 2b4376b..ab943b4 100644
--- a/aten/src/ATen/core/type.cpp
+++ b/aten/src/ATen/core/type.cpp
@@ -1337,8 +1337,8 @@
};
bool names_match = !rhs->schema() || test_names_match(schema(), rhs->schema());
// co-variant rules for tuples
- return names_match && compare(*rhs, [&](const TypePtr a, const TypePtr b) {
- return a->isSubtypeOfExt(*b, why_not);
+ return names_match && compare(*rhs, [&](const Type& a, const Type& b) {
+ return a.isSubtypeOfExt(b, why_not);
});
}
@@ -1354,7 +1354,7 @@
bool TupleType::operator==(const Type& rhs) const {
bool typesSame =
- compare(rhs, [](const TypePtr a, const TypePtr b) { return *a == *b; });
+ compare(rhs, [](const Type& a, const Type& b) { return a == b; });
if (!typesSame) {
return false;
}