Refactor IpTables class to remove duplication.

This CL tries to remove as much duplicated code from the IpTables class
as possible. The basic construct of running the same command with
different executables/options is extracted into a helper function.

Moreover, the unit tests are simplified by mocking one function call
higher and removing a lot of set-up duplication.

Bug: 26911013
Change-Id: Iecdacab2ef6ffa5631c877835bdfb0bf7191536c
diff --git a/iptables.cc b/iptables.cc
index 55756a1..62201dd 100644
--- a/iptables.cc
+++ b/iptables.cc
@@ -19,6 +19,9 @@
 #include <string>
 #include <vector>
 
+#include <base/bind.h>
+#include <base/bind_helpers.h>
+#include <base/callback.h>
 #include <base/logging.h>
 #include <base/strings/string_number_conversions.h>
 #include <base/strings/string_util.h>
@@ -27,6 +30,9 @@
 #include <brillo/process.h>
 
 namespace {
+
+using IpTablesCallback = base::Callback<bool(const std::string&, bool)>;
+
 #if defined(__ANDROID__)
 const char kIpTablesPath[] = "/system/bin/iptables";
 const char kIp6TablesPath[] = "/system/bin/ip6tables";
@@ -38,6 +44,9 @@
 const char kUnprivilegedUser[] = "nobody";
 #endif  // __ANDROID__
 
+const char kIPv4[] = "IPv4";
+const char kIPv6[] = "IPv6";
+
 const uint64_t kIpTablesCapMask =
     CAP_TO_MASK(CAP_NET_ADMIN) | CAP_TO_MASK(CAP_NET_RAW);
 
@@ -70,6 +79,24 @@
   }
   return true;
 }
+
+bool RunForAllArguments(const IpTablesCallback& iptables_cmd,
+                        const std::vector<std::string>& arguments,
+                        bool add) {
+  bool success = true;
+  for (const auto& argument : arguments) {
+    if (!iptables_cmd.Run(argument, add)) {
+      // On failure, only abort if rules are being added.
+      // If removing a rule fails, attempt the remaining removals but still
+      // return 'false'.
+      success = false;
+      if (add)
+        break;
+    }
+  }
+  return success;
+}
+
 }  // namespace
 
 namespace firewalld {
@@ -233,6 +260,70 @@
   return ip4_success && ip6_success;
 }
 
+bool IpTables::ApplyVpnSetup(const std::vector<std::string>& usernames,
+                             const std::string& interface,
+                             bool add) {
+  bool success = true;
+  std::vector<std::string> added_usernames;
+
+  if (!ApplyRuleForUserTraffic(add)) {
+    if (add) {
+      ApplyRuleForUserTraffic(false /* remove */);
+      return false;
+    }
+    success = false;
+  }
+
+  if (!ApplyMasquerade(interface, add)) {
+    if (add) {
+      ApplyVpnSetup(added_usernames, interface, false /* remove */);
+      return false;
+    }
+    success = false;
+  }
+
+  for (const auto& username : usernames) {
+    if (!ApplyMarkForUserTraffic(username, add)) {
+      if (add) {
+        ApplyVpnSetup(added_usernames, interface, false /* remove */);
+        return false;
+      }
+      success = false;
+    }
+    if (add) {
+      added_usernames.push_back(username);
+    }
+  }
+
+  return success;
+}
+
+bool IpTables::ApplyMasquerade(const std::string& interface, bool add) {
+  const IpTablesCallback apply_masquerade =
+      base::Bind(&IpTables::ApplyMasqueradeWithExecutable,
+                 base::Unretained(this),
+                 interface);
+
+  return RunForAllArguments(
+      apply_masquerade, {kIpTablesPath, kIp6TablesPath}, add);
+}
+
+bool IpTables::ApplyMarkForUserTraffic(const std::string& username, bool add) {
+  const IpTablesCallback apply_mark =
+      base::Bind(&IpTables::ApplyMarkForUserTrafficWithExecutable,
+                 base::Unretained(this),
+                 username);
+
+  return RunForAllArguments(apply_mark, {kIpTablesPath, kIp6TablesPath}, add);
+}
+
+bool IpTables::ApplyRuleForUserTraffic(bool add) {
+  const IpTablesCallback apply_rule = base::Bind(
+      &IpTables::ApplyRuleForUserTrafficWithVersion, base::Unretained(this));
+
+  return RunForAllArguments(apply_rule, {kIPv4, kIPv6}, add);
+}
+
 bool IpTables::AddAcceptRule(const std::string& executable_path,
                              ProtocolEnum protocol,
                              uint16_t port,
@@ -281,98 +372,9 @@
   return ExecvNonRoot(argv, kIpTablesCapMask) == 0;
 }
 
-bool IpTables::ApplyMasquerade46(const std::string& interface, bool add) {
-  bool return_value = true;
-
-  if (!ApplyMasquerade(kIpTablesPath, interface, add)) {
-    LOG(ERROR) << (add ? "Adding" : "Removing")
-               << " masquerade failed for interface " << interface
-               << " using '" << kIpTablesPath << "'";
-    return_value = false;
-    if (add)
-      return false;
-  }
-  if (!ApplyMasquerade(kIp6TablesPath, interface, add)) {
-    LOG(ERROR) << (add ? "Adding" : "Removing")
-               << " masquerade failed for interface " << interface
-               << " using '" << kIp6TablesPath << "'";
-    return_value = false;
-  }
-  return return_value;
-}
-
-bool IpTables::ApplyMarkForUserTraffic46(const std::string& username,
-                                         bool add) {
-  bool return_value = true;
-
-  if (!ApplyMarkForUserTraffic(kIpTablesPath, username, add)) {
-    LOG(ERROR) << (add ? "Adding" : "Removing")
-               << " mark failed for user " << username
-               << " using '" << kIpTablesPath << "'";
-    return_value = false;
-    if (add)
-      return false;
-  }
-  if (!ApplyMarkForUserTraffic(kIp6TablesPath, username, add)) {
-    LOG(ERROR) << (add ? "Adding" : "Removing")
-               << " mark failed for user " << username
-               << " using '" << kIp6TablesPath << "'";
-    return_value = false;
-  }
-  return return_value;
-}
-
-bool IpTables::ApplyVpnSetup(const std::vector<std::string>& usernames,
-                             const std::string& interface,
-                             bool add) {
-  bool return_value = true;
-  std::vector<std::string> added_usernames;
-
-  if (!ApplyRuleForUserTraffic(kIPv4, add)) {
-    LOG(ERROR) << (add ? "Adding" : "Removing")
-               << " rule for IPv4 user traffic failed.";
-    if (add)
-      return false;
-    return_value = false;
-  }
-
-  if (!ApplyRuleForUserTraffic(kIPv6, add)) {
-    LOG(ERROR) << (add ? "Adding" : "Removing")
-               << " rule for IPv6 user traffic failed.";
-    if (add) {
-      ApplyVpnSetup(added_usernames, interface, false);
-      return false;
-    }
-    return_value = false;
-  }
-
-  if (!ApplyMasquerade46(interface, add)) {
-    if (add) {
-      ApplyVpnSetup(added_usernames, interface, false);
-      return false;
-    }
-    return_value = false;
-  }
-
-  for (const auto& username : usernames) {
-    if (!ApplyMarkForUserTraffic46(username, add)) {
-      if (add) {
-        ApplyVpnSetup(added_usernames, interface, false);
-        return false;
-      }
-      return_value = false;
-    }
-    if (add) {
-      added_usernames.push_back(username);
-    }
-  }
-
-  return return_value;
-}
-
-bool IpTables::ApplyMasquerade(const std::string& executable_path,
-                               const std::string& interface,
-                               bool add) {
+bool IpTables::ApplyMasqueradeWithExecutable(const std::string& interface,
+                                             const std::string& executable_path,
+                                             bool add) {
   std::vector<std::string> argv;
   argv.push_back(executable_path);
   argv.push_back("-t");  // table
@@ -385,12 +387,18 @@
   argv.push_back("MASQUERADE");
 
   // Use CAP_NET_ADMIN|CAP_NET_RAW.
-  return ExecvNonRoot(argv, kIpTablesCapMask) == 0;
+  bool success = ExecvNonRoot(argv, kIpTablesCapMask) == 0;
+
+  if (!success) {
+    LOG(ERROR) << (add ? "Adding" : "Removing")
+               << " masquerade failed for interface " << interface
+               << " using '" << executable_path << "'";
+  }
+  return success;
 }
 
-bool IpTables::ApplyMarkForUserTraffic(const std::string& executable_path,
-                                       const std::string& user_name,
-                                       bool add) {
+bool IpTables::ApplyMarkForUserTrafficWithExecutable(
+    const std::string& username, const std::string& executable_path, bool add) {
   std::vector<std::string> argv;
   argv.push_back(executable_path);
   argv.push_back("-t");  // table
@@ -400,17 +408,25 @@
   argv.push_back("-m");
   argv.push_back("owner");
   argv.push_back("--uid-owner");
-  argv.push_back(user_name);
+  argv.push_back(username);
   argv.push_back("-j");
   argv.push_back("MARK");
   argv.push_back("--set-mark");
   argv.push_back(kMarkForUserTraffic);
 
   // Use CAP_NET_ADMIN|CAP_NET_RAW.
-  return ExecvNonRoot(argv, kIpTablesCapMask) == 0;
+  bool success = ExecvNonRoot(argv, kIpTablesCapMask) == 0;
+
+  if (!success) {
+      LOG(ERROR) << (add ? "Adding" : "Removing")
+                 << " mark failed for user " << username
+                 << " using '" << kIpTablesPath << "'";
+  }
+  return success;
 }
 
-bool IpTables::ApplyRuleForUserTraffic(IPVersionEnum ip_version, bool add) {
+bool IpTables::ApplyRuleForUserTrafficWithVersion(const std::string& ip_version,
+                                                  bool add) {
   brillo::ProcessImpl ip;
   ip.AddArg(kIpPath);
   if (ip_version == kIPv6)
@@ -422,7 +438,13 @@
   ip.AddArg("table");
   ip.AddArg(kTableIdForUserTraffic);
 
-  return ip.Run() == 0;
+  bool success = ip.Run() == 0;
+
+  if (!success) {
+    LOG(ERROR) << (add ? "Adding" : "Removing") << " rule for " << ip_version
+               << " user traffic failed";
+  }
+  return success;
 }
 
 int IpTables::ExecvNonRoot(const std::vector<std::string>& argv,
diff --git a/iptables.h b/iptables.h
index fcd7571..74b9acb 100644
--- a/iptables.h
+++ b/iptables.h
@@ -30,7 +30,6 @@
 namespace firewalld {
 
 enum ProtocolEnum { kProtocolTcp, kProtocolUdp };
-enum IPVersionEnum { kIPv4, kIPv6 };
 
 class IpTables : public org::chromium::FirewalldInterface {
  public:
@@ -91,21 +90,21 @@
                      const std::string& interface,
                      bool add);
 
-  virtual bool ApplyMasquerade(const std::string& executable_path,
-                               const std::string& interface,
-                               bool add);
-  virtual bool ApplyMasquerade46(const std::string& interface,
-                                 bool add);
-  virtual bool ApplyMarkForUserTraffic(const std::string& executable_path,
-                                       const std::string& user_name,
-                                       bool add);
-  virtual bool ApplyMarkForUserTraffic46(const std::string& username,
-                                         bool add);
-  virtual bool ApplyRuleForUserTraffic(IPVersionEnum ip_version,
-                                       bool add);
+  virtual bool ApplyMasquerade(const std::string& interface, bool add);
+  bool ApplyMasqueradeWithExecutable(const std::string& interface,
+                                     const std::string& executable_path,
+                                     bool add);
 
-  int ExecvNonRoot(const std::vector<std::string>& argv,
-                   uint64_t capmask);
+  virtual bool ApplyMarkForUserTraffic(const std::string& username, bool add);
+  bool ApplyMarkForUserTrafficWithExecutable(const std::string& username,
+                                             const std::string& executable_path,
+                                             bool add);
+
+  virtual bool ApplyRuleForUserTraffic(bool add);
+  bool ApplyRuleForUserTrafficWithVersion(const std::string& ip_version,
+                                          bool add);
+
+  int ExecvNonRoot(const std::vector<std::string>& argv, uint64_t capmask);
 
   // Keep track of firewall holes to avoid adding redundant firewall rules.
   std::set<Hole> tcp_holes_;
diff --git a/iptables_unittest.cc b/iptables_unittest.cc
index 7bd59fe..615b4f8 100644
--- a/iptables_unittest.cc
+++ b/iptables_unittest.cc
@@ -193,28 +193,15 @@
   const bool add = true;
 
   MockIpTables mock_iptables;
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIpTablesPath, interface, add))
-      .WillOnce(Return(true));
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIp6TablesPath, interface, add))
+  EXPECT_CALL(mock_iptables, ApplyMasquerade(interface, add))
       .WillOnce(Return(true));
 
-  EXPECT_CALL(mock_iptables,
-              ApplyMarkForUserTraffic(kIpTablesPath, usernames[0], add))
+  EXPECT_CALL(mock_iptables, ApplyMarkForUserTraffic(usernames[0], add))
       .WillOnce(Return(true));
-  EXPECT_CALL(mock_iptables,
-              ApplyMarkForUserTraffic(kIp6TablesPath, usernames[0], add))
+  EXPECT_CALL(mock_iptables, ApplyMarkForUserTraffic(usernames[1], add))
       .WillOnce(Return(true));
 
-  EXPECT_CALL(mock_iptables,
-              ApplyMarkForUserTraffic(kIpTablesPath, usernames[1], add))
-      .WillOnce(Return(true));
-  EXPECT_CALL(mock_iptables,
-              ApplyMarkForUserTraffic(kIp6TablesPath, usernames[1], add))
-      .WillOnce(Return(true));
-
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv4, add))
-      .WillOnce(Return(true));
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv6, add))
+  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(add))
       .WillOnce(Return(true));
 
   ASSERT_TRUE(
@@ -228,56 +215,36 @@
   const bool add = true;
 
   MockIpTables mock_iptables;
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIpTablesPath, interface, add))
-      .Times(1)
-      .WillOnce(Return(true));
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIp6TablesPath, interface, add))
+  EXPECT_CALL(mock_iptables, ApplyMasquerade(interface, add))
       .Times(1)
       .WillOnce(Return(true));
 
   EXPECT_CALL(mock_iptables,
-              ApplyMarkForUserTraffic(kIpTablesPath, usernames[0], add))
+              ApplyMarkForUserTraffic(usernames[0], add))
       .Times(1)
       .WillOnce(Return(true));
   EXPECT_CALL(mock_iptables,
-              ApplyMarkForUserTraffic(kIp6TablesPath, usernames[0], add))
-      .Times(1)
-      .WillOnce(Return(true));
-  EXPECT_CALL(mock_iptables,
-              ApplyMarkForUserTraffic(kIpTablesPath, usernames[1], add))
+              ApplyMarkForUserTraffic(usernames[1], add))
       .Times(1)
       .WillOnce(Return(false));
 
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv4, add))
-      .Times(1)
-      .WillOnce(Return(true));
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv6, add))
+  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(add))
       .Times(1)
       .WillOnce(Return(true));
 
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIpTablesPath, interface, remove))
-      .Times(1)
-      .WillOnce(Return(true));
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIp6TablesPath, interface, remove))
+  EXPECT_CALL(mock_iptables, ApplyMasquerade(interface, remove))
       .Times(1)
       .WillOnce(Return(true));
 
   EXPECT_CALL(mock_iptables,
-              ApplyMarkForUserTraffic(kIpTablesPath, usernames[0], remove))
+              ApplyMarkForUserTraffic(usernames[0], remove))
       .Times(1)
       .WillOnce(Return(false));
   EXPECT_CALL(mock_iptables,
-              ApplyMarkForUserTraffic(kIp6TablesPath, usernames[0], remove))
-      .Times(1)
-      .WillOnce(Return(false));
-  EXPECT_CALL(mock_iptables,
-              ApplyMarkForUserTraffic(kIpTablesPath, usernames[1], remove))
+              ApplyMarkForUserTraffic(usernames[1], remove))
               .Times(0);
 
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv4, remove))
-      .Times(1)
-      .WillOnce(Return(false));
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv6, remove))
+  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(remove))
       .Times(1)
       .WillOnce(Return(false));
 
@@ -292,30 +259,21 @@
   const bool add = true;
 
   MockIpTables mock_iptables;
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIpTablesPath, interface, add))
+  EXPECT_CALL(mock_iptables, ApplyMasquerade(interface, add))
       .Times(1)
       .WillOnce(Return(false));
 
-  EXPECT_CALL(mock_iptables, ApplyMarkForUserTraffic(_, _, _)).Times(0);
+  EXPECT_CALL(mock_iptables, ApplyMarkForUserTraffic(_, _)).Times(0);
 
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv4, add))
-      .Times(1)
-      .WillOnce(Return(true));
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv6, add))
+  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(add))
       .Times(1)
       .WillOnce(Return(true));
 
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIpTablesPath, interface, remove))
-      .Times(1)
-      .WillOnce(Return(true));
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIp6TablesPath, interface, remove))
+  EXPECT_CALL(mock_iptables, ApplyMasquerade(interface, remove))
       .Times(1)
       .WillOnce(Return(true));
 
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv4, remove))
-      .Times(1)
-      .WillOnce(Return(true));
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv6, remove))
+  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(remove))
       .Times(1)
       .WillOnce(Return(true));
 
@@ -330,17 +288,15 @@
   const bool add = true;
 
   MockIpTables mock_iptables;
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIpTablesPath, interface, _))
-      .Times(0);
-  EXPECT_CALL(mock_iptables, ApplyMarkForUserTraffic(_, _, _)).Times(0);
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv4, add))
+  EXPECT_CALL(mock_iptables, ApplyMasquerade(interface, _)).Times(0);
+  EXPECT_CALL(mock_iptables, ApplyMarkForUserTraffic(_, _)).Times(0);
+  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(add))
       .Times(1)
       .WillOnce(Return(false));
 
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv4, remove)).Times(0);
+  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(remove)).Times(1);
 
-  ASSERT_FALSE(
-      mock_iptables.ApplyVpnSetup(usernames, interface, add));
+  ASSERT_FALSE(mock_iptables.ApplyVpnSetup(usernames, interface, add));
 }
 
 TEST_F(IpTablesTest, ApplyVpnSetupRemove_Success) {
@@ -350,33 +306,21 @@
   const bool add = true;
 
   MockIpTables mock_iptables;
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIpTablesPath, interface, remove))
+  EXPECT_CALL(mock_iptables, ApplyMasquerade(interface, remove))
       .Times(1)
       .WillOnce(Return(true));
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIp6TablesPath, interface, remove))
-      .Times(1)
-      .WillOnce(Return(true));
-
-  EXPECT_CALL(mock_iptables, ApplyMarkForUserTraffic(_, _, remove))
-      .Times(4)
+  EXPECT_CALL(mock_iptables, ApplyMarkForUserTraffic(_, remove))
+      .Times(2)
       .WillRepeatedly(Return(true));
-
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv4, remove))
-      .Times(1)
-      .WillOnce(Return(true));
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv6, remove))
+  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(remove))
       .Times(1)
       .WillOnce(Return(true));
 
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIpTablesPath, interface, add))
-      .Times(0);
+  EXPECT_CALL(mock_iptables, ApplyMasquerade(interface, add)).Times(0);
+  EXPECT_CALL(mock_iptables, ApplyMarkForUserTraffic(_, add)).Times(0);
+  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(add)).Times(0);
 
-  EXPECT_CALL(mock_iptables, ApplyMarkForUserTraffic(_, _, add))
-      .Times(0);
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv4, add)).Times(0);
-
-  ASSERT_TRUE(
-      mock_iptables.ApplyVpnSetup(usernames, interface, remove));
+  ASSERT_TRUE(mock_iptables.ApplyVpnSetup(usernames, interface, remove));
 }
 
 TEST_F(IpTablesTest, ApplyVpnSetupRemove_Failure) {
@@ -386,33 +330,24 @@
   const bool add = true;
 
   MockIpTables mock_iptables;
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIpTablesPath, interface, remove))
-      .Times(1)
-      .WillRepeatedly(Return(false));
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIp6TablesPath, interface, remove))
+  EXPECT_CALL(mock_iptables, ApplyMasquerade(interface, remove))
       .Times(1)
       .WillRepeatedly(Return(false));
 
-  EXPECT_CALL(mock_iptables, ApplyMarkForUserTraffic(_, _, remove))
-      .Times(4)
+  EXPECT_CALL(mock_iptables, ApplyMarkForUserTraffic(_, remove))
+      .Times(2)
       .WillRepeatedly(Return(false));
 
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv4, remove))
-      .Times(1)
-      .WillRepeatedly(Return(false));
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv6, remove))
+  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(remove))
       .Times(1)
       .WillRepeatedly(Return(false));
 
-  EXPECT_CALL(mock_iptables, ApplyMasquerade(kIpTablesPath, interface, add))
-      .Times(0);
+  EXPECT_CALL(mock_iptables, ApplyMasquerade(interface, add)).Times(0);
 
-  EXPECT_CALL(mock_iptables, ApplyMarkForUserTraffic(_, _, add))
-      .Times(0);
-  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(kIPv4, add)).Times(0);
+  EXPECT_CALL(mock_iptables, ApplyMarkForUserTraffic(_, add)).Times(0);
+  EXPECT_CALL(mock_iptables, ApplyRuleForUserTraffic(add)).Times(0);
 
-  ASSERT_FALSE(
-      mock_iptables.ApplyVpnSetup(usernames, interface, remove));
+  ASSERT_FALSE(mock_iptables.ApplyVpnSetup(usernames, interface, remove));
 }
 
 }  // namespace firewalld
diff --git a/mock_iptables.h b/mock_iptables.h
index cb14801..54aaa25 100644
--- a/mock_iptables.h
+++ b/mock_iptables.h
@@ -37,15 +37,9 @@
       DeleteAcceptRule,
       bool(const std::string&, ProtocolEnum, uint16_t, const std::string&));
 
-  MOCK_METHOD3(ApplyMasquerade, bool(const std::string&,
-                                     const std::string&,
-                                     bool));
-
-  MOCK_METHOD3(ApplyMarkForUserTraffic, bool(const std::string&,
-                                             const std::string&,
-                                             bool));
-
-  MOCK_METHOD2(ApplyRuleForUserTraffic, bool(IPVersionEnum, bool));
+  MOCK_METHOD2(ApplyMasquerade, bool(const std::string&, bool));
+  MOCK_METHOD2(ApplyMarkForUserTraffic, bool(const std::string&, bool));
+  MOCK_METHOD1(ApplyRuleForUserTraffic, bool(bool));
 
  private:
   DISALLOW_COPY_AND_ASSIGN(MockIpTables);