Add equality comparison to c10::Dict (#34892)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34892
Same rationale and implementation as in
https://github.com/pytorch/pytorch/pull/34856
Test Plan: Imported from OSS
Differential Revision: D20493169
Pulled By: suo
fbshipit-source-id: 46d79a4ff5d4af2964cfaeb2c43f56decadf3201
diff --git a/aten/src/ATen/core/Dict.h b/aten/src/ATen/core/Dict.h
index 304afbd..92bcd5f 100644
--- a/aten/src/ATen/core/Dict.h
+++ b/aten/src/ATen/core/Dict.h
@@ -338,6 +338,25 @@
*/
void reserve(size_type count) const;
+ /**
+ * Value equality comparison. This function implements Python-like semantics for
+ * equality: two dicts with the same identity (e.g. same pointer) trivially
+ * compare equal, otherwise each element is compared for equality.
+ */
+ template <class Key_, class Value_>
+ friend bool operator==(
+ const Dict<Key_, Value_>& lhs,
+ const Dict<Key_, Value_>& rhs);
+ template <class Key_, class Value_>
+ friend bool operator!=(
+ const Dict<Key_, Value_>& lhs,
+ const Dict<Key_, Value_>& rhs);
+
+ /**
+ * Identity comparison. Returns true if and only if `rhs` represents the same
+ * Dict object as `this`.
+ */
+ bool is(const Dict& rhs) const;
// private API for now because the return type will change to TypePtr
// instead of optional<TypePtr> once types are mandatory.
diff --git a/aten/src/ATen/core/Dict_inl.h b/aten/src/ATen/core/Dict_inl.h
index 1fc2fc9..f9670ba 100644
--- a/aten/src/ATen/core/Dict_inl.h
+++ b/aten/src/ATen/core/Dict_inl.h
@@ -215,4 +215,39 @@
impl_->elementTypes.valueType = std::move(t);
}
+template <class Key_, class Value_>
+bool operator==(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
+ if (lhs.impl_ == rhs.impl_) {
+ // Dicts with the same identity trivially compare equal.
+ return true;
+ }
+
+ // TODO: when we define equality on IValue, we can just defer to the
+ // operator== implementation of the underlying map.
+ // For now, do the comparison manually to avoid invoking the template
+ // specialization for IValue equality
+ if (lhs.size() != rhs.size()) {
+ return false;
+ }
+ for (const auto& pr : lhs) {
+ auto it = rhs.find(pr.key());
+ if (it == rhs.end()) {
+ return false;
+ }
+ if (it->value() != pr.value()) {
+ return false;
+ }
+ }
+ return true;
+}
+
+template <class Key_, class Value_>
+bool operator!=(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
+ return !(lhs == rhs);
+}
+
+template <class Key, class Value>
+bool Dict<Key, Value>::is(const Dict& rhs) const {
+ return this->impl_ == rhs.impl_;
+}
}
diff --git a/aten/src/ATen/test/Dict_test.cpp b/aten/src/ATen/test/Dict_test.cpp
index 4e93401..f027445 100644
--- a/aten/src/ATen/test/Dict_test.cpp
+++ b/aten/src/ATen/test/Dict_test.cpp
@@ -496,3 +496,25 @@
EXPECT_EQ(dict.end(), found_nokey1);
EXPECT_EQ(dict.end(), found_nokey2);
}
+
+TEST(DictTest, dictEquality) {
+ Dict<string, int64_t> dict;
+ dict.insert("one", 1);
+ dict.insert("two", 2);
+
+ Dict<string, int64_t> dictSameValue;
+ dictSameValue.insert("one", 1);
+ dictSameValue.insert("two", 2);
+
+ Dict<string, int64_t> dictNotEqual;
+ dictNotEqual.insert("foo", 1);
+ dictNotEqual.insert("bar", 2);
+
+ Dict<string, int64_t> dictRef = dict;
+
+ EXPECT_EQ(dict, dictSameValue);
+ EXPECT_NE(dict, dictNotEqual);
+ EXPECT_NE(dictSameValue, dictNotEqual);
+ EXPECT_FALSE(dict.is(dictSameValue));
+ EXPECT_TRUE(dict.is(dictRef));
+}