blob: 7a7162ff4ccf96041bbd48bc154396ca9ee4efe7 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "discovery/public/dns_sd_service_watcher.h"
#include <algorithm>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
using testing::_;
using testing::ContainerEq;
using testing::IsSubsetOf;
using testing::IsSupersetOf;
using testing::StrictMock;
namespace openscreen {
namespace discovery {
namespace {
std::vector<std::string> ConvertRefs(
const std::vector<std::reference_wrapper<const std::string>>& value) {
std::vector<std::string> strings;
// This loop is required to unwrap reference_wrapper objects.
for (const std::string& val : value) {
strings.push_back(val);
}
return strings;
}
static const IPAddress kAddressV4(192, 168, 0, 0);
static const IPEndpoint kEndpointV4{kAddressV4, 0};
constexpr char kCastServiceId[] = "_googlecast._tcp";
constexpr char kCastDomainId[] = "local";
constexpr NetworkInterfaceIndex kNetworkInterface = 0;
class MockDnsSdService : public DnsSdService {
public:
MockDnsSdService() : querier_(this) {}
DnsSdQuerier* GetQuerier() override { return &querier_; }
DnsSdPublisher* GetPublisher() override { return nullptr; }
MOCK_METHOD2(StartQuery,
void(const std::string& service, DnsSdQuerier::Callback* cb));
MOCK_METHOD2(StopQuery,
void(const std::string& service, DnsSdQuerier::Callback* cb));
MOCK_METHOD1(ReinitializeQueries, void(const std::string& service));
private:
class MockQuerier : public DnsSdQuerier {
public:
explicit MockQuerier(MockDnsSdService* service) : mock_service_(service) {
OSP_DCHECK(service);
}
void StartQuery(const std::string& service,
DnsSdQuerier::Callback* cb) override {
mock_service_->StartQuery(service, cb);
}
void StopQuery(const std::string& service,
DnsSdQuerier::Callback* cb) override {
mock_service_->StopQuery(service, cb);
}
void ReinitializeQueries(const std::string& service) override {
mock_service_->ReinitializeQueries(service);
}
private:
MockDnsSdService* const mock_service_;
};
MockQuerier querier_;
};
} // namespace
class TestServiceWatcher : public DnsSdServiceWatcher<std::string> {
public:
using DnsSdServiceWatcher<std::string>::ConstRefT;
explicit TestServiceWatcher(MockDnsSdService* service)
: DnsSdServiceWatcher<std::string>(
service,
kCastServiceId,
[this](const DnsSdInstance& instance) { return Convert(instance); },
[this](std::vector<ConstRefT> ref, ConstRefT service, ServicesUpdatedState state) {
Callback(std::move(ref));
}) {}
MOCK_METHOD1(Callback, void(std::vector<ConstRefT>));
using DnsSdServiceWatcher<std::string>::OnEndpointCreated;
using DnsSdServiceWatcher<std::string>::OnEndpointUpdated;
using DnsSdServiceWatcher<std::string>::OnEndpointDeleted;
private:
std::string Convert(const DnsSdInstance& instance) {
return instance.instance_id();
}
};
class DnsSdServiceWatcherTests : public testing::Test {
public:
DnsSdServiceWatcherTests() : watcher_(&service_) {
// Start service discovery, since all other tests need it
EXPECT_FALSE(watcher_.is_running());
EXPECT_CALL(service_, StartQuery(kCastServiceId, _));
watcher_.StartDiscovery();
testing::Mock::VerifyAndClearExpectations(&service_);
}
protected:
void CreateNewInstance(const DnsSdInstanceEndpoint& record) {
const std::vector<std::string> services_before =
ConvertRefs(watcher_.GetServices());
const size_t count = services_before.size();
std::vector<std::string> callbacked_services;
EXPECT_CALL(watcher_, Callback(_))
.WillOnce([services = &callbacked_services](
std::vector<TestServiceWatcher::ConstRefT> value) {
*services = ConvertRefs(value);
});
watcher_.OnEndpointCreated(record);
testing::Mock::VerifyAndClearExpectations(&watcher_);
std::vector<std::string> fetched_services =
ConvertRefs(watcher_.GetServices());
EXPECT_EQ(fetched_services.size(), count + 1);
EXPECT_THAT(fetched_services, ContainerEq(callbacked_services));
EXPECT_THAT(fetched_services, IsSupersetOf(services_before));
}
void CreateExistingInstance(const DnsSdInstanceEndpoint& record) {
const std::vector<std::string> services_before =
ConvertRefs(watcher_.GetServices());
const size_t count = services_before.size();
std::vector<std::string> callbacked_services;
EXPECT_CALL(watcher_, Callback(_))
.WillOnce([services = &callbacked_services](
std::vector<TestServiceWatcher::ConstRefT> value) {
*services = ConvertRefs(value);
});
watcher_.OnEndpointCreated(record);
testing::Mock::VerifyAndClearExpectations(&watcher_);
const std::vector<std::string> fetched_services =
ConvertRefs(watcher_.GetServices());
EXPECT_EQ(fetched_services.size(), count);
EXPECT_THAT(fetched_services, ContainerEq(callbacked_services));
EXPECT_THAT(fetched_services, ContainerEq(services_before));
}
void UpdateExistingInstance(const DnsSdInstanceEndpoint& record) {
const std::vector<std::string> services_before =
ConvertRefs(watcher_.GetServices());
const size_t count = services_before.size();
std::vector<std::string> callbacked_services;
EXPECT_CALL(watcher_, Callback(_))
.WillOnce([services = &callbacked_services](
std::vector<TestServiceWatcher::ConstRefT> value) {
*services = ConvertRefs(value);
});
watcher_.OnEndpointUpdated(record);
testing::Mock::VerifyAndClearExpectations(&watcher_);
const std::vector<std::string> fetched_services =
ConvertRefs(watcher_.GetServices());
EXPECT_EQ(fetched_services.size(), count);
EXPECT_THAT(fetched_services, ContainerEq(callbacked_services));
EXPECT_THAT(fetched_services, ContainerEq(services_before));
}
void DeleteExistingInstance(const DnsSdInstanceEndpoint& record) {
const std::vector<std::string> services_before =
ConvertRefs(watcher_.GetServices());
const size_t count = services_before.size();
std::vector<std::string> callbacked_services;
EXPECT_CALL(watcher_, Callback(_))
.WillOnce([services = &callbacked_services](
std::vector<TestServiceWatcher::ConstRefT> value) {
*services = ConvertRefs(value);
});
watcher_.OnEndpointDeleted(record);
testing::Mock::VerifyAndClearExpectations(&watcher_);
const std::vector<std::string> fetched_services =
ConvertRefs(watcher_.GetServices());
EXPECT_EQ(fetched_services.size(), count - 1);
}
void UpdateNonExistingInstance(const DnsSdInstanceEndpoint& record) {
const std::vector<std::string> services_before =
ConvertRefs(watcher_.GetServices());
const size_t count = services_before.size();
EXPECT_CALL(watcher_, Callback(_)).Times(0);
watcher_.OnEndpointUpdated(record);
testing::Mock::VerifyAndClearExpectations(&watcher_);
const std::vector<std::string> fetched_services =
ConvertRefs(watcher_.GetServices());
EXPECT_EQ(fetched_services.size(), count);
EXPECT_THAT(services_before, ContainerEq(fetched_services));
}
void DeleteNonExistingInstance(const DnsSdInstanceEndpoint& record) {
const std::vector<std::string> services_before =
ConvertRefs(watcher_.GetServices());
const size_t count = services_before.size();
EXPECT_CALL(watcher_, Callback(_)).Times(0);
watcher_.OnEndpointDeleted(record);
testing::Mock::VerifyAndClearExpectations(&watcher_);
const std::vector<std::string> fetched_services =
ConvertRefs(watcher_.GetServices());
EXPECT_EQ(fetched_services.size(), count);
EXPECT_THAT(services_before, ContainerEq(fetched_services));
}
bool ContainsService(const DnsSdInstanceEndpoint& record) {
const std::string& service = record.instance_id();
const std::vector<TestServiceWatcher::ConstRefT> services =
watcher_.GetServices();
return std::find_if(services.begin(), services.end(),
[&service](const std::string& ref) {
return service == ref;
}) != services.end();
}
StrictMock<MockDnsSdService> service_;
StrictMock<TestServiceWatcher> watcher_;
std::vector<std::string> fetched_services;
};
TEST_F(DnsSdServiceWatcherTests, StartStopDiscoveryWorks) {
EXPECT_TRUE(watcher_.is_running());
EXPECT_CALL(service_, StopQuery(kCastServiceId, _));
watcher_.StopDiscovery();
EXPECT_FALSE(watcher_.is_running());
}
TEST(DnsSdServiceWatcherTest, RefreshFailsBeforeDiscoveryStarts) {
StrictMock<MockDnsSdService> service;
StrictMock<TestServiceWatcher> watcher(&service);
EXPECT_FALSE(watcher.DiscoverNow().ok());
EXPECT_FALSE(watcher.ForceRefresh().ok());
}
TEST_F(DnsSdServiceWatcherTests, RefreshDiscoveryWorks) {
const DnsSdInstanceEndpoint record("Instance", kCastServiceId, kCastDomainId,
DnsSdTxtRecord{}, kNetworkInterface,
kEndpointV4);
CreateNewInstance(record);
// Refresh services.
EXPECT_CALL(service_, ReinitializeQueries(kCastServiceId));
EXPECT_TRUE(watcher_.DiscoverNow().ok());
EXPECT_EQ(watcher_.GetServices().size(), size_t{1});
testing::Mock::VerifyAndClearExpectations(&service_);
EXPECT_CALL(service_, ReinitializeQueries(kCastServiceId));
EXPECT_TRUE(watcher_.ForceRefresh().ok());
EXPECT_EQ(watcher_.GetServices().size(), size_t{0});
testing::Mock::VerifyAndClearExpectations(&service_);
}
TEST_F(DnsSdServiceWatcherTests, CreatingUpdatingDeletingInstancesWork) {
const DnsSdInstanceEndpoint record("Instance", kCastServiceId, kCastDomainId,
DnsSdTxtRecord{}, kNetworkInterface,
kEndpointV4);
const DnsSdInstanceEndpoint record2("Instance2", kCastServiceId,
kCastDomainId, DnsSdTxtRecord{},
kNetworkInterface, kEndpointV4);
EXPECT_FALSE(ContainsService(record));
EXPECT_FALSE(ContainsService(record2));
CreateNewInstance(record);
EXPECT_TRUE(ContainsService(record));
EXPECT_FALSE(ContainsService(record2));
CreateExistingInstance(record);
EXPECT_TRUE(ContainsService(record));
EXPECT_FALSE(ContainsService(record2));
UpdateNonExistingInstance(record2);
EXPECT_TRUE(ContainsService(record));
EXPECT_FALSE(ContainsService(record2));
DeleteNonExistingInstance(record2);
EXPECT_TRUE(ContainsService(record));
EXPECT_FALSE(ContainsService(record2));
CreateNewInstance(record2);
EXPECT_TRUE(ContainsService(record));
EXPECT_TRUE(ContainsService(record2));
UpdateExistingInstance(record2);
EXPECT_TRUE(ContainsService(record));
EXPECT_TRUE(ContainsService(record2));
UpdateExistingInstance(record);
EXPECT_TRUE(ContainsService(record));
EXPECT_TRUE(ContainsService(record2));
DeleteExistingInstance(record);
EXPECT_FALSE(ContainsService(record));
EXPECT_TRUE(ContainsService(record2));
UpdateNonExistingInstance(record);
EXPECT_FALSE(ContainsService(record));
EXPECT_TRUE(ContainsService(record2));
DeleteNonExistingInstance(record);
EXPECT_FALSE(ContainsService(record));
EXPECT_TRUE(ContainsService(record2));
DeleteExistingInstance(record2);
EXPECT_FALSE(ContainsService(record));
EXPECT_FALSE(ContainsService(record2));
}
} // namespace discovery
} // namespace openscreen