blob: badb1ae5ab94f49b039bfb3b4acd15ec5b694abe [file] [log] [blame]
// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/utility/local_discovery/service_discovery_message_handler.h"
#include <algorithm>
#include "base/lazy_instance.h"
#include "chrome/common/local_discovery/local_discovery_messages.h"
#include "chrome/common/local_discovery/service_discovery_client_impl.h"
#include "content/public/utility/utility_thread.h"
#include "net/socket/socket_descriptor.h"
#include "net/udp/datagram_server_socket.h"
namespace local_discovery {
namespace {
void ClosePlatformSocket(net::SocketDescriptor socket);
// Sets socket factory used by |net::CreatePlatformSocket|. Implemetation
// keeps single socket that will be returned to the first call to
// |net::CreatePlatformSocket| during object lifetime.
class ScopedSocketFactory : public net::PlatformSocketFactory {
public:
explicit ScopedSocketFactory(net::SocketDescriptor socket) : socket_(socket) {
net::PlatformSocketFactory::SetInstance(this);
}
virtual ~ScopedSocketFactory() {
net::PlatformSocketFactory::SetInstance(NULL);
ClosePlatformSocket(socket_);
socket_ = net::kInvalidSocket;
}
virtual net::SocketDescriptor CreateSocket(int family, int type,
int protocol) OVERRIDE {
DCHECK_EQ(type, SOCK_DGRAM);
DCHECK(family == AF_INET || family == AF_INET6);
net::SocketDescriptor result = net::kInvalidSocket;
std::swap(result, socket_);
return result;
}
private:
net::SocketDescriptor socket_;
DISALLOW_COPY_AND_ASSIGN(ScopedSocketFactory);
};
struct SocketInfo {
SocketInfo(net::SocketDescriptor socket,
net::AddressFamily address_family,
uint32 interface_index)
: socket(socket),
address_family(address_family),
interface_index(interface_index) {
}
net::SocketDescriptor socket;
net::AddressFamily address_family;
uint32 interface_index;
};
// Returns list of sockets preallocated before.
class PreCreatedMDnsSocketFactory : public net::MDnsSocketFactory {
public:
PreCreatedMDnsSocketFactory() {}
virtual ~PreCreatedMDnsSocketFactory() {
// Not empty if process exits too fast, before starting mDns code. If
// happened, destructors may crash accessing destroyed global objects.
sockets_.weak_clear();
}
// net::MDnsSocketFactory implementation:
virtual void CreateSockets(
ScopedVector<net::DatagramServerSocket>* sockets) OVERRIDE {
sockets->swap(sockets_);
Reset();
}
void AddSocket(const SocketInfo& socket_info) {
// Takes ownership of socket_info.socket;
ScopedSocketFactory platform_factory(socket_info.socket);
scoped_ptr<net::DatagramServerSocket> socket(
net::CreateAndBindMDnsSocket(socket_info.address_family,
socket_info.interface_index));
if (socket) {
socket->DetachFromThread();
sockets_.push_back(socket.release());
}
}
void Reset() {
sockets_.clear();
}
private:
ScopedVector<net::DatagramServerSocket> sockets_;
DISALLOW_COPY_AND_ASSIGN(PreCreatedMDnsSocketFactory);
};
base::LazyInstance<PreCreatedMDnsSocketFactory>
g_local_discovery_socket_factory = LAZY_INSTANCE_INITIALIZER;
#if defined(OS_WIN)
void ClosePlatformSocket(net::SocketDescriptor socket) {
::closesocket(socket);
}
void StaticInitializeSocketFactory() {
net::InterfaceIndexFamilyList interfaces(net::GetMDnsInterfacesToBind());
for (size_t i = 0; i < interfaces.size(); ++i) {
DCHECK(interfaces[i].second == net::ADDRESS_FAMILY_IPV4 ||
interfaces[i].second == net::ADDRESS_FAMILY_IPV6);
net::SocketDescriptor descriptor =
net::CreatePlatformSocket(
net::ConvertAddressFamily(interfaces[i].second), SOCK_DGRAM,
IPPROTO_UDP);
g_local_discovery_socket_factory.Get().AddSocket(
SocketInfo(descriptor, interfaces[i].second, interfaces[i].first));
}
}
#else // OS_WIN
void ClosePlatformSocket(net::SocketDescriptor socket) {
::close(socket);
}
void StaticInitializeSocketFactory() {
}
#endif // OS_WIN
void SendHostMessageOnUtilityThread(IPC::Message* msg) {
content::UtilityThread::Get()->Send(msg);
}
std::string WatcherUpdateToString(ServiceWatcher::UpdateType update) {
switch (update) {
case ServiceWatcher::UPDATE_ADDED:
return "UPDATE_ADDED";
case ServiceWatcher::UPDATE_CHANGED:
return "UPDATE_CHANGED";
case ServiceWatcher::UPDATE_REMOVED:
return "UPDATE_REMOVED";
case ServiceWatcher::UPDATE_INVALIDATED:
return "UPDATE_INVALIDATED";
}
return "Unknown Update";
}
std::string ResolverStatusToString(ServiceResolver::RequestStatus status) {
switch (status) {
case ServiceResolver::STATUS_SUCCESS:
return "STATUS_SUCESS";
case ServiceResolver::STATUS_REQUEST_TIMEOUT:
return "STATUS_REQUEST_TIMEOUT";
case ServiceResolver::STATUS_KNOWN_NONEXISTENT:
return "STATUS_KNOWN_NONEXISTENT";
}
return "Unknown Status";
}
} // namespace
ServiceDiscoveryMessageHandler::ServiceDiscoveryMessageHandler() {
}
ServiceDiscoveryMessageHandler::~ServiceDiscoveryMessageHandler() {
DCHECK(!discovery_thread_);
}
void ServiceDiscoveryMessageHandler::PreSandboxStartup() {
StaticInitializeSocketFactory();
}
void ServiceDiscoveryMessageHandler::InitializeMdns() {
if (service_discovery_client_ || mdns_client_)
return;
mdns_client_ = net::MDnsClient::CreateDefault();
bool result =
mdns_client_->StartListening(g_local_discovery_socket_factory.Pointer());
// Close unused sockets.
g_local_discovery_socket_factory.Get().Reset();
if (!result) {
VLOG(1) << "Failed to start MDnsClient";
Send(new LocalDiscoveryHostMsg_Error());
return;
}
service_discovery_client_.reset(
new local_discovery::ServiceDiscoveryClientImpl(mdns_client_.get()));
}
bool ServiceDiscoveryMessageHandler::InitializeThread() {
if (discovery_task_runner_)
return true;
if (discovery_thread_)
return false;
utility_task_runner_ = base::MessageLoop::current()->message_loop_proxy();
discovery_thread_.reset(new base::Thread("ServiceDiscoveryThread"));
base::Thread::Options thread_options(base::MessageLoop::TYPE_IO, 0);
if (discovery_thread_->StartWithOptions(thread_options)) {
discovery_task_runner_ = discovery_thread_->message_loop_proxy();
discovery_task_runner_->PostTask(FROM_HERE,
base::Bind(&ServiceDiscoveryMessageHandler::InitializeMdns,
base::Unretained(this)));
}
return discovery_task_runner_ != NULL;
}
bool ServiceDiscoveryMessageHandler::OnMessageReceived(
const IPC::Message& message) {
bool handled = true;
IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryMessageHandler, message)
#if defined(OS_POSIX)
IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetSockets, OnSetSockets)
#endif // OS_POSIX
IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_StartWatcher, OnStartWatcher)
IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DiscoverServices, OnDiscoverServices)
IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetActivelyRefreshServices,
OnSetActivelyRefreshServices)
IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyWatcher, OnDestroyWatcher)
IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveService, OnResolveService)
IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyResolver, OnDestroyResolver)
IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveLocalDomain,
OnResolveLocalDomain)
IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyLocalDomainResolver,
OnDestroyLocalDomainResolver)
IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ShutdownLocalDiscovery,
ShutdownLocalDiscovery)
IPC_MESSAGE_UNHANDLED(handled = false)
IPC_END_MESSAGE_MAP()
return handled;
}
void ServiceDiscoveryMessageHandler::PostTask(
const tracked_objects::Location& from_here,
const base::Closure& task) {
if (!InitializeThread())
return;
discovery_task_runner_->PostTask(from_here, task);
}
#if defined(OS_POSIX)
void ServiceDiscoveryMessageHandler::OnSetSockets(
const std::vector<LocalDiscoveryMsg_SocketInfo>& sockets) {
for (size_t i = 0; i < sockets.size(); ++i) {
g_local_discovery_socket_factory.Get().AddSocket(
SocketInfo(sockets[i].descriptor.fd, sockets[i].address_family,
sockets[i].interface_index));
}
}
#endif // OS_POSIX
void ServiceDiscoveryMessageHandler::OnStartWatcher(
uint64 id,
const std::string& service_type) {
PostTask(FROM_HERE,
base::Bind(&ServiceDiscoveryMessageHandler::StartWatcher,
base::Unretained(this), id, service_type));
}
void ServiceDiscoveryMessageHandler::OnDiscoverServices(uint64 id,
bool force_update) {
PostTask(FROM_HERE,
base::Bind(&ServiceDiscoveryMessageHandler::DiscoverServices,
base::Unretained(this), id, force_update));
}
void ServiceDiscoveryMessageHandler::OnSetActivelyRefreshServices(
uint64 id, bool actively_refresh_services) {
PostTask(FROM_HERE,
base::Bind(
&ServiceDiscoveryMessageHandler::SetActivelyRefreshServices,
base::Unretained(this), id, actively_refresh_services));
}
void ServiceDiscoveryMessageHandler::OnDestroyWatcher(uint64 id) {
PostTask(FROM_HERE,
base::Bind(&ServiceDiscoveryMessageHandler::DestroyWatcher,
base::Unretained(this), id));
}
void ServiceDiscoveryMessageHandler::OnResolveService(
uint64 id,
const std::string& service_name) {
PostTask(FROM_HERE,
base::Bind(&ServiceDiscoveryMessageHandler::ResolveService,
base::Unretained(this), id, service_name));
}
void ServiceDiscoveryMessageHandler::OnDestroyResolver(uint64 id) {
PostTask(FROM_HERE,
base::Bind(&ServiceDiscoveryMessageHandler::DestroyResolver,
base::Unretained(this), id));
}
void ServiceDiscoveryMessageHandler::OnResolveLocalDomain(
uint64 id, const std::string& domain,
net::AddressFamily address_family) {
PostTask(FROM_HERE,
base::Bind(&ServiceDiscoveryMessageHandler::ResolveLocalDomain,
base::Unretained(this), id, domain, address_family));
}
void ServiceDiscoveryMessageHandler::OnDestroyLocalDomainResolver(uint64 id) {
PostTask(FROM_HERE,
base::Bind(
&ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver,
base::Unretained(this), id));
}
void ServiceDiscoveryMessageHandler::StartWatcher(
uint64 id,
const std::string& service_type) {
VLOG(1) << "StartWatcher, id=" << id << ", type=" << service_type;
if (!service_discovery_client_)
return;
DCHECK(!ContainsKey(service_watchers_, id));
scoped_ptr<ServiceWatcher> watcher(
service_discovery_client_->CreateServiceWatcher(
service_type,
base::Bind(&ServiceDiscoveryMessageHandler::OnServiceUpdated,
base::Unretained(this), id)));
watcher->Start();
service_watchers_[id].reset(watcher.release());
}
void ServiceDiscoveryMessageHandler::DiscoverServices(uint64 id,
bool force_update) {
VLOG(1) << "DiscoverServices, id=" << id;
if (!service_discovery_client_)
return;
DCHECK(ContainsKey(service_watchers_, id));
service_watchers_[id]->DiscoverNewServices(force_update);
}
void ServiceDiscoveryMessageHandler::SetActivelyRefreshServices(
uint64 id,
bool actively_refresh_services) {
VLOG(1) << "ActivelyRefreshServices, id=" << id;
if (!service_discovery_client_)
return;
DCHECK(ContainsKey(service_watchers_, id));
service_watchers_[id]->SetActivelyRefreshServices(actively_refresh_services);
}
void ServiceDiscoveryMessageHandler::DestroyWatcher(uint64 id) {
VLOG(1) << "DestoryWatcher, id=" << id;
if (!service_discovery_client_)
return;
service_watchers_.erase(id);
}
void ServiceDiscoveryMessageHandler::ResolveService(
uint64 id,
const std::string& service_name) {
VLOG(1) << "ResolveService, id=" << id << ", name=" << service_name;
if (!service_discovery_client_)
return;
DCHECK(!ContainsKey(service_resolvers_, id));
scoped_ptr<ServiceResolver> resolver(
service_discovery_client_->CreateServiceResolver(
service_name,
base::Bind(&ServiceDiscoveryMessageHandler::OnServiceResolved,
base::Unretained(this), id)));
resolver->StartResolving();
service_resolvers_[id].reset(resolver.release());
}
void ServiceDiscoveryMessageHandler::DestroyResolver(uint64 id) {
VLOG(1) << "DestroyResolver, id=" << id;
if (!service_discovery_client_)
return;
service_resolvers_.erase(id);
}
void ServiceDiscoveryMessageHandler::ResolveLocalDomain(
uint64 id,
const std::string& domain,
net::AddressFamily address_family) {
VLOG(1) << "ResolveLocalDomain, id=" << id << ", domain=" << domain;
if (!service_discovery_client_)
return;
DCHECK(!ContainsKey(local_domain_resolvers_, id));
scoped_ptr<LocalDomainResolver> resolver(
service_discovery_client_->CreateLocalDomainResolver(
domain, address_family,
base::Bind(&ServiceDiscoveryMessageHandler::OnLocalDomainResolved,
base::Unretained(this), id)));
resolver->Start();
local_domain_resolvers_[id].reset(resolver.release());
}
void ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver(uint64 id) {
VLOG(1) << "DestroyLocalDomainResolver, id=" << id;
if (!service_discovery_client_)
return;
local_domain_resolvers_.erase(id);
}
void ServiceDiscoveryMessageHandler::ShutdownLocalDiscovery() {
if (!discovery_task_runner_)
return;
discovery_task_runner_->PostTask(
FROM_HERE,
base::Bind(&ServiceDiscoveryMessageHandler::ShutdownOnIOThread,
base::Unretained(this)));
// This will wait for message loop to drain, so ShutdownOnIOThread will
// definitely be called.
discovery_thread_.reset();
}
void ServiceDiscoveryMessageHandler::ShutdownOnIOThread() {
VLOG(1) << "ShutdownLocalDiscovery";
service_watchers_.clear();
service_resolvers_.clear();
local_domain_resolvers_.clear();
service_discovery_client_.reset();
mdns_client_.reset();
}
void ServiceDiscoveryMessageHandler::OnServiceUpdated(
uint64 id,
ServiceWatcher::UpdateType update,
const std::string& name) {
VLOG(1) << "OnServiceUpdated, id=" << id
<< ", status=" << WatcherUpdateToString(update) << ", name=" << name;
DCHECK(service_discovery_client_);
Send(new LocalDiscoveryHostMsg_WatcherCallback(id, update, name));
}
void ServiceDiscoveryMessageHandler::OnServiceResolved(
uint64 id,
ServiceResolver::RequestStatus status,
const ServiceDescription& description) {
VLOG(1) << "OnServiceResolved, id=" << id
<< ", status=" << ResolverStatusToString(status)
<< ", name=" << description.service_name;
DCHECK(service_discovery_client_);
Send(new LocalDiscoveryHostMsg_ResolverCallback(id, status, description));
}
void ServiceDiscoveryMessageHandler::OnLocalDomainResolved(
uint64 id,
bool success,
const net::IPAddressNumber& address_ipv4,
const net::IPAddressNumber& address_ipv6) {
VLOG(1) << "OnLocalDomainResolved, id=" << id
<< ", IPv4=" << (address_ipv4.empty() ? "" :
net::IPAddressToString(address_ipv4))
<< ", IPv6=" << (address_ipv6.empty() ? "" :
net::IPAddressToString(address_ipv6));
DCHECK(service_discovery_client_);
Send(new LocalDiscoveryHostMsg_LocalDomainResolverCallback(
id, success, address_ipv4, address_ipv6));
}
void ServiceDiscoveryMessageHandler::Send(IPC::Message* msg) {
utility_task_runner_->PostTask(FROM_HERE,
base::Bind(&SendHostMessageOnUtilityThread,
msg));
}
} // namespace local_discovery