| // Copyright 2019 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "discovery/public/dns_sd_service_watcher.h" |
| |
| #include <algorithm> |
| |
| #include "gmock/gmock.h" |
| #include "gtest/gtest.h" |
| |
| using testing::_; |
| using testing::ContainerEq; |
| using testing::IsSubsetOf; |
| using testing::IsSupersetOf; |
| using testing::StrictMock; |
| |
| namespace openscreen { |
| namespace discovery { |
| namespace { |
| |
| std::vector<std::string> ConvertRefs( |
| const std::vector<std::reference_wrapper<const std::string>>& value) { |
| std::vector<std::string> strings; |
| |
| // This loop is required to unwrap reference_wrapper objects. |
| for (const std::string& val : value) { |
| strings.push_back(val); |
| } |
| return strings; |
| } |
| |
| static const IPAddress kAddressV4(192, 168, 0, 0); |
| static const IPEndpoint kEndpointV4{kAddressV4, 0}; |
| constexpr char kCastServiceId[] = "_googlecast._tcp"; |
| constexpr char kCastDomainId[] = "local"; |
| constexpr NetworkInterfaceIndex kNetworkInterface = 0; |
| |
| class MockDnsSdService : public DnsSdService { |
| public: |
| MockDnsSdService() : querier_(this) {} |
| |
| DnsSdQuerier* GetQuerier() override { return &querier_; } |
| DnsSdPublisher* GetPublisher() override { return nullptr; } |
| |
| MOCK_METHOD2(StartQuery, |
| void(const std::string& service, DnsSdQuerier::Callback* cb)); |
| MOCK_METHOD2(StopQuery, |
| void(const std::string& service, DnsSdQuerier::Callback* cb)); |
| MOCK_METHOD1(ReinitializeQueries, void(const std::string& service)); |
| |
| private: |
| class MockQuerier : public DnsSdQuerier { |
| public: |
| explicit MockQuerier(MockDnsSdService* service) : mock_service_(service) { |
| OSP_DCHECK(service); |
| } |
| |
| void StartQuery(const std::string& service, |
| DnsSdQuerier::Callback* cb) override { |
| mock_service_->StartQuery(service, cb); |
| } |
| |
| void StopQuery(const std::string& service, |
| DnsSdQuerier::Callback* cb) override { |
| mock_service_->StopQuery(service, cb); |
| } |
| |
| void ReinitializeQueries(const std::string& service) override { |
| mock_service_->ReinitializeQueries(service); |
| } |
| |
| private: |
| MockDnsSdService* const mock_service_; |
| }; |
| |
| MockQuerier querier_; |
| }; |
| |
| } // namespace |
| |
| class TestServiceWatcher : public DnsSdServiceWatcher<std::string> { |
| public: |
| using DnsSdServiceWatcher<std::string>::ConstRefT; |
| |
| explicit TestServiceWatcher(MockDnsSdService* service) |
| : DnsSdServiceWatcher<std::string>( |
| service, |
| kCastServiceId, |
| [this](const DnsSdInstance& instance) { return Convert(instance); }, |
| [this](std::vector<ConstRefT> ref, ConstRefT service, ServicesUpdatedState state) { |
| Callback(std::move(ref)); |
| }) {} |
| |
| MOCK_METHOD1(Callback, void(std::vector<ConstRefT>)); |
| |
| using DnsSdServiceWatcher<std::string>::OnEndpointCreated; |
| using DnsSdServiceWatcher<std::string>::OnEndpointUpdated; |
| using DnsSdServiceWatcher<std::string>::OnEndpointDeleted; |
| |
| private: |
| std::string Convert(const DnsSdInstance& instance) { |
| return instance.instance_id(); |
| } |
| }; |
| |
| class DnsSdServiceWatcherTests : public testing::Test { |
| public: |
| DnsSdServiceWatcherTests() : watcher_(&service_) { |
| // Start service discovery, since all other tests need it |
| EXPECT_FALSE(watcher_.is_running()); |
| EXPECT_CALL(service_, StartQuery(kCastServiceId, _)); |
| watcher_.StartDiscovery(); |
| testing::Mock::VerifyAndClearExpectations(&service_); |
| } |
| |
| protected: |
| void CreateNewInstance(const DnsSdInstanceEndpoint& record) { |
| const std::vector<std::string> services_before = |
| ConvertRefs(watcher_.GetServices()); |
| const size_t count = services_before.size(); |
| |
| std::vector<std::string> callbacked_services; |
| EXPECT_CALL(watcher_, Callback(_)) |
| .WillOnce([services = &callbacked_services]( |
| std::vector<TestServiceWatcher::ConstRefT> value) { |
| *services = ConvertRefs(value); |
| }); |
| watcher_.OnEndpointCreated(record); |
| testing::Mock::VerifyAndClearExpectations(&watcher_); |
| |
| std::vector<std::string> fetched_services = |
| ConvertRefs(watcher_.GetServices()); |
| EXPECT_EQ(fetched_services.size(), count + 1); |
| |
| EXPECT_THAT(fetched_services, ContainerEq(callbacked_services)); |
| EXPECT_THAT(fetched_services, IsSupersetOf(services_before)); |
| } |
| |
| void CreateExistingInstance(const DnsSdInstanceEndpoint& record) { |
| const std::vector<std::string> services_before = |
| ConvertRefs(watcher_.GetServices()); |
| const size_t count = services_before.size(); |
| |
| std::vector<std::string> callbacked_services; |
| EXPECT_CALL(watcher_, Callback(_)) |
| .WillOnce([services = &callbacked_services]( |
| std::vector<TestServiceWatcher::ConstRefT> value) { |
| *services = ConvertRefs(value); |
| }); |
| watcher_.OnEndpointCreated(record); |
| testing::Mock::VerifyAndClearExpectations(&watcher_); |
| |
| const std::vector<std::string> fetched_services = |
| ConvertRefs(watcher_.GetServices()); |
| EXPECT_EQ(fetched_services.size(), count); |
| |
| EXPECT_THAT(fetched_services, ContainerEq(callbacked_services)); |
| EXPECT_THAT(fetched_services, ContainerEq(services_before)); |
| } |
| |
| void UpdateExistingInstance(const DnsSdInstanceEndpoint& record) { |
| const std::vector<std::string> services_before = |
| ConvertRefs(watcher_.GetServices()); |
| const size_t count = services_before.size(); |
| |
| std::vector<std::string> callbacked_services; |
| EXPECT_CALL(watcher_, Callback(_)) |
| .WillOnce([services = &callbacked_services]( |
| std::vector<TestServiceWatcher::ConstRefT> value) { |
| *services = ConvertRefs(value); |
| }); |
| watcher_.OnEndpointUpdated(record); |
| testing::Mock::VerifyAndClearExpectations(&watcher_); |
| |
| const std::vector<std::string> fetched_services = |
| ConvertRefs(watcher_.GetServices()); |
| EXPECT_EQ(fetched_services.size(), count); |
| |
| EXPECT_THAT(fetched_services, ContainerEq(callbacked_services)); |
| EXPECT_THAT(fetched_services, ContainerEq(services_before)); |
| } |
| |
| void DeleteExistingInstance(const DnsSdInstanceEndpoint& record) { |
| const std::vector<std::string> services_before = |
| ConvertRefs(watcher_.GetServices()); |
| const size_t count = services_before.size(); |
| |
| std::vector<std::string> callbacked_services; |
| EXPECT_CALL(watcher_, Callback(_)) |
| .WillOnce([services = &callbacked_services]( |
| std::vector<TestServiceWatcher::ConstRefT> value) { |
| *services = ConvertRefs(value); |
| }); |
| watcher_.OnEndpointDeleted(record); |
| testing::Mock::VerifyAndClearExpectations(&watcher_); |
| |
| const std::vector<std::string> fetched_services = |
| ConvertRefs(watcher_.GetServices()); |
| EXPECT_EQ(fetched_services.size(), count - 1); |
| } |
| |
| void UpdateNonExistingInstance(const DnsSdInstanceEndpoint& record) { |
| const std::vector<std::string> services_before = |
| ConvertRefs(watcher_.GetServices()); |
| const size_t count = services_before.size(); |
| |
| EXPECT_CALL(watcher_, Callback(_)).Times(0); |
| watcher_.OnEndpointUpdated(record); |
| testing::Mock::VerifyAndClearExpectations(&watcher_); |
| |
| const std::vector<std::string> fetched_services = |
| ConvertRefs(watcher_.GetServices()); |
| EXPECT_EQ(fetched_services.size(), count); |
| |
| EXPECT_THAT(services_before, ContainerEq(fetched_services)); |
| } |
| |
| void DeleteNonExistingInstance(const DnsSdInstanceEndpoint& record) { |
| const std::vector<std::string> services_before = |
| ConvertRefs(watcher_.GetServices()); |
| const size_t count = services_before.size(); |
| |
| EXPECT_CALL(watcher_, Callback(_)).Times(0); |
| watcher_.OnEndpointDeleted(record); |
| testing::Mock::VerifyAndClearExpectations(&watcher_); |
| |
| const std::vector<std::string> fetched_services = |
| ConvertRefs(watcher_.GetServices()); |
| EXPECT_EQ(fetched_services.size(), count); |
| |
| EXPECT_THAT(services_before, ContainerEq(fetched_services)); |
| } |
| |
| bool ContainsService(const DnsSdInstanceEndpoint& record) { |
| const std::string& service = record.instance_id(); |
| const std::vector<TestServiceWatcher::ConstRefT> services = |
| watcher_.GetServices(); |
| return std::find_if(services.begin(), services.end(), |
| [&service](const std::string& ref) { |
| return service == ref; |
| }) != services.end(); |
| } |
| |
| StrictMock<MockDnsSdService> service_; |
| StrictMock<TestServiceWatcher> watcher_; |
| std::vector<std::string> fetched_services; |
| }; |
| |
| TEST_F(DnsSdServiceWatcherTests, StartStopDiscoveryWorks) { |
| EXPECT_TRUE(watcher_.is_running()); |
| EXPECT_CALL(service_, StopQuery(kCastServiceId, _)); |
| watcher_.StopDiscovery(); |
| EXPECT_FALSE(watcher_.is_running()); |
| } |
| |
| TEST(DnsSdServiceWatcherTest, RefreshFailsBeforeDiscoveryStarts) { |
| StrictMock<MockDnsSdService> service; |
| StrictMock<TestServiceWatcher> watcher(&service); |
| EXPECT_FALSE(watcher.DiscoverNow().ok()); |
| EXPECT_FALSE(watcher.ForceRefresh().ok()); |
| } |
| |
| TEST_F(DnsSdServiceWatcherTests, RefreshDiscoveryWorks) { |
| const DnsSdInstanceEndpoint record("Instance", kCastServiceId, kCastDomainId, |
| DnsSdTxtRecord{}, kNetworkInterface, |
| kEndpointV4); |
| CreateNewInstance(record); |
| |
| // Refresh services. |
| EXPECT_CALL(service_, ReinitializeQueries(kCastServiceId)); |
| EXPECT_TRUE(watcher_.DiscoverNow().ok()); |
| EXPECT_EQ(watcher_.GetServices().size(), size_t{1}); |
| testing::Mock::VerifyAndClearExpectations(&service_); |
| |
| EXPECT_CALL(service_, ReinitializeQueries(kCastServiceId)); |
| EXPECT_TRUE(watcher_.ForceRefresh().ok()); |
| EXPECT_EQ(watcher_.GetServices().size(), size_t{0}); |
| testing::Mock::VerifyAndClearExpectations(&service_); |
| } |
| |
| TEST_F(DnsSdServiceWatcherTests, CreatingUpdatingDeletingInstancesWork) { |
| const DnsSdInstanceEndpoint record("Instance", kCastServiceId, kCastDomainId, |
| DnsSdTxtRecord{}, kNetworkInterface, |
| kEndpointV4); |
| const DnsSdInstanceEndpoint record2("Instance2", kCastServiceId, |
| kCastDomainId, DnsSdTxtRecord{}, |
| kNetworkInterface, kEndpointV4); |
| |
| EXPECT_FALSE(ContainsService(record)); |
| EXPECT_FALSE(ContainsService(record2)); |
| |
| CreateNewInstance(record); |
| EXPECT_TRUE(ContainsService(record)); |
| EXPECT_FALSE(ContainsService(record2)); |
| |
| CreateExistingInstance(record); |
| EXPECT_TRUE(ContainsService(record)); |
| EXPECT_FALSE(ContainsService(record2)); |
| |
| UpdateNonExistingInstance(record2); |
| EXPECT_TRUE(ContainsService(record)); |
| EXPECT_FALSE(ContainsService(record2)); |
| |
| DeleteNonExistingInstance(record2); |
| EXPECT_TRUE(ContainsService(record)); |
| EXPECT_FALSE(ContainsService(record2)); |
| |
| CreateNewInstance(record2); |
| EXPECT_TRUE(ContainsService(record)); |
| EXPECT_TRUE(ContainsService(record2)); |
| |
| UpdateExistingInstance(record2); |
| EXPECT_TRUE(ContainsService(record)); |
| EXPECT_TRUE(ContainsService(record2)); |
| |
| UpdateExistingInstance(record); |
| EXPECT_TRUE(ContainsService(record)); |
| EXPECT_TRUE(ContainsService(record2)); |
| |
| DeleteExistingInstance(record); |
| EXPECT_FALSE(ContainsService(record)); |
| EXPECT_TRUE(ContainsService(record2)); |
| |
| UpdateNonExistingInstance(record); |
| EXPECT_FALSE(ContainsService(record)); |
| EXPECT_TRUE(ContainsService(record2)); |
| |
| DeleteNonExistingInstance(record); |
| EXPECT_FALSE(ContainsService(record)); |
| EXPECT_TRUE(ContainsService(record2)); |
| |
| DeleteExistingInstance(record2); |
| EXPECT_FALSE(ContainsService(record)); |
| EXPECT_FALSE(ContainsService(record2)); |
| } |
| |
| } // namespace discovery |
| } // namespace openscreen |