More tests on xt_qtaguid owner match function

Added another series of tests about inverted uid and gid match test on
qtaguid module. It ensures the behavior of the owner match module is
consistent with new sk_uid replacement.

Bug: 37524657
Test: Test passes on all common kernel branches
Signed-off-by: Chenbo Feng <fengc@google.com>
Change-Id: I1af2243a326284e5b65eaf223e4b3edf14126eb9
diff --git a/net/test/net_test.py b/net/test/net_test.py
index 7f0b511..d1d62b2 100755
--- a/net/test/net_test.py
+++ b/net/test/net_test.py
@@ -343,25 +343,35 @@
 except ValueError:
   HAVE_IPV6 = False
 
-
-class RunAsUid(object):
+class RunAsUidGid(object):
   """Context guard to run a code block as a given UID."""
 
-  def __init__(self, uid):
+  def __init__(self, uid, gid):
     self.uid = uid
+    self.gid = gid
 
   def __enter__(self):
     if self.uid:
       self.saved_uid = os.geteuid()
       self.saved_groups = os.getgroups()
-      if self.uid:
-        os.setgroups(self.saved_groups + [AID_INET])
-        os.seteuid(self.uid)
+      os.setgroups(self.saved_groups + [AID_INET])
+      os.seteuid(self.uid)
+    if self.gid:
+      self.saved_gid = os.getgid()
+      os.setgid(self.gid)
 
   def __exit__(self, unused_type, unused_value, unused_traceback):
     if self.uid:
       os.seteuid(self.saved_uid)
       os.setgroups(self.saved_groups)
+    if self.gid:
+      os.setgid(self.saved_gid)
+
+class RunAsUid(RunAsUidGid):
+  """Context guard to run a code block as a given GID and UID."""
+
+  def __init__(self, uid):
+    RunAsUidGid.__init__(self, uid, 0)
 
 
 class NetworkTest(unittest.TestCase):
diff --git a/net/test/qtaguid_test.py b/net/test/qtaguid_test.py
index 0e3fca8..08f51e7 100755
--- a/net/test/qtaguid_test.py
+++ b/net/test/qtaguid_test.py
@@ -27,6 +27,19 @@
 
 class QtaguidTest(net_test.NetworkTest):
 
+  def RunIptablesCommand(self, args):
+    self.assertFalse(net_test.RunIptablesCommand(4, args))
+    self.assertFalse(net_test.RunIptablesCommand(6, args))
+
+  def setUp(self):
+    self.RunIptablesCommand("-N qtaguid_test_OUTPUT")
+    self.RunIptablesCommand("-A OUTPUT -j qtaguid_test_OUTPUT")
+
+  def tearDown(self):
+    self.RunIptablesCommand("-D OUTPUT -j qtaguid_test_OUTPUT")
+    self.RunIptablesCommand("-F qtaguid_test_OUTPUT")
+    self.RunIptablesCommand("-X qtaguid_test_OUTPUT")
+
   def WriteToCtrl(self, command):
     ctrl_file = open(CTRL_PROCPATH, 'w')
     ctrl_file.write(command)
@@ -38,27 +51,69 @@
         return True
     return False
 
-  def SetIptablesRule(self, version, is_add, is_gid, my_id):
+  def SetIptablesRule(self, version, is_add, is_gid, my_id, inverted):
     add_del = "-A" if is_add else "-D"
     uid_gid = "--gid-owner" if is_gid else "--uid-owner"
-    args = "%s OUTPUT -m owner %s %d -j DROP" % (add_del, uid_gid, my_id)
+    if inverted:
+      args = "%s qtaguid_test_OUTPUT -m owner ! %s %d -j DROP" % (add_del, uid_gid, my_id)
+    else:
+      args = "%s qtaguid_test_OUTPUT -m owner %s %d -j DROP" % (add_del, uid_gid, my_id)
     self.assertFalse(net_test.RunIptablesCommand(version, args))
 
+  def AddIptablesRule(self, version, is_gid, myId):
+    self.SetIptablesRule(version, True, is_gid, myId, False)
+
+  def AddIptablesInvertedRule(self, version, is_gid, myId):
+    self.SetIptablesRule(version, True, is_gid, myId, True)
+
+  def DelIptablesRule(self, version, is_gid, myId):
+    self.SetIptablesRule(version, False, is_gid, myId, False)
+
+  def DelIptablesInvertedRule(self, version, is_gid, myId):
+    self.SetIptablesRule(version, False, is_gid, myId, True)
+
   def CheckSocketOutput(self, version, is_gid):
     myId = os.getgid() if is_gid else os.getuid()
-    self.SetIptablesRule(version, True, is_gid, myId);
+    self.AddIptablesRule(version, is_gid, myId)
     family = {4: AF_INET, 6: AF_INET6}[version]
     s = socket(family, SOCK_DGRAM, 0)
     addr = {4: "127.0.0.1", 6: "::1"}[version]
     s.bind((addr, 0))
     addr = s.getsockname()
     self.assertRaisesErrno(errno.EPERM, s.sendto, "foo", addr)
-    self.SetIptablesRule(version, False, is_gid, myId)
+    self.DelIptablesRule(version, is_gid, myId)
     s.sendto("foo", addr)
     data, sockaddr = s.recvfrom(4096)
     self.assertEqual("foo", data)
     self.assertEqual(sockaddr, addr)
 
+  def CheckSocketOutputInverted(self, version, is_gid):
+    # Load a inverted iptable rule on current uid/gid 0, traffic from other
+    # uid/gid should be blocked and traffic from current uid/gid should pass.
+    myId = os.getgid() if is_gid else os.getuid()
+    self.AddIptablesInvertedRule(version, is_gid, myId)
+    family = {4: AF_INET, 6: AF_INET6}[version]
+    s = socket(family, SOCK_DGRAM, 0)
+    addr1 = {4: "127.0.0.1", 6: "::1"}[version]
+    s.bind((addr1, 0))
+    addr1 = s.getsockname()
+    s.sendto("foo", addr1)
+    data, sockaddr = s.recvfrom(4096)
+    self.assertEqual("foo", data)
+    self.assertEqual(sockaddr, addr1)
+    with net_test.RunAsUidGid(0 if is_gid else 12345,
+                              12345 if is_gid else 0):
+      s2 = socket(family, SOCK_DGRAM, 0)
+      addr2 = {4: "127.0.0.1", 6: "::1"}[version]
+      s2.bind((addr2, 0))
+      addr2 = s2.getsockname()
+      self.assertRaisesErrno(errno.EPERM, s2.sendto, "foo", addr2)
+    self.DelIptablesInvertedRule(version, is_gid, myId)
+    s.sendto("foo", addr1)
+    data, sockaddr = s.recvfrom(4096)
+    self.assertEqual("foo", data)
+    self.assertEqual(sockaddr, addr1)
+
   def testCloseWithoutUntag(self):
     self.dev_file = open("/dev/xt_qtaguid", "r");
     sk = socket(AF_INET, SOCK_DGRAM, 0)
@@ -88,6 +143,10 @@
     self.CheckSocketOutput(6, False)
     self.CheckSocketOutput(4, True)
     self.CheckSocketOutput(6, True)
+    self.CheckSocketOutputInverted(4, True)
+    self.CheckSocketOutputInverted(6, True)
+    self.CheckSocketOutputInverted(4, False)
+    self.CheckSocketOutputInverted(6, False)
 
   @unittest.skip("does not pass on current kernels")
   def testCheckNotMatchGid(self):