Fix for get_buffer(): check buffers by name instead of value (#61429)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/61242
Previous code was wrongly checking if a tensor is a buffer in a module by comparing values; fix compares names instead.
Docs need some updating as well- current plan is to bump that to a separate PR, but I'm happy to do it here as well if preferred.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61429
Reviewed By: gchanan
Differential Revision: D29712341
Pulled By: jbschlosser
fbshipit-source-id: 41f29ab746505e60f13de42a9053a6770a3aac22
diff --git a/test/test_nn.py b/test/test_nn.py
index 2d002b4..ebd0aa4 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1300,6 +1300,38 @@
m.register_buffer('buffer_name', buffer3)
self.assertEqual(m.buffer_name, buffer3)
+ def test_get_buffer(self):
+ m = nn.Module()
+ buffer1 = torch.randn(2, 3)
+ buffer2 = torch.randn(4, 5)
+ m.register_buffer('foo', buffer1)
+ m.register_buffer('bar', buffer2)
+ self.assertEqual(buffer1, m.get_buffer('foo'))
+ self.assertEqual(buffer2, m.get_buffer('bar'))
+
+ def test_get_buffer_from_submodules(self):
+ class MyModule(nn.Module):
+ def __init__(self, foo, bar):
+ super().__init__()
+ self.sub = Sub(foo, bar)
+
+ class Sub(nn.Module):
+ def __init__(self, foo, bar):
+ super().__init__()
+ self.register_buffer('foo', foo)
+ self.subsub = SubSub(bar)
+
+ class SubSub(nn.Module):
+ def __init__(self, bar):
+ super().__init__()
+ self.register_buffer('bar', bar)
+
+ foo = torch.randn(2, 3)
+ bar = torch.randn(4, 5)
+ m = MyModule(foo, bar)
+ self.assertEqual(foo, m.get_buffer('sub.foo'))
+ self.assertEqual(bar, m.get_buffer('sub.subsub.bar'))
+
def test_buffer_not_persistent(self):
m = nn.Module()
m.register_buffer('buf', torch.rand(5), persistent=False)
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 9181c8c..631a8ea 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -525,7 +525,7 @@
buffer: torch.Tensor = getattr(mod, buffer_name)
- if buffer not in mod._buffers.values():
+ if buffer_name not in mod._buffers:
raise AttributeError("`" + buffer_name + "` is not a buffer")
return buffer