blob: 207c454acfbc3a8b09eb97eb4d8a3cfba89e5862 [file] [log] [blame]
// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <utility>
#include <vector>
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/compiler_specific.h"
#include "base/format_macros.h"
#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/message_loop/message_loop.h"
#include "base/message_loop/message_loop_proxy.h"
#include "base/run_loop.h"
#include "base/strings/string_split.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "base/time/time.h"
#include "net/base/address_list.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/base/net_log.h"
#include "net/base/test_completion_callback.h"
#include "net/server/http_server.h"
#include "net/server/http_server_request_info.h"
#include "net/socket/tcp_client_socket.h"
#include "net/socket/tcp_listen_socket.h"
#include "net/url_request/url_fetcher.h"
#include "net/url_request/url_fetcher_delegate.h"
#include "net/url_request/url_request_context.h"
#include "net/url_request/url_request_context_getter.h"
#include "net/url_request/url_request_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net {
namespace {
const int kMaxExpectedResponseLength = 2048;
void SetTimedOutAndQuitLoop(const base::WeakPtr<bool> timed_out,
const base::Closure& quit_loop_func) {
if (timed_out) {
*timed_out = true;
quit_loop_func.Run();
}
}
bool RunLoopWithTimeout(base::RunLoop* run_loop) {
bool timed_out = false;
base::WeakPtrFactory<bool> timed_out_weak_factory(&timed_out);
base::MessageLoop::current()->PostDelayedTask(
FROM_HERE,
base::Bind(&SetTimedOutAndQuitLoop,
timed_out_weak_factory.GetWeakPtr(),
run_loop->QuitClosure()),
base::TimeDelta::FromSeconds(1));
run_loop->Run();
return !timed_out;
}
class TestHttpClient {
public:
TestHttpClient() : connect_result_(OK) {}
int ConnectAndWait(const IPEndPoint& address) {
AddressList addresses(address);
NetLog::Source source;
socket_.reset(new TCPClientSocket(addresses, NULL, source));
base::RunLoop run_loop;
connect_result_ = socket_->Connect(base::Bind(&TestHttpClient::OnConnect,
base::Unretained(this),
run_loop.QuitClosure()));
if (connect_result_ != OK && connect_result_ != ERR_IO_PENDING)
return connect_result_;
if (!RunLoopWithTimeout(&run_loop))
return ERR_TIMED_OUT;
return connect_result_;
}
void Send(const std::string& data) {
write_buffer_ =
new DrainableIOBuffer(new StringIOBuffer(data), data.length());
Write();
}
bool Read(std::string* message) {
return Read(message, 1);
}
bool Read(std::string* message, int expected_bytes) {
int total_bytes_received = 0;
message->clear();
while (total_bytes_received < expected_bytes) {
net::TestCompletionCallback callback;
ReadInternal(callback.callback());
int bytes_received = callback.WaitForResult();
if (bytes_received <= 0)
return false;
total_bytes_received += bytes_received;
message->append(read_buffer_->data(), bytes_received);
}
return true;
}
private:
void OnConnect(const base::Closure& quit_loop, int result) {
connect_result_ = result;
quit_loop.Run();
}
void Write() {
int result = socket_->Write(
write_buffer_.get(),
write_buffer_->BytesRemaining(),
base::Bind(&TestHttpClient::OnWrite, base::Unretained(this)));
if (result != ERR_IO_PENDING)
OnWrite(result);
}
void OnWrite(int result) {
ASSERT_GT(result, 0);
write_buffer_->DidConsume(result);
if (write_buffer_->BytesRemaining())
Write();
}
void ReadInternal(const net::CompletionCallback& callback) {
read_buffer_ = new IOBufferWithSize(kMaxExpectedResponseLength);
int result = socket_->Read(read_buffer_,
kMaxExpectedResponseLength,
callback);
if (result != ERR_IO_PENDING)
callback.Run(result);
}
scoped_refptr<IOBufferWithSize> read_buffer_;
scoped_refptr<DrainableIOBuffer> write_buffer_;
scoped_ptr<TCPClientSocket> socket_;
int connect_result_;
};
} // namespace
class HttpServerTest : public testing::Test,
public HttpServer::Delegate {
public:
HttpServerTest() : quit_after_request_count_(0) {}
virtual void SetUp() OVERRIDE {
TCPListenSocketFactory socket_factory("127.0.0.1", 0);
server_ = new HttpServer(socket_factory, this);
ASSERT_EQ(OK, server_->GetLocalAddress(&server_address_));
}
virtual void OnHttpRequest(int connection_id,
const HttpServerRequestInfo& info) OVERRIDE {
requests_.push_back(std::make_pair(info, connection_id));
if (requests_.size() == quit_after_request_count_)
run_loop_quit_func_.Run();
}
virtual void OnWebSocketRequest(int connection_id,
const HttpServerRequestInfo& info) OVERRIDE {
NOTREACHED();
}
virtual void OnWebSocketMessage(int connection_id,
const std::string& data) OVERRIDE {
NOTREACHED();
}
virtual void OnClose(int connection_id) OVERRIDE {}
bool RunUntilRequestsReceived(size_t count) {
quit_after_request_count_ = count;
if (requests_.size() == count)
return true;
base::RunLoop run_loop;
run_loop_quit_func_ = run_loop.QuitClosure();
bool success = RunLoopWithTimeout(&run_loop);
run_loop_quit_func_.Reset();
return success;
}
HttpServerRequestInfo GetRequest(size_t request_index) {
return requests_[request_index].first;
}
int GetConnectionId(size_t request_index) {
return requests_[request_index].second;
}
protected:
scoped_refptr<HttpServer> server_;
IPEndPoint server_address_;
base::Closure run_loop_quit_func_;
std::vector<std::pair<HttpServerRequestInfo, int> > requests_;
private:
size_t quit_after_request_count_;
};
class WebSocketTest : public HttpServerTest {
virtual void OnHttpRequest(int connection_id,
const HttpServerRequestInfo& info) OVERRIDE {
NOTREACHED();
}
virtual void OnWebSocketRequest(int connection_id,
const HttpServerRequestInfo& info) OVERRIDE {
HttpServerTest::OnHttpRequest(connection_id, info);
}
virtual void OnWebSocketMessage(int connection_id,
const std::string& data) OVERRIDE {
}
};
TEST_F(HttpServerTest, Request) {
TestHttpClient client;
ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
client.Send("GET /test HTTP/1.1\r\n\r\n");
ASSERT_TRUE(RunUntilRequestsReceived(1));
ASSERT_EQ("GET", GetRequest(0).method);
ASSERT_EQ("/test", GetRequest(0).path);
ASSERT_EQ("", GetRequest(0).data);
ASSERT_EQ(0u, GetRequest(0).headers.size());
ASSERT_TRUE(StartsWithASCII(GetRequest(0).peer.ToString(),
"127.0.0.1",
true));
}
TEST_F(HttpServerTest, RequestWithHeaders) {
TestHttpClient client;
ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
const char* kHeaders[][3] = {
{"Header", ": ", "1"},
{"HeaderWithNoWhitespace", ":", "1"},
{"HeaderWithWhitespace", " : \t ", "1 1 1 \t "},
{"HeaderWithColon", ": ", "1:1"},
{"EmptyHeader", ":", ""},
{"EmptyHeaderWithWhitespace", ": \t ", ""},
{"HeaderWithNonASCII", ": ", "\xf7"},
};
std::string headers;
for (size_t i = 0; i < arraysize(kHeaders); ++i) {
headers +=
std::string(kHeaders[i][0]) + kHeaders[i][1] + kHeaders[i][2] + "\r\n";
}
client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
ASSERT_TRUE(RunUntilRequestsReceived(1));
ASSERT_EQ("", GetRequest(0).data);
for (size_t i = 0; i < arraysize(kHeaders); ++i) {
std::string field = StringToLowerASCII(std::string(kHeaders[i][0]));
std::string value = kHeaders[i][2];
ASSERT_EQ(1u, GetRequest(0).headers.count(field)) << field;
ASSERT_EQ(value, GetRequest(0).headers[field]) << kHeaders[i][0];
}
}
TEST_F(HttpServerTest, RequestWithDuplicateHeaders) {
TestHttpClient client;
ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
const char* kHeaders[][3] = {
{"FirstHeader", ": ", "1"},
{"DuplicateHeader", ": ", "2"},
{"MiddleHeader", ": ", "3"},
{"DuplicateHeader", ": ", "4"},
{"LastHeader", ": ", "5"},
};
std::string headers;
for (size_t i = 0; i < arraysize(kHeaders); ++i) {
headers +=
std::string(kHeaders[i][0]) + kHeaders[i][1] + kHeaders[i][2] + "\r\n";
}
client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
ASSERT_TRUE(RunUntilRequestsReceived(1));
ASSERT_EQ("", GetRequest(0).data);
for (size_t i = 0; i < arraysize(kHeaders); ++i) {
std::string field = StringToLowerASCII(std::string(kHeaders[i][0]));
std::string value = (field == "duplicateheader") ? "2,4" : kHeaders[i][2];
ASSERT_EQ(1u, GetRequest(0).headers.count(field)) << field;
ASSERT_EQ(value, GetRequest(0).headers[field]) << kHeaders[i][0];
}
}
TEST_F(HttpServerTest, HasHeaderValueTest) {
TestHttpClient client;
ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
const char* kHeaders[] = {
"Header: Abcd",
"HeaderWithNoWhitespace:E",
"HeaderWithWhitespace : \t f \t ",
"DuplicateHeader: g",
"HeaderWithComma: h, i ,j",
"DuplicateHeader: k",
"EmptyHeader:",
"EmptyHeaderWithWhitespace: \t ",
"HeaderWithNonASCII: \xf7",
};
std::string headers;
for (size_t i = 0; i < arraysize(kHeaders); ++i) {
headers += std::string(kHeaders[i]) + "\r\n";
}
client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
ASSERT_TRUE(RunUntilRequestsReceived(1));
ASSERT_EQ("", GetRequest(0).data);
ASSERT_TRUE(GetRequest(0).HasHeaderValue("header", "abcd"));
ASSERT_FALSE(GetRequest(0).HasHeaderValue("header", "bc"));
ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithnowhitespace", "e"));
ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithwhitespace", "f"));
ASSERT_TRUE(GetRequest(0).HasHeaderValue("duplicateheader", "g"));
ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "h"));
ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "i"));
ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "j"));
ASSERT_TRUE(GetRequest(0).HasHeaderValue("duplicateheader", "k"));
ASSERT_FALSE(GetRequest(0).HasHeaderValue("emptyheader", "x"));
ASSERT_FALSE(GetRequest(0).HasHeaderValue("emptyheaderwithwhitespace", "x"));
ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithnonascii", "\xf7"));
}
TEST_F(HttpServerTest, RequestWithBody) {
TestHttpClient client;
ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
std::string body = "a" + std::string(1 << 10, 'b') + "c";
client.Send(base::StringPrintf(
"GET /test HTTP/1.1\r\n"
"SomeHeader: 1\r\n"
"Content-Length: %" PRIuS "\r\n\r\n%s",
body.length(),
body.c_str()));
ASSERT_TRUE(RunUntilRequestsReceived(1));
ASSERT_EQ(2u, GetRequest(0).headers.size());
ASSERT_EQ(body.length(), GetRequest(0).data.length());
ASSERT_EQ('a', body[0]);
ASSERT_EQ('c', *body.rbegin());
}
TEST_F(WebSocketTest, RequestWebSocket) {
TestHttpClient client;
ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
client.Send(
"GET /test HTTP/1.1\r\n"
"Upgrade: WebSocket\r\n"
"Connection: SomethingElse, Upgrade\r\n"
"Sec-WebSocket-Version: 8\r\n"
"Sec-WebSocket-Key: key\r\n"
"\r\n");
ASSERT_TRUE(RunUntilRequestsReceived(1));
}
TEST_F(HttpServerTest, RequestWithTooLargeBody) {
class TestURLFetcherDelegate : public URLFetcherDelegate {
public:
TestURLFetcherDelegate(const base::Closure& quit_loop_func)
: quit_loop_func_(quit_loop_func) {}
virtual ~TestURLFetcherDelegate() {}
virtual void OnURLFetchComplete(const URLFetcher* source) OVERRIDE {
EXPECT_EQ(HTTP_INTERNAL_SERVER_ERROR, source->GetResponseCode());
quit_loop_func_.Run();
}
private:
base::Closure quit_loop_func_;
};
base::RunLoop run_loop;
TestURLFetcherDelegate delegate(run_loop.QuitClosure());
scoped_refptr<URLRequestContextGetter> request_context_getter(
new TestURLRequestContextGetter(base::MessageLoopProxy::current()));
scoped_ptr<URLFetcher> fetcher(
URLFetcher::Create(GURL(base::StringPrintf("http://127.0.0.1:%d/test",
server_address_.port())),
URLFetcher::GET,
&delegate));
fetcher->SetRequestContext(request_context_getter.get());
fetcher->AddExtraRequestHeader(
base::StringPrintf("content-length:%d", 1 << 30));
fetcher->Start();
ASSERT_TRUE(RunLoopWithTimeout(&run_loop));
ASSERT_EQ(0u, requests_.size());
}
TEST_F(HttpServerTest, Send200) {
TestHttpClient client;
ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
client.Send("GET /test HTTP/1.1\r\n\r\n");
ASSERT_TRUE(RunUntilRequestsReceived(1));
server_->Send200(GetConnectionId(0), "Response!", "text/plain");
std::string response;
ASSERT_TRUE(client.Read(&response));
ASSERT_TRUE(StartsWithASCII(response, "HTTP/1.1 200 OK", true));
ASSERT_TRUE(EndsWith(response, "Response!", true));
}
TEST_F(HttpServerTest, SendRaw) {
TestHttpClient client;
ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
client.Send("GET /test HTTP/1.1\r\n\r\n");
ASSERT_TRUE(RunUntilRequestsReceived(1));
server_->SendRaw(GetConnectionId(0), "Raw Data ");
server_->SendRaw(GetConnectionId(0), "More Data");
server_->SendRaw(GetConnectionId(0), "Third Piece of Data");
const std::string expected_response("Raw Data More DataThird Piece of Data");
std::string response;
ASSERT_TRUE(client.Read(&response, expected_response.length()));
ASSERT_EQ(expected_response, response);
}
namespace {
class MockStreamListenSocket : public StreamListenSocket {
public:
MockStreamListenSocket(StreamListenSocket::Delegate* delegate)
: StreamListenSocket(kInvalidSocket, delegate) {}
virtual void Accept() OVERRIDE { NOTREACHED(); }
private:
virtual ~MockStreamListenSocket() {}
};
} // namespace
TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) {
StreamListenSocket* socket =
new MockStreamListenSocket(server_.get());
server_->DidAccept(NULL, make_scoped_ptr(socket));
std::string body("body");
std::string request_text = base::StringPrintf(
"GET /test HTTP/1.1\r\n"
"SomeHeader: 1\r\n"
"Content-Length: %" PRIuS "\r\n\r\n%s",
body.length(),
body.c_str());
server_->DidRead(socket, request_text.c_str(), request_text.length() - 2);
ASSERT_EQ(0u, requests_.size());
server_->DidRead(socket, request_text.c_str() + request_text.length() - 2, 2);
ASSERT_EQ(1u, requests_.size());
ASSERT_EQ(body, GetRequest(0).data);
}
TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
// The idea behind this test is that requests with or without bodies should
// not break parsing of the next request.
TestHttpClient client;
ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
std::string body = "body";
client.Send(base::StringPrintf(
"GET /test HTTP/1.1\r\n"
"Content-Length: %" PRIuS "\r\n\r\n%s",
body.length(),
body.c_str()));
ASSERT_TRUE(RunUntilRequestsReceived(1));
ASSERT_EQ(body, GetRequest(0).data);
int client_connection_id = GetConnectionId(0);
server_->Send200(client_connection_id, "Content for /test", "text/plain");
std::string response1;
ASSERT_TRUE(client.Read(&response1));
ASSERT_TRUE(StartsWithASCII(response1, "HTTP/1.1 200 OK", true));
ASSERT_TRUE(EndsWith(response1, "Content for /test", true));
client.Send("GET /test2 HTTP/1.1\r\n\r\n");
ASSERT_TRUE(RunUntilRequestsReceived(2));
ASSERT_EQ("/test2", GetRequest(1).path);
ASSERT_EQ(client_connection_id, GetConnectionId(1));
server_->Send404(client_connection_id);
std::string response2;
ASSERT_TRUE(client.Read(&response2));
ASSERT_TRUE(StartsWithASCII(response2, "HTTP/1.1 404 Not Found", true));
client.Send("GET /test3 HTTP/1.1\r\n\r\n");
ASSERT_TRUE(RunUntilRequestsReceived(3));
ASSERT_EQ("/test3", GetRequest(2).path);
ASSERT_EQ(client_connection_id, GetConnectionId(2));
server_->Send200(client_connection_id, "Content for /test3", "text/plain");
std::string response3;
ASSERT_TRUE(client.Read(&response3));
ASSERT_TRUE(StartsWithASCII(response3, "HTTP/1.1 200 OK", true));
ASSERT_TRUE(EndsWith(response3, "Content for /test3", true));
}
} // namespace net