Fix _apply in nn.Module (#15305)
Summary:
Fixes an issue that arose from https://github.com/pytorch/pytorch/pull/13481 where `.shared_memory()` couldn't be called. Effectively undoes all changes to `nn.Module` from that PR and solve the relevant problem in a different way (the goal was to be able to call `._apply()` on the Python wrapper for a C++ module).
soumith
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15305
Differential Revision: D13493937
Pulled By: goldsborough
fbshipit-source-id: 4cb8687f90fc8709a536c5e7eacd0dc8edf6f750
diff --git a/test/test_cpp_extensions.py b/test/test_cpp_extensions.py
index 42dec14..73e60a0 100755
--- a/test/test_cpp_extensions.py
+++ b/test/test_cpp_extensions.py
@@ -453,7 +453,7 @@
sequential.to(old_dtype)
self.assertEqual(sequential[2].parameters()[0].dtype, old_dtype)
- # Make sure we can access these method recursively.
+ # Make sure we can access these methods recursively.
self.assertEqual(len(list(sequential.parameters())), len(net.parameters()) * 2 + 1)
self.assertEqual(len(list(sequential.named_parameters())), len(net.named_parameters()) * 2 + 1)
self.assertEqual(len(list(sequential.buffers())), len(net.buffers()) * 2)
@@ -556,6 +556,22 @@
self.assertTrue(p.device.index == 0)
self.assertEqual(cpu_parameters[i], p)
+ net.cpu()
+ net.add_new_parameter("a", torch.eye(5))
+ net.add_new_parameter("b", torch.eye(5))
+ net.add_new_buffer("c", torch.eye(5))
+ net.add_new_buffer("d", torch.eye(5))
+ net.add_new_submodule("fc2")
+ net.add_new_submodule("fc3")
+
+ for p in net.parameters():
+ self.assertTrue(p.device.type == "cpu")
+
+ net.cuda()
+
+ for p in net.parameters():
+ self.assertTrue(p.device.type == "cuda")
+
def test_returns_shared_library_path_when_is_python_module_is_true(self):
source = """
#include <torch/script.h>
diff --git a/test/test_nn.py b/test/test_nn.py
index 0283db2..de31d73 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -570,6 +570,28 @@
input = torch.randn(2, 3, dtype=torch.float)
self.assertEqual(m(input).size(), (2, 5))
+ def test_share_memory(self):
+ class Net(nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.p = nn.Parameter(torch.eye(5))
+ self.par = nn.ParameterList()
+ self.par.append(nn.Parameter(torch.randn(10)))
+
+ def forward(inp):
+ return inp.clone()
+
+ net = Net()
+ for p in net.parameters():
+ self.assertFalse(p.storage().is_shared())
+ for b in net.buffers():
+ self.assertFalse(b.storage().is_shared())
+ net.share_memory()
+ for p in net.parameters():
+ self.assertTrue(p.storage().is_shared())
+ for b in net.buffers():
+ self.assertTrue(b.storage().is_shared())
+
def test_hooks(self):
module = nn.Sigmoid()
input = torch.ones(5, 5, requires_grad=True)
diff --git a/torch/nn/cpp.py b/torch/nn/cpp.py
index 854d488..194c17b 100644
--- a/torch/nn/cpp.py
+++ b/torch/nn/cpp.py
@@ -65,6 +65,19 @@
if not attr.startswith("_"):
setattr(self, attr, getattr(self.cpp_module, attr))
+ def _apply(self, fn):
+ for param in self.parameters():
+ # Tensors stored in modules are graph leaves, and we don't
+ # want to create copy nodes, so we have to unpack the data.
+ param.data = fn(param.data)
+ if param._grad is not None:
+ param._grad.data = fn(param._grad.data)
+
+ for buf in self.buffers():
+ buf.data = fn(buf.data)
+
+ return self
+
@property
def training(self):
return self.cpp_module.training
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 832dff0..4de506d 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -20,18 +20,6 @@
return s
-def _if_float_tensor(fn):
- '''
- Calls `fn` on a value `t` only if `t` is a float tensor, or not a tensor (in
- which case it's a module, as part of a recursive call to apply()).
- '''
- def apply(t):
- if not isinstance(t, torch.Tensor) or t.is_floating_point():
- return fn(t)
- return t
- return apply
-
-
class Module(object):
r"""Base class for all neural network modules.
@@ -196,7 +184,7 @@
def _apply(self, fn):
for module in self.children():
- fn(module)
+ module._apply(fn)
for param in self._parameters.values():
if param is not None:
@@ -296,7 +284,7 @@
Returns:
Module: self
"""
- return self._apply(_if_float_tensor(lambda t: t.float()))
+ return self._apply(lambda t: t.float() if t.is_floating_point() else t)
def double(self):
r"""Casts all floating point parameters and buffers to ``double`` datatype.
@@ -304,7 +292,7 @@
Returns:
Module: self
"""
- return self._apply(_if_float_tensor(lambda t: t.double()))
+ return self._apply(lambda t: t.double() if t.is_floating_point() else t)
def half(self):
r"""Casts all floating point parameters and buffers to ``half`` datatype.
@@ -312,7 +300,7 @@
Returns:
Module: self
"""
- return self._apply(_if_float_tensor(lambda t: t.half()))
+ return self._apply(lambda t: t.half() if t.is_floating_point() else t)
def to(self, *args, **kwargs):
r"""Moves and/or casts the parameters and buffers.
@@ -388,9 +376,7 @@
'dtypes, but got desired dtype={}'.format(dtype))
def convert(t):
- if isinstance(t, torch.Tensor):
- return t.to(device, dtype if t.is_floating_point() else None, non_blocking)
- return t.to(device, dtype, non_blocking)
+ return t.to(device, dtype if t.is_floating_point() else None, non_blocking)
return self._apply(convert)