Refactor parameterization logic in net tests

This change refactors the logic to create parameterized tests in the
kernel unit tests. Logic for name generation is left to classes, while
common code for test injection is moved to the utility class

Bug: 66467511
Test: Ran net tests
Change-Id: I7eba57c616145246637beefac3aca16f9e2e899e
(cherry picked from commit ad7a31a77695b60bdcd223df568a2b921acc41b0)
diff --git a/net/test/parameterization_test.py b/net/test/parameterization_test.py
new file mode 100755
index 0000000..8f9e130
--- /dev/null
+++ b/net/test/parameterization_test.py
@@ -0,0 +1,83 @@
+#!/usr/bin/python
+#
+# Copyright 2018 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import unittest
+
+import net_test
+import util
+
+
+def InjectTests():
+  ParmeterizationTest.InjectTests()
+
+
+# This test class ensures that the Parameterized Test generator in utils.py
+# works properly. It injects test methods into itself, and ensures that they
+# are generated as expected, and that the TestClosures being run are properly
+# defined, and running different parameterized tests each time.
+class ParmeterizationTest(net_test.NetworkTest):
+  tests_run_list = []
+
+  @staticmethod
+  def NameGenerator(a, b, c):
+    return str(a) + "_" + str(b) + "_" + str(c)
+
+  @classmethod
+  def InjectTests(cls):
+    PARAMS_A = (1, 2)
+    PARAMS_B = (3, 4)
+    PARAMS_C = (5, 6)
+
+    param_list = itertools.product(PARAMS_A, PARAMS_B, PARAMS_C)
+    util.InjectParameterizedTest(cls, param_list, cls.NameGenerator)
+
+  def ParamTestDummyFunc(self, a, b, c):
+    self.tests_run_list.append(
+        "testDummyFunc_" + ParmeterizationTest.NameGenerator(a, b, c))
+
+  def testParameterization(self):
+    expected = [
+        "testDummyFunc_1_3_5",
+        "testDummyFunc_1_3_6",
+        "testDummyFunc_1_4_5",
+        "testDummyFunc_1_4_6",
+        "testDummyFunc_2_3_5",
+        "testDummyFunc_2_3_6",
+        "testDummyFunc_2_4_5",
+        "testDummyFunc_2_4_6",
+    ]
+
+    actual = [name for name in dir(self) if name.startswith("testDummyFunc")]
+
+    # Check that name and contents are equal
+    self.assertEqual(len(expected), len(actual))
+    self.assertEqual(sorted(expected), sorted(actual))
+
+    # Start a clean list, and run all the tests.
+    self.tests_run_list = list()
+    for test_name in expected:
+      test_method = getattr(self, test_name)
+      test_method()
+
+    # Make sure all tests have been run with the correct parameters
+    for test_name in expected:
+      self.assertTrue(test_name in self.tests_run_list)
+
+
+if __name__ == "__main__":
+  ParmeterizationTest.InjectTests()
+  unittest.main()
diff --git a/net/test/util.py b/net/test/util.py
index bed3e1d..cbcd2d0 100644
--- a/net/test/util.py
+++ b/net/test/util.py
@@ -13,4 +13,59 @@
 # limitations under the License.
 
 def GetPadLength(block_size, length):
-  return (block_size - (length % block_size)) % block_size
\ No newline at end of file
+  return (block_size - (length % block_size)) % block_size
+
+
+def InjectParameterizedTest(cls, param_list, name_generator):
+  """Injects parameterized tests into the provided class
+
+  This method searches for all tests that start with the name "ParamTest",
+  and injects a test method for each set of parameters in param_list. Names
+  are generated via the use of the name_generator.
+
+  Args:
+    cls: the class for which to inject all parameterized tests
+    param_list: a list of tuples, where each tuple is a combination of
+        of parameters to test (i.e. representing a single test case)
+    name_generator: A function that takes a combination of parameters and
+        returns a string that identifies the test case.
+  """
+  param_test_names = [name for name in dir(cls) if name.startswith("ParamTest")]
+
+  # Force param_list to an actual list; otherwise itertools.Product will hit
+  # the end, resulting in only the first ParamTest* method actually being
+  # parameterized
+  param_list = list(param_list)
+
+  # Parameterize each test method starting with "ParamTest"
+  for test_name in param_test_names:
+    func = getattr(cls, test_name)
+
+    for params in param_list:
+      # Give the test method a readable, debuggable name.
+      param_string = name_generator(*params)
+      new_name = "%s_%s" % (func.__name__.replace("ParamTest", "test"),
+                            param_string)
+      new_name = new_name.replace("(", "-").replace(")", "")  # remove parens
+
+      # Inject the test method
+      setattr(cls, new_name, _GetTestClosure(func, params))
+
+
+def _GetTestClosure(func, params):
+  """ Creates a no-argument test method for the given function and parameters.
+
+  This is required to be separate from the InjectParameterizedTest method, due
+  to some interesting scoping issues with internal function declarations. If
+  left in InjectParameterizedTest, all the tests end up using the same
+  instance of TestClosure
+
+  Args:
+    func: the function for which this test closure should run
+    params: the parameters for the run of this test function
+  """
+
+  def TestClosure(self):
+    func(self, *params)
+
+  return TestClosure
diff --git a/net/test/xfrm_algorithm_test.py b/net/test/xfrm_algorithm_test.py
index 6adc461..0176265 100755
--- a/net/test/xfrm_algorithm_test.py
+++ b/net/test/xfrm_algorithm_test.py
@@ -27,6 +27,7 @@
 import multinetwork_base
 import net_test
 from tun_twister import TapTwister
+import util
 import xfrm
 import xfrm_base
 
@@ -72,49 +73,26 @@
 ]
 
 def InjectTests():
-    XfrmAlgorithmTest.InjectTests()
+  XfrmAlgorithmTest.InjectTests()
+
 
 class XfrmAlgorithmTest(xfrm_base.XfrmLazyTest):
   @classmethod
   def InjectTests(cls):
-    """Inject parameterized test cases into this class.
-
-    Because a library for parameterized testing is not availble in
-    net_test.rootfs.20150203, this does a minimal parameterization.
-
-    This finds methods named like "ParamTestFoo" and replaces them with several
-    "testFoo(*)" methods taking different parameter dicts. A set of test
-    parameters is generated from every combination of encryption,
-    authentication, IP version, and TCP/UDP.
-
-    The benefit of this approach is that an individually failing tests have a
-    clearly separated stack trace, and one failed test doesn't prevent the rest
-    from running.
-    """
-    param_test_names = [
-        name for name in dir(cls) if name.startswith("ParamTest")
-    ]
     VERSIONS = (4, 6)
     TYPES = (SOCK_DGRAM, SOCK_STREAM)
 
     # Tests all combinations of auth & crypt. Mutually exclusive with aead.
-    for crypt, auth, version, proto, name in itertools.product(
-        CRYPT_ALGOS, AUTH_ALGOS, VERSIONS, TYPES, param_test_names):
-      XfrmAlgorithmTest.InjectSingleTest(name, version, proto, crypt=crypt, auth=auth)
+    param_list = itertools.product(VERSIONS, TYPES, AUTH_ALGOS, CRYPT_ALGOS,
+                                   [None])
+    util.InjectParameterizedTest(cls, param_list, cls.TestNameGenerator)
 
     # Tests all combinations of aead. Mutually exclusive with auth/crypt.
-    for aead, version, proto, name in itertools.product(
-        AEAD_ALGOS, VERSIONS, TYPES, param_test_names):
-      XfrmAlgorithmTest.InjectSingleTest(name, version, proto, aead=aead)
+    param_list = itertools.product(VERSIONS, TYPES, [None], [None], AEAD_ALGOS)
+    util.InjectParameterizedTest(cls, param_list, cls.TestNameGenerator)
 
-  @classmethod
-  def InjectSingleTest(cls, name, version, proto, crypt=None, auth=None, aead=None):
-    func = getattr(cls, name)
-
-    def TestClosure(self):
-      func(self, {"crypt": crypt, "auth": auth, "aead": aead,
-          "version": version, "proto": proto})
-
+  @staticmethod
+  def TestNameGenerator(version, proto, auth, crypt, aead):
     # Produce a unique and readable name for each test. e.g.
     #     testSocketPolicySimple_cbc-aes_256_hmac-sha512_512_256_IPv6_UDP
     param_string = ""
@@ -131,12 +109,9 @@
 
     param_string += "%s_%s" % ("IPv4" if version == 4 else "IPv6",
         "UDP" if proto == SOCK_DGRAM else "TCP")
-    new_name = "%s_%s" % (func.__name__.replace("ParamTest", "test"),
-                          param_string)
-    new_name = new_name.replace("(", "-").replace(")", "")  # remove parens
-    setattr(cls, new_name, TestClosure)
+    return param_string
 
-  def ParamTestSocketPolicySimple(self, params):
+  def ParamTestSocketPolicySimple(self, version, proto, auth, crypt, aead):
     """Test two-way traffic using transport mode and socket policies."""
 
     def AssertEncrypted(packet):
@@ -153,37 +128,21 @@
     # other using transport mode ESP. Because of TapTwister, both sockets
     # perceive each other as owning "remote_addr".
     netid = self.RandomNetid()
-    family = net_test.GetAddressFamily(params["version"])
-    local_addr = self.MyAddress(params["version"], netid)
-    remote_addr = self.GetRemoteSocketAddress(params["version"])
-    crypt_left = (xfrm.XfrmAlgo((
-        params["crypt"].name,
-        params["crypt"].key_len)),
-        os.urandom(params["crypt"].key_len / 8)) if params["crypt"] else None
-    crypt_right = (xfrm.XfrmAlgo((
-        params["crypt"].name,
-        params["crypt"].key_len)),
-        os.urandom(params["crypt"].key_len / 8)) if params["crypt"] else None
-    auth_left = (xfrm.XfrmAlgoAuth((
-        params["auth"].name,
-        params["auth"].key_len,
-        params["auth"].trunc_len)),
-        os.urandom(params["auth"].key_len / 8)) if params["auth"] else None
-    auth_right = (xfrm.XfrmAlgoAuth((
-        params["auth"].name,
-        params["auth"].key_len,
-        params["auth"].trunc_len)),
-        os.urandom(params["auth"].key_len / 8)) if params["auth"] else None
-    aead_left = (xfrm.XfrmAlgoAead((
-        params["aead"].name,
-        params["aead"].key_len,
-        params["aead"].icv_len)),
-        os.urandom(params["aead"].key_len / 8)) if params["aead"] else None
-    aead_right = (xfrm.XfrmAlgoAead((
-        params["aead"].name,
-        params["aead"].key_len,
-        params["aead"].icv_len)),
-        os.urandom(params["aead"].key_len / 8)) if params["aead"] else None
+    family = net_test.GetAddressFamily(version)
+    local_addr = self.MyAddress(version, netid)
+    remote_addr = self.GetRemoteSocketAddress(version)
+    auth_left = (xfrm.XfrmAlgoAuth((auth.name, auth.key_len, auth.trunc_len)),
+                 os.urandom(auth.key_len / 8)) if auth else None
+    auth_right = (xfrm.XfrmAlgoAuth((auth.name, auth.key_len, auth.trunc_len)),
+                  os.urandom(auth.key_len / 8)) if auth else None
+    crypt_left = (xfrm.XfrmAlgo((crypt.name, crypt.key_len)),
+                  os.urandom(crypt.key_len / 8)) if crypt else None
+    crypt_right = (xfrm.XfrmAlgo((crypt.name, crypt.key_len)),
+                   os.urandom(crypt.key_len / 8)) if crypt else None
+    aead_left = (xfrm.XfrmAlgoAead((aead.name, aead.key_len, aead.icv_len)),
+                 os.urandom(aead.key_len / 8)) if aead else None
+    aead_right = (xfrm.XfrmAlgoAead((aead.name, aead.key_len, aead.icv_len)),
+                  os.urandom(aead.key_len / 8)) if aead else None
     spi_left = 0xbeefface
     spi_right = 0xcafed00d
     req_ids = [100, 200, 300, 400]  # Used to match templates and SAs.
@@ -242,20 +201,20 @@
         output_mark=None)
 
     # Make two sockets.
-    sock_left = socket(family, params["proto"], 0)
+    sock_left = socket(family, proto, 0)
     sock_left.settimeout(2.0)
     sock_left.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
     self.SelectInterface(sock_left, netid, "mark")
-    sock_right = socket(family, params["proto"], 0)
+    sock_right = socket(family, proto, 0)
     sock_right.settimeout(2.0)
     sock_right.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
     self.SelectInterface(sock_right, netid, "mark")
 
     # For UDP, set SO_LINGER to 0, to prevent TCP sockets from hanging around
     # in a TIME_WAIT state.
-    if params["proto"] == SOCK_STREAM:
-        net_test.DisableFinWait(sock_left)
-        net_test.DisableFinWait(sock_right)
+    if proto == SOCK_STREAM:
+      net_test.DisableFinWait(sock_left)
+      net_test.DisableFinWait(sock_right)
 
     # Apply the left outbound socket policy.
     xfrm_base.ApplySocketPolicy(sock_left, family, xfrm.XFRM_POLICY_OUT,
@@ -302,14 +261,14 @@
         sock.close()
 
     # Server and client need to know each other's port numbers in advance.
-    wildcard_addr = net_test.GetWildcardAddress(params["version"])
+    wildcard_addr = net_test.GetWildcardAddress(version)
     sock_left.bind((wildcard_addr, 0))
     sock_right.bind((wildcard_addr, 0))
     left_port = sock_left.getsockname()[1]
     right_port = sock_right.getsockname()[1]
 
     # Start the appropriate server type on sock_right.
-    target = TcpServer if params["proto"] == SOCK_STREAM else UdpServer
+    target = TcpServer if proto == SOCK_STREAM else UdpServer
     server = threading.Thread(
         target=target,
         args=(sock_right, left_port),