Fix for get_buffer(): check buffers by name instead of value (#61429)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/61242

Previous code was wrongly checking if a tensor is a buffer in a module by comparing values; fix compares names instead.
Docs need some updating as well- current plan is to bump that to a separate PR, but I'm happy to do it here as well if preferred.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61429

Reviewed By: gchanan

Differential Revision: D29712341

Pulled By: jbschlosser

fbshipit-source-id: 41f29ab746505e60f13de42a9053a6770a3aac22
diff --git a/test/test_nn.py b/test/test_nn.py
index 2d002b4..ebd0aa4 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1300,6 +1300,38 @@
         m.register_buffer('buffer_name', buffer3)
         self.assertEqual(m.buffer_name, buffer3)
 
+    def test_get_buffer(self):
+        m = nn.Module()
+        buffer1 = torch.randn(2, 3)
+        buffer2 = torch.randn(4, 5)
+        m.register_buffer('foo', buffer1)
+        m.register_buffer('bar', buffer2)
+        self.assertEqual(buffer1, m.get_buffer('foo'))
+        self.assertEqual(buffer2, m.get_buffer('bar'))
+
+    def test_get_buffer_from_submodules(self):
+        class MyModule(nn.Module):
+            def __init__(self, foo, bar):
+                super().__init__()
+                self.sub = Sub(foo, bar)
+
+        class Sub(nn.Module):
+            def __init__(self, foo, bar):
+                super().__init__()
+                self.register_buffer('foo', foo)
+                self.subsub = SubSub(bar)
+
+        class SubSub(nn.Module):
+            def __init__(self, bar):
+                super().__init__()
+                self.register_buffer('bar', bar)
+
+        foo = torch.randn(2, 3)
+        bar = torch.randn(4, 5)
+        m = MyModule(foo, bar)
+        self.assertEqual(foo, m.get_buffer('sub.foo'))
+        self.assertEqual(bar, m.get_buffer('sub.subsub.bar'))
+
     def test_buffer_not_persistent(self):
         m = nn.Module()
         m.register_buffer('buf', torch.rand(5), persistent=False)
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 9181c8c..631a8ea 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -525,7 +525,7 @@
 
         buffer: torch.Tensor = getattr(mod, buffer_name)
 
-        if buffer not in mod._buffers.values():
+        if buffer_name not in mod._buffers:
             raise AttributeError("`" + buffer_name + "` is not a buffer")
 
         return buffer