blob: c52bf6dd5ce64ca386efad81e0396d4e05ac32cc [file] [log] [blame]
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROME_BROWSER_EXTENSIONS_API_SOCKET_SOCKET_API_H_
#define CHROME_BROWSER_EXTENSIONS_API_SOCKET_SOCKET_API_H_
#include "base/gtest_prod_util.h"
#include "base/memory/ref_counted.h"
#include "chrome/browser/extensions/api/api_function.h"
#include "chrome/browser/extensions/api/api_resource_manager.h"
#include "chrome/common/extensions/api/socket.h"
#include "extensions/browser/extension_function.h"
#include "net/base/address_list.h"
#include "net/dns/host_resolver.h"
#include "net/socket/tcp_client_socket.h"
#include <string>
class IOThread;
namespace net {
class IOBuffer;
}
namespace extensions {
class Socket;
// A simple interface to ApiResourceManager<Socket> or derived class. The goal
// of this interface is to allow Socket API functions to use distinct instances
// of ApiResourceManager<> depending on the type of socket (old version in
// "socket" namespace vs new version in "socket.xxx" namespaces).
class SocketResourceManagerInterface {
public:
virtual ~SocketResourceManagerInterface() {}
virtual bool SetProfile(Profile* profile) = 0;
virtual int Add(Socket *socket) = 0;
virtual Socket* Get(const std::string& extension_id,
int api_resource_id) = 0;
virtual void Remove(const std::string& extension_id,
int api_resource_id) = 0;
virtual base::hash_set<int>* GetResourceIds(
const std::string& extension_id) = 0;
};
// Implementation of SocketResourceManagerInterface using an
// ApiResourceManager<T> instance (where T derives from Socket).
template<typename T>
class SocketResourceManager : public SocketResourceManagerInterface {
public:
SocketResourceManager()
: manager_(NULL) {
}
virtual bool SetProfile(Profile* profile) OVERRIDE {
manager_ = ApiResourceManager<T>::Get(profile);
DCHECK(manager_) << "There is no socket manager. "
"If this assertion is failing during a test, then it is likely that "
"TestExtensionSystem is failing to provide an instance of "
"ApiResourceManager<Socket>.";
return manager_ != NULL;
}
virtual int Add(Socket *socket) OVERRIDE {
// Note: Cast needed here, because "T" may be a subclass of "Socket".
return manager_->Add(static_cast<T*>(socket));
}
virtual Socket* Get(const std::string& extension_id,
int api_resource_id) OVERRIDE {
return manager_->Get(extension_id, api_resource_id);
}
virtual void Remove(const std::string& extension_id,
int api_resource_id) OVERRIDE {
manager_->Remove(extension_id, api_resource_id);
}
virtual base::hash_set<int>* GetResourceIds(
const std::string& extension_id) OVERRIDE {
return manager_->GetResourceIds(extension_id);
}
private:
ApiResourceManager<T>* manager_;
};
class SocketAsyncApiFunction : public AsyncApiFunction {
public:
SocketAsyncApiFunction();
protected:
virtual ~SocketAsyncApiFunction();
// AsyncApiFunction:
virtual bool PrePrepare() OVERRIDE;
virtual bool Respond() OVERRIDE;
virtual scoped_ptr<SocketResourceManagerInterface>
CreateSocketResourceManager();
int AddSocket(Socket* socket);
Socket* GetSocket(int api_resource_id);
void RemoveSocket(int api_resource_id);
base::hash_set<int>* GetSocketIds();
private:
scoped_ptr<SocketResourceManagerInterface> manager_;
};
class SocketExtensionWithDnsLookupFunction : public SocketAsyncApiFunction {
protected:
SocketExtensionWithDnsLookupFunction();
virtual ~SocketExtensionWithDnsLookupFunction();
void StartDnsLookup(const std::string& hostname);
virtual void AfterDnsLookup(int lookup_result) = 0;
std::string resolved_address_;
private:
void OnDnsLookup(int resolve_result);
// This instance is widely available through BrowserProcess, but we need to
// acquire it on the UI thread and then use it on the IO thread, so we keep a
// plain pointer to it here as we move from thread to thread.
IOThread* io_thread_;
scoped_ptr<net::HostResolver::RequestHandle> request_handle_;
scoped_ptr<net::AddressList> addresses_;
};
class SocketCreateFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.create", SOCKET_CREATE)
SocketCreateFunction();
protected:
virtual ~SocketCreateFunction();
// AsyncApiFunction:
virtual bool Prepare() OVERRIDE;
virtual void Work() OVERRIDE;
private:
FRIEND_TEST_ALL_PREFIXES(SocketUnitTest, Create);
enum SocketType {
kSocketTypeInvalid = -1,
kSocketTypeTCP,
kSocketTypeUDP
};
scoped_ptr<api::socket::Create::Params> params_;
SocketType socket_type_;
};
class SocketDestroyFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.destroy", SOCKET_DESTROY)
protected:
virtual ~SocketDestroyFunction() {}
// AsyncApiFunction:
virtual bool Prepare() OVERRIDE;
virtual void Work() OVERRIDE;
private:
int socket_id_;
};
class SocketConnectFunction : public SocketExtensionWithDnsLookupFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.connect", SOCKET_CONNECT)
SocketConnectFunction();
protected:
virtual ~SocketConnectFunction();
// AsyncApiFunction:
virtual bool Prepare() OVERRIDE;
virtual void AsyncWorkStart() OVERRIDE;
// SocketExtensionWithDnsLookupFunction:
virtual void AfterDnsLookup(int lookup_result) OVERRIDE;
private:
void StartConnect();
void OnConnect(int result);
int socket_id_;
std::string hostname_;
int port_;
Socket* socket_;
};
class SocketDisconnectFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.disconnect", SOCKET_DISCONNECT)
protected:
virtual ~SocketDisconnectFunction() {}
// AsyncApiFunction:
virtual bool Prepare() OVERRIDE;
virtual void Work() OVERRIDE;
private:
int socket_id_;
};
class SocketBindFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.bind", SOCKET_BIND)
protected:
virtual ~SocketBindFunction() {}
// AsyncApiFunction:
virtual bool Prepare() OVERRIDE;
virtual void Work() OVERRIDE;
private:
int socket_id_;
std::string address_;
int port_;
};
class SocketListenFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.listen", SOCKET_LISTEN)
SocketListenFunction();
protected:
virtual ~SocketListenFunction();
// AsyncApiFunction:
virtual bool Prepare() OVERRIDE;
virtual void Work() OVERRIDE;
private:
scoped_ptr<api::socket::Listen::Params> params_;
};
class SocketAcceptFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.accept", SOCKET_ACCEPT)
SocketAcceptFunction();
protected:
virtual ~SocketAcceptFunction();
// AsyncApiFunction:
virtual bool Prepare() OVERRIDE;
virtual void AsyncWorkStart() OVERRIDE;
private:
void OnAccept(int result_code, net::TCPClientSocket *socket);
scoped_ptr<api::socket::Accept::Params> params_;
};
class SocketReadFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.read", SOCKET_READ)
SocketReadFunction();
protected:
virtual ~SocketReadFunction();
// AsyncApiFunction:
virtual bool Prepare() OVERRIDE;
virtual void AsyncWorkStart() OVERRIDE;
void OnCompleted(int result, scoped_refptr<net::IOBuffer> io_buffer);
private:
scoped_ptr<api::socket::Read::Params> params_;
};
class SocketWriteFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.write", SOCKET_WRITE)
SocketWriteFunction();
protected:
virtual ~SocketWriteFunction();
// AsyncApiFunction:
virtual bool Prepare() OVERRIDE;
virtual void AsyncWorkStart() OVERRIDE;
void OnCompleted(int result);
private:
int socket_id_;
scoped_refptr<net::IOBuffer> io_buffer_;
size_t io_buffer_size_;
};
class SocketRecvFromFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.recvFrom", SOCKET_RECVFROM)
SocketRecvFromFunction();
protected:
virtual ~SocketRecvFromFunction();
// AsyncApiFunction
virtual bool Prepare() OVERRIDE;
virtual void AsyncWorkStart() OVERRIDE;
void OnCompleted(int result,
scoped_refptr<net::IOBuffer> io_buffer,
const std::string& address,
int port);
private:
scoped_ptr<api::socket::RecvFrom::Params> params_;
};
class SocketSendToFunction : public SocketExtensionWithDnsLookupFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.sendTo", SOCKET_SENDTO)
SocketSendToFunction();
protected:
virtual ~SocketSendToFunction();
// AsyncApiFunction:
virtual bool Prepare() OVERRIDE;
virtual void AsyncWorkStart() OVERRIDE;
void OnCompleted(int result);
// SocketExtensionWithDnsLookupFunction:
virtual void AfterDnsLookup(int lookup_result) OVERRIDE;
private:
void StartSendTo();
int socket_id_;
scoped_refptr<net::IOBuffer> io_buffer_;
size_t io_buffer_size_;
std::string hostname_;
int port_;
Socket* socket_;
};
class SocketSetKeepAliveFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.setKeepAlive", SOCKET_SETKEEPALIVE)
SocketSetKeepAliveFunction();
protected:
virtual ~SocketSetKeepAliveFunction();
// AsyncApiFunction:
virtual bool Prepare() OVERRIDE;
virtual void Work() OVERRIDE;
private:
scoped_ptr<api::socket::SetKeepAlive::Params> params_;
};
class SocketSetNoDelayFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.setNoDelay", SOCKET_SETNODELAY)
SocketSetNoDelayFunction();
protected:
virtual ~SocketSetNoDelayFunction();
// AsyncApiFunction:
virtual bool Prepare() OVERRIDE;
virtual void Work() OVERRIDE;
private:
scoped_ptr<api::socket::SetNoDelay::Params> params_;
};
class SocketGetInfoFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.getInfo", SOCKET_GETINFO)
SocketGetInfoFunction();
protected:
virtual ~SocketGetInfoFunction();
// AsyncApiFunction:
virtual bool Prepare() OVERRIDE;
virtual void Work() OVERRIDE;
private:
scoped_ptr<api::socket::GetInfo::Params> params_;
};
class SocketGetNetworkListFunction : public AsyncExtensionFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.getNetworkList", SOCKET_GETNETWORKLIST)
protected:
virtual ~SocketGetNetworkListFunction() {}
virtual bool RunImpl() OVERRIDE;
private:
void GetNetworkListOnFileThread();
void HandleGetNetworkListError();
void SendResponseOnUIThread(const net::NetworkInterfaceList& interface_list);
};
class SocketJoinGroupFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.joinGroup", SOCKET_MULTICAST_JOIN_GROUP)
SocketJoinGroupFunction();
protected:
virtual ~SocketJoinGroupFunction();
// AsyncApiFunction
virtual bool Prepare() OVERRIDE;
virtual void Work() OVERRIDE;
private:
scoped_ptr<api::socket::JoinGroup::Params> params_;
};
class SocketLeaveGroupFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.leaveGroup", SOCKET_MULTICAST_LEAVE_GROUP)
SocketLeaveGroupFunction();
protected:
virtual ~SocketLeaveGroupFunction();
// AsyncApiFunction
virtual bool Prepare() OVERRIDE;
virtual void Work() OVERRIDE;
private:
scoped_ptr<api::socket::LeaveGroup::Params> params_;
};
class SocketSetMulticastTimeToLiveFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.setMulticastTimeToLive",
SOCKET_MULTICAST_SET_TIME_TO_LIVE)
SocketSetMulticastTimeToLiveFunction();
protected:
virtual ~SocketSetMulticastTimeToLiveFunction();
// AsyncApiFunction
virtual bool Prepare() OVERRIDE;
virtual void Work() OVERRIDE;
private:
scoped_ptr<api::socket::SetMulticastTimeToLive::Params> params_;
};
class SocketSetMulticastLoopbackModeFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.setMulticastLoopbackMode",
SOCKET_MULTICAST_SET_LOOPBACK_MODE)
SocketSetMulticastLoopbackModeFunction();
protected:
virtual ~SocketSetMulticastLoopbackModeFunction();
// AsyncApiFunction
virtual bool Prepare() OVERRIDE;
virtual void Work() OVERRIDE;
private:
scoped_ptr<api::socket::SetMulticastLoopbackMode::Params> params_;
};
class SocketGetJoinedGroupsFunction : public SocketAsyncApiFunction {
public:
DECLARE_EXTENSION_FUNCTION("socket.getJoinedGroups",
SOCKET_MULTICAST_GET_JOINED_GROUPS)
SocketGetJoinedGroupsFunction();
protected:
virtual ~SocketGetJoinedGroupsFunction();
// AsyncApiFunction
virtual bool Prepare() OVERRIDE;
virtual void Work() OVERRIDE;
private:
scoped_ptr<api::socket::GetJoinedGroups::Params> params_;
};
} // namespace extensions
#endif // CHROME_BROWSER_EXTENSIONS_API_SOCKET_SOCKET_API_H_