Add equality comparison to c10::Dict (#34892)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34892

Same rationale and implementation as in
https://github.com/pytorch/pytorch/pull/34856

Test Plan: Imported from OSS

Differential Revision: D20493169

Pulled By: suo

fbshipit-source-id: 46d79a4ff5d4af2964cfaeb2c43f56decadf3201
diff --git a/aten/src/ATen/core/Dict.h b/aten/src/ATen/core/Dict.h
index 304afbd..92bcd5f 100644
--- a/aten/src/ATen/core/Dict.h
+++ b/aten/src/ATen/core/Dict.h
@@ -338,6 +338,25 @@
    */
   void reserve(size_type count) const;
 
+  /**
+   * Value equality comparison. This function implements Python-like semantics for
+   * equality: two dicts with the same identity (e.g. same pointer) trivially
+   * compare equal, otherwise each element is compared for equality.
+   */
+  template <class Key_, class Value_>
+  friend bool operator==(
+      const Dict<Key_, Value_>& lhs,
+      const Dict<Key_, Value_>& rhs);
+  template <class Key_, class Value_>
+  friend bool operator!=(
+      const Dict<Key_, Value_>& lhs,
+      const Dict<Key_, Value_>& rhs);
+
+  /**
+   * Identity comparison. Returns true if and only if `rhs` represents the same
+   * Dict object as `this`.
+   */
+  bool is(const Dict& rhs) const;
 
   // private API for now because the return type will change to TypePtr
   // instead of optional<TypePtr> once types are mandatory.
diff --git a/aten/src/ATen/core/Dict_inl.h b/aten/src/ATen/core/Dict_inl.h
index 1fc2fc9..f9670ba 100644
--- a/aten/src/ATen/core/Dict_inl.h
+++ b/aten/src/ATen/core/Dict_inl.h
@@ -215,4 +215,39 @@
   impl_->elementTypes.valueType = std::move(t);
 }
 
+template <class Key_, class Value_>
+bool operator==(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
+  if (lhs.impl_ == rhs.impl_) {
+    // Dicts with the same identity trivially compare equal.
+    return true;
+  }
+
+  // TODO: when we define equality on IValue, we can just defer to the
+  // operator== implementation of the underlying map.
+  // For now, do the comparison manually to avoid invoking the template
+  // specialization for IValue equality
+  if (lhs.size() != rhs.size()) {
+    return false;
+  }
+  for (const auto& pr : lhs) {
+    auto it = rhs.find(pr.key());
+    if (it == rhs.end()) {
+      return false;
+    }
+    if (it->value() != pr.value()) {
+      return false;
+    }
+  }
+  return true;
+}
+
+template <class Key_, class Value_>
+bool operator!=(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
+  return !(lhs == rhs);
+}
+
+template <class Key, class Value>
+bool Dict<Key, Value>::is(const Dict& rhs) const {
+  return this->impl_ == rhs.impl_;
+}
 }
diff --git a/aten/src/ATen/test/Dict_test.cpp b/aten/src/ATen/test/Dict_test.cpp
index 4e93401..f027445 100644
--- a/aten/src/ATen/test/Dict_test.cpp
+++ b/aten/src/ATen/test/Dict_test.cpp
@@ -496,3 +496,25 @@
   EXPECT_EQ(dict.end(), found_nokey1);
   EXPECT_EQ(dict.end(), found_nokey2);
 }
+
+TEST(DictTest, dictEquality) {
+  Dict<string, int64_t> dict;
+  dict.insert("one", 1);
+  dict.insert("two", 2);
+
+  Dict<string, int64_t> dictSameValue;
+  dictSameValue.insert("one", 1);
+  dictSameValue.insert("two", 2);
+
+  Dict<string, int64_t> dictNotEqual;
+  dictNotEqual.insert("foo", 1);
+  dictNotEqual.insert("bar", 2);
+
+  Dict<string, int64_t> dictRef = dict;
+
+  EXPECT_EQ(dict, dictSameValue);
+  EXPECT_NE(dict, dictNotEqual);
+  EXPECT_NE(dictSameValue, dictNotEqual);
+  EXPECT_FALSE(dict.is(dictSameValue));
+  EXPECT_TRUE(dict.is(dictRef));
+}