[nn] Add support for +=, * and *= operations for nn.Sequential objects (#81279)
Fixes 71329
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81279
Approved by: https://github.com/albanD
diff --git a/test/test_nn.py b/test/test_nn.py
index 00c4ea1..c8ca423 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1589,6 +1589,50 @@
other = nn.Sequential(l3, l4)
self.assertEqual(n + other, nn.Sequential(l1, l2, l3, l4))
+ def test_Sequential_iadd(self):
+ l1 = nn.Linear(10, 20)
+ l2 = nn.Linear(20, 30)
+ l3 = nn.Linear(30, 40)
+ l4 = nn.Linear(40, 50)
+ n = nn.Sequential(l1, l2, l3)
+ n2 = nn.Sequential(l4)
+ n += n2
+ n2 += n
+ self.assertEqual(n, nn.Sequential(l1, l2, l3, l4))
+ self.assertEqual(n2, nn.Sequential(l4, l1, l2, l3, l4))
+
+ def test_Sequential_mul(self):
+ l1 = nn.Linear(10, 20)
+ l2 = nn.Linear(20, 30)
+ l3 = nn.Linear(30, 40)
+ l4 = nn.Linear(40, 50)
+ n = nn.Sequential(l1, l2, l3, l4)
+ n2 = n * 2
+ self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4))
+
+ def test_Sequential_rmul(self):
+ l1 = nn.Linear(10, 20)
+ l2 = nn.Linear(20, 30)
+ l3 = nn.Linear(30, 40)
+ l4 = nn.Linear(40, 50)
+ n = nn.Sequential(l1, l2, l3, l4)
+ n2 = 2 * n
+ self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4))
+
+ def test_Sequential_imul(self):
+ l1 = nn.Linear(10, 20)
+ l2 = nn.Linear(20, 30)
+ l3 = nn.Linear(30, 40)
+ l4 = nn.Linear(40, 50)
+ n = nn.Sequential(l1, l2, l3, l4)
+ n *= 2
+ self.assertEqual(n, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4))
+ n *= 2
+ self.assertEqual(
+ n,
+ nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4, l1, l2, l3, l4, l1, l2, l3, l4)
+ )
+
def test_Sequential_append(self):
l1 = nn.Linear(10, 20)
l2 = nn.Linear(20, 30)
diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py
index d0d44eb..b1b0129 100644
--- a/torch/nn/modules/container.py
+++ b/torch/nn/modules/container.py
@@ -135,6 +135,48 @@
'of Sequential class, but {} is given.'.format(
str(type(other))))
+ def __iadd__(self, other) -> 'Sequential':
+ if isinstance(other, Sequential):
+ offset = len(self)
+ for i, module in enumerate(other):
+ self.add_module(str(i + offset), module)
+ return self
+ else:
+ raise ValueError('add operator supports only objects '
+ 'of Sequential class, but {} is given.'.format(
+ str(type(other))))
+
+ def __mul__(self, other: int) -> 'Sequential':
+ if not isinstance(other, int):
+ raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
+ elif (other <= 0):
+ raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
+ else:
+ combined = Sequential()
+ offset = 0
+ for _ in range(other):
+ for module in self:
+ combined.add_module(str(offset), module)
+ offset += 1
+ return combined
+
+ def __rmul__(self, other: int) -> 'Sequential':
+ return self.__mul__(other)
+
+ def __imul__(self, other: int) -> 'Sequential':
+ if not isinstance(other, int):
+ raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
+ elif (other <= 0):
+ raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
+ else:
+ len_original = len(self)
+ offset = len(self)
+ for _ in range(other - 1):
+ for i in range(len_original):
+ self.add_module(str(i + offset), self._modules[str(i)])
+ offset += len_original
+ return self
+
@_copy_to_script_wrapper
def __dir__(self):
keys = super(Sequential, self).__dir__()