blob: ff99b987240281d0cdddb4b10a0aed39bb70ab7f [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <linux/netfilter/nfnetlink_log.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <netinet/tcp.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "NetlinkManager.h"
#include "WakeupController.h"
using ::testing::StrictMock;
using ::testing::Test;
using ::testing::DoAll;
using ::testing::SaveArg;
using ::testing::Return;
using ::testing::_;
namespace android {
namespace net {
const uint32_t kDefaultPacketCopyRange = WakeupController::kDefaultPacketCopyRange;
using netdutils::status::ok;
class MockNetdEventListener {
public:
MOCK_METHOD10(onWakeupEvent,
void(const std::string& prefix, int uid, int ether, int ipNextHeader,
const std::vector<uint8_t>& dstHw, const std::string& srcIp,
const std::string& dstIp, int srcPort, int dstPort, uint64_t timestampNs));
};
class MockIptablesRestore : public IptablesRestoreInterface {
public:
~MockIptablesRestore() override = default;
MOCK_METHOD3(execute, int(const IptablesTarget target, const std::string& commands,
std::string* output));
};
class MockNFLogListener : public NFLogListenerInterface {
public:
~MockNFLogListener() override = default;
MOCK_METHOD2(subscribe, netdutils::Status(uint16_t nfLogGroup, const DispatchFn& fn));
MOCK_METHOD3(subscribe,
netdutils::Status(uint16_t nfLogGroup, uint32_t copyRange, const DispatchFn& fn));
MOCK_METHOD1(unsubscribe, netdutils::Status(uint16_t nfLogGroup));
};
class WakeupControllerTest : public Test {
protected:
WakeupControllerTest() {
EXPECT_CALL(mListener,
subscribe(NetlinkManager::NFLOG_WAKEUP_GROUP, kDefaultPacketCopyRange, _))
.WillOnce(DoAll(SaveArg<2>(&mMessageHandler), Return(ok)));
EXPECT_CALL(mListener,
unsubscribe(NetlinkManager::NFLOG_WAKEUP_GROUP)).WillOnce(Return(ok));
EXPECT_OK(mController.init(&mListener));
}
StrictMock<MockNetdEventListener> mEventListener;
StrictMock<MockIptablesRestore> mIptables;
StrictMock<MockNFLogListener> mListener;
WakeupController mController{
[this](const WakeupController::ReportArgs& args) {
mEventListener.onWakeupEvent(args.prefix, args.uid, args.ethertype, args.ipNextHeader,
args.dstHw, args.srcIp, args.dstIp, args.srcPort,
args.dstPort, args.timestampNs);
},
&mIptables};
NFLogListenerInterface::DispatchFn mMessageHandler;
};
TEST_F(WakeupControllerTest, msgHandlerWithPartialAttributes) {
const char kPrefix[] = "test:prefix";
const uid_t kUid = 8734;
const gid_t kGid = 2222;
const uint64_t kNsPerS = 1000000000ULL;
const uint64_t kTsNs = 9999 + (34 * kNsPerS);
struct Msg {
nlmsghdr nlmsg;
nfgenmsg nfmsg;
nlattr uidAttr;
uid_t uid;
nlattr gidAttr;
gid_t gid;
nlattr tsAttr;
timespec ts;
nlattr prefixAttr;
char prefix[sizeof(kPrefix)];
} msg = {};
msg.uidAttr.nla_type = NFULA_UID;
msg.uidAttr.nla_len = sizeof(msg.uidAttr) + sizeof(msg.uid);
msg.uid = htonl(kUid);
msg.gidAttr.nla_type = NFULA_GID;
msg.gidAttr.nla_len = sizeof(msg.gidAttr) + sizeof(msg.gid);
msg.gid = htonl(kGid);
msg.tsAttr.nla_type = NFULA_TIMESTAMP;
msg.tsAttr.nla_len = sizeof(msg.tsAttr) + sizeof(msg.ts);
msg.ts.tv_sec = htonl(kTsNs / kNsPerS);
msg.ts.tv_nsec = htonl(kTsNs % kNsPerS);
msg.prefixAttr.nla_type = NFULA_PREFIX;
msg.prefixAttr.nla_len = sizeof(msg.prefixAttr) + sizeof(msg.prefix);
memcpy(msg.prefix, kPrefix, sizeof(kPrefix));
auto payload = drop(netdutils::makeSlice(msg), offsetof(Msg, uidAttr));
EXPECT_CALL(mEventListener,
onWakeupEvent(kPrefix, kUid, -1, -1, std::vector<uint8_t>(), "", "", -1, -1, kTsNs));
mMessageHandler(msg.nlmsg, msg.nfmsg, payload);
}
TEST_F(WakeupControllerTest, msgHandler) {
const char kPrefix[] = "test:prefix";
const uid_t kUid = 8734;
const gid_t kGid = 2222;
const std::vector<uint8_t> kMacAddr = {11, 22, 33, 44, 55, 66};
const char* kSrcIpAddr = "192.168.2.1";
const char* kDstIpAddr = "192.168.2.23";
const uint16_t kEthertype = 0x800;
const uint8_t kIpNextHeader = 6;
const uint16_t kSrcPort = 1238;
const uint16_t kDstPort = 4567;
const uint64_t kNsPerS = 1000000000ULL;
const uint64_t kTsNs = 9999 + (34 * kNsPerS);
struct Msg {
nlmsghdr nlmsg;
nfgenmsg nfmsg;
nlattr uidAttr;
uid_t uid;
nlattr gidAttr;
gid_t gid;
nlattr tsAttr;
timespec ts;
nlattr prefixAttr;
char prefix[sizeof(kPrefix)];
nlattr packetHeaderAttr;
struct nfulnl_msg_packet_hdr packetHeader;
nlattr hardwareAddrAttr;
struct nfulnl_msg_packet_hw hardwareAddr;
nlattr packetPayloadAttr;
struct iphdr ipHeader;
struct tcphdr tcpHeader;
} msg = {};
msg.prefixAttr.nla_type = NFULA_PREFIX;
msg.prefixAttr.nla_len = sizeof(msg.prefixAttr) + sizeof(msg.prefix);
memcpy(msg.prefix, kPrefix, sizeof(kPrefix));
msg.uidAttr.nla_type = NFULA_UID;
msg.uidAttr.nla_len = sizeof(msg.uidAttr) + sizeof(msg.uid);
msg.uid = htonl(kUid);
msg.gidAttr.nla_type = NFULA_GID;
msg.gidAttr.nla_len = sizeof(msg.gidAttr) + sizeof(msg.gid);
msg.gid = htonl(kGid);
msg.tsAttr.nla_type = NFULA_TIMESTAMP;
msg.tsAttr.nla_len = sizeof(msg.tsAttr) + sizeof(msg.ts);
msg.ts.tv_sec = htonl(kTsNs / kNsPerS);
msg.ts.tv_nsec = htonl(kTsNs % kNsPerS);
msg.packetHeaderAttr.nla_type = NFULA_PACKET_HDR;
msg.packetHeaderAttr.nla_len = sizeof(msg.packetHeaderAttr) + sizeof(msg.packetHeader);
msg.packetHeader.hw_protocol = htons(kEthertype);
msg.hardwareAddrAttr.nla_type = NFULA_HWADDR;
msg.hardwareAddrAttr.nla_len = sizeof(msg.hardwareAddrAttr) + sizeof(msg.hardwareAddr);
msg.hardwareAddr.hw_addrlen = htons(kMacAddr.size());
std::copy(kMacAddr.begin(), kMacAddr.end(), msg.hardwareAddr.hw_addr);
msg.packetPayloadAttr.nla_type = NFULA_PAYLOAD;
msg.packetPayloadAttr.nla_len =
sizeof(msg.packetPayloadAttr) + sizeof(msg.ipHeader) + sizeof(msg.tcpHeader);
msg.ipHeader.protocol = IPPROTO_TCP;
msg.ipHeader.ihl = sizeof(msg.ipHeader) / 4; // ipv4 IHL counts 32 bit words.
inet_pton(AF_INET, kSrcIpAddr, &msg.ipHeader.saddr);
inet_pton(AF_INET, kDstIpAddr, &msg.ipHeader.daddr);
msg.tcpHeader.th_sport = htons(kSrcPort);
msg.tcpHeader.th_dport = htons(kDstPort);
auto payload = drop(netdutils::makeSlice(msg), offsetof(Msg, uidAttr));
EXPECT_CALL(mEventListener, onWakeupEvent(kPrefix, kUid, kEthertype, kIpNextHeader, kMacAddr,
kSrcIpAddr, kDstIpAddr, kSrcPort, kDstPort, kTsNs));
mMessageHandler(msg.nlmsg, msg.nfmsg, payload);
}
TEST_F(WakeupControllerTest, badAttr) {
const char kPrefix[] = "test:prefix";
const uid_t kUid = 8734;
const gid_t kGid = 2222;
const uint64_t kNsPerS = 1000000000ULL;
const uint64_t kTsNs = 9999 + (34 * kNsPerS);
struct Msg {
nlmsghdr nlmsg;
nfgenmsg nfmsg;
nlattr uidAttr;
uid_t uid;
nlattr invalid0;
nlattr invalid1;
nlattr gidAttr;
gid_t gid;
nlattr tsAttr;
timespec ts;
nlattr prefixAttr;
char prefix[sizeof(kPrefix)];
} msg = {};
msg.uidAttr.nla_type = 999;
msg.uidAttr.nla_len = sizeof(msg.uidAttr) + sizeof(msg.uid);
msg.uid = htonl(kUid);
msg.invalid0.nla_type = 0;
msg.invalid0.nla_len = 0;
msg.invalid1.nla_type = 0;
msg.invalid1.nla_len = 1;
msg.gidAttr.nla_type = NFULA_GID;
msg.gidAttr.nla_len = sizeof(msg.gidAttr) + sizeof(msg.gid);
msg.gid = htonl(kGid);
msg.tsAttr.nla_type = NFULA_TIMESTAMP;
msg.tsAttr.nla_len = sizeof(msg.tsAttr) - 2;
msg.ts.tv_sec = htonl(kTsNs / kNsPerS);
msg.ts.tv_nsec = htonl(kTsNs % kNsPerS);
msg.prefixAttr.nla_type = NFULA_UID;
msg.prefixAttr.nla_len = sizeof(msg.prefixAttr) + sizeof(msg.prefix);
memcpy(msg.prefix, kPrefix, sizeof(kPrefix));
auto payload = drop(netdutils::makeSlice(msg), offsetof(Msg, uidAttr));
EXPECT_CALL(mEventListener,
onWakeupEvent("", 1952805748, -1, -1, std::vector<uint8_t>(), "", "", -1, -1, 0));
mMessageHandler(msg.nlmsg, msg.nfmsg, payload);
}
TEST_F(WakeupControllerTest, unterminatedString) {
char ones[20] = {};
memset(ones, 1, sizeof(ones));
struct Msg {
nlmsghdr nlmsg;
nfgenmsg nfmsg;
nlattr prefixAttr;
char prefix[sizeof(ones)];
} msg = {};
msg.prefixAttr.nla_type = NFULA_PREFIX;
msg.prefixAttr.nla_len = sizeof(msg.prefixAttr) + sizeof(msg.prefix);
memcpy(msg.prefix, ones, sizeof(ones));
const auto expected = std::string(ones, sizeof(ones) - 1);
auto payload = drop(netdutils::makeSlice(msg), offsetof(Msg, prefixAttr));
EXPECT_CALL(mEventListener,
onWakeupEvent(expected, -1, -1, -1, std::vector<uint8_t>(), "", "", -1, -1, 0));
mMessageHandler(msg.nlmsg, msg.nfmsg, payload);
}
TEST_F(WakeupControllerTest, addInterface) {
const char kPrefix[] = "test:prefix";
const char kIfName[] = "wlan8";
const uint32_t kMark = 0x12345678;
const uint32_t kMask = 0x0F0F0F0F;
const char kExpected[] =
"*mangle\n-A wakeupctrl_mangle_INPUT -i test:prefix"
" -j NFLOG --nflog-prefix wlan8 --nflog-group 3 --nflog-threshold 8"
" -m mark --mark 0x12345678/0x0f0f0f0f -m limit --limit 10/s\nCOMMIT\n";
EXPECT_CALL(mIptables, execute(V4V6, kExpected, _)).WillOnce(Return(0));
EXPECT_OK(mController.addInterface(kPrefix, kIfName, kMark, kMask));
}
TEST_F(WakeupControllerTest, delInterface) {
const char kPrefix[] = "test:prefix";
const char kIfName[] = "wlan8";
const uint32_t kMark = 0x12345678;
const uint32_t kMask = 0xF0F0F0F0;
const char kExpected[] =
"*mangle\n-D wakeupctrl_mangle_INPUT -i test:prefix"
" -j NFLOG --nflog-prefix wlan8 --nflog-group 3 --nflog-threshold 8"
" -m mark --mark 0x12345678/0xf0f0f0f0 -m limit --limit 10/s\nCOMMIT\n";
EXPECT_CALL(mIptables, execute(V4V6, kExpected, _)).WillOnce(Return(0));
EXPECT_OK(mController.delInterface(kPrefix, kIfName, kMark, kMask));
}
} // namespace net
} // namespace android