blob: 895eb1f6585f42ea01e9dc0651e166a2bb722039 [file] [log] [blame]
/*
* Copyright (C) 2014 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.squareup.okhttp.ws;
import com.squareup.okhttp.OkHttpClient;
import com.squareup.okhttp.Request;
import com.squareup.okhttp.Response;
import com.squareup.okhttp.internal.SslContextBuilder;
import com.squareup.okhttp.mockwebserver.MockResponse;
import com.squareup.okhttp.mockwebserver.MockWebServer;
import com.squareup.okhttp.testing.RecordingHostnameVerifier;
import java.io.IOException;
import java.net.ProtocolException;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLContext;
import okio.Buffer;
import okio.BufferedSink;
import okio.BufferedSource;
import org.junit.After;
import org.junit.Rule;
import org.junit.Test;
import static com.squareup.okhttp.ws.WebSocket.PayloadType.TEXT;
public final class WebSocketCallTest {
@Rule public final MockWebServer server = new MockWebServer();
private final SSLContext sslContext = SslContextBuilder.localhost();
private final WebSocketRecorder listener = new WebSocketRecorder();
private final OkHttpClient client = new OkHttpClient();
private final Random random = new Random(0);
@After public void tearDown() {
listener.assertExhausted();
}
@Test public void clientPingPong() throws IOException {
WebSocketListener serverListener = new EmptyWebSocketListener();
server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
WebSocket webSocket = awaitWebSocket();
webSocket.sendPing(new Buffer().writeUtf8("Hello, WebSockets!"));
listener.assertPong(new Buffer().writeUtf8("Hello, WebSockets!"));
}
@Test public void clientMessage() throws IOException {
WebSocketRecorder serverListener = new WebSocketRecorder();
server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
WebSocket webSocket = awaitWebSocket();
webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello, WebSockets!"));
serverListener.assertTextMessage("Hello, WebSockets!");
}
@Test public void serverMessage() throws IOException {
WebSocketListener serverListener = new EmptyWebSocketListener() {
@Override public void onOpen(final WebSocket webSocket, Response response) {
new Thread() {
@Override public void run() {
try {
webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello, WebSockets!"));
} catch (IOException e) {
throw new AssertionError(e);
}
}
}.start();
}
};
server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
awaitWebSocket();
listener.assertTextMessage("Hello, WebSockets!");
}
@Test public void clientStreamingMessage() throws IOException {
WebSocketRecorder serverListener = new WebSocketRecorder();
server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
WebSocket webSocket = awaitWebSocket();
BufferedSink sink = webSocket.newMessageSink(TEXT);
sink.writeUtf8("Hello, ").flush();
sink.writeUtf8("WebSockets!").flush();
sink.close();
serverListener.assertTextMessage("Hello, WebSockets!");
}
@Test public void serverStreamingMessage() throws IOException {
WebSocketListener serverListener = new EmptyWebSocketListener() {
@Override public void onOpen(final WebSocket webSocket, Response response) {
new Thread() {
@Override public void run() {
try {
BufferedSink sink = webSocket.newMessageSink(TEXT);
sink.writeUtf8("Hello, ").flush();
sink.writeUtf8("WebSockets!").flush();
sink.close();
} catch (IOException e) {
throw new AssertionError(e);
}
}
}.start();
}
};
server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
awaitWebSocket();
listener.assertTextMessage("Hello, WebSockets!");
}
@Test public void okButNotOk() {
server.enqueue(new MockResponse());
awaitWebSocket();
listener.assertFailure(ProtocolException.class, "Expected HTTP 101 response but was '200 OK'");
}
@Test public void notFound() {
server.enqueue(new MockResponse().setStatus("HTTP/1.1 404 Not Found"));
awaitWebSocket();
listener.assertFailure(ProtocolException.class,
"Expected HTTP 101 response but was '404 Not Found'");
}
@Test public void missingConnectionHeader() {
server.enqueue(new MockResponse()
.setResponseCode(101)
.setHeader("Upgrade", "websocket")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
awaitWebSocket();
listener.assertFailure(ProtocolException.class,
"Expected 'Connection' header value 'Upgrade' but was 'null'");
}
@Test public void wrongConnectionHeader() {
server.enqueue(new MockResponse()
.setResponseCode(101)
.setHeader("Upgrade", "websocket")
.setHeader("Connection", "Downgrade")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
awaitWebSocket();
listener.assertFailure(ProtocolException.class,
"Expected 'Connection' header value 'Upgrade' but was 'Downgrade'");
}
@Test public void missingUpgradeHeader() {
server.enqueue(new MockResponse()
.setResponseCode(101)
.setHeader("Connection", "Upgrade")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
awaitWebSocket();
listener.assertFailure(ProtocolException.class,
"Expected 'Upgrade' header value 'websocket' but was 'null'");
}
@Test public void wrongUpgradeHeader() {
server.enqueue(new MockResponse()
.setResponseCode(101)
.setHeader("Connection", "Upgrade")
.setHeader("Upgrade", "Pepsi")
.setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
awaitWebSocket();
listener.assertFailure(ProtocolException.class,
"Expected 'Upgrade' header value 'websocket' but was 'Pepsi'");
}
@Test public void missingMagicHeader() {
server.enqueue(new MockResponse()
.setResponseCode(101)
.setHeader("Connection", "Upgrade")
.setHeader("Upgrade", "websocket"));
awaitWebSocket();
listener.assertFailure(ProtocolException.class,
"Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'null'");
}
@Test public void wrongMagicHeader() {
server.enqueue(new MockResponse()
.setResponseCode(101)
.setHeader("Connection", "Upgrade")
.setHeader("Upgrade", "websocket")
.setHeader("Sec-WebSocket-Accept", "magic"));
awaitWebSocket();
listener.assertFailure(ProtocolException.class,
"Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'magic'");
}
@Test public void wsScheme() throws IOException {
websocketScheme("ws");
}
@Test public void wsUppercaseScheme() throws IOException {
websocketScheme("WS");
}
@Test public void wssScheme() throws IOException {
server.useHttps(sslContext.getSocketFactory(), false);
client.setSslSocketFactory(sslContext.getSocketFactory());
client.setHostnameVerifier(new RecordingHostnameVerifier());
websocketScheme("wss");
}
@Test public void httpsScheme() throws IOException {
server.useHttps(sslContext.getSocketFactory(), false);
client.setSslSocketFactory(sslContext.getSocketFactory());
client.setHostnameVerifier(new RecordingHostnameVerifier());
websocketScheme("https");
}
private void websocketScheme(String scheme) throws IOException {
WebSocketRecorder serverListener = new WebSocketRecorder();
server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
Request request1 = new Request.Builder()
.url(scheme + "://" + server.getHostName() + ":" + server.getPort() + "/")
.build();
WebSocket webSocket = awaitWebSocket(request1);
webSocket.sendMessage(TEXT, new Buffer().writeUtf8("abc"));
serverListener.assertTextMessage("abc");
}
private WebSocket awaitWebSocket() {
return awaitWebSocket(new Request.Builder().get().url(server.url("/")).build());
}
private WebSocket awaitWebSocket(Request request) {
WebSocketCall call = new WebSocketCall(client, request, random);
final AtomicReference<Response> responseRef = new AtomicReference<>();
final AtomicReference<WebSocket> webSocketRef = new AtomicReference<>();
final AtomicReference<IOException> failureRef = new AtomicReference<>();
final CountDownLatch latch = new CountDownLatch(1);
call.enqueue(new WebSocketListener() {
@Override public void onOpen(WebSocket webSocket, Response response) {
webSocketRef.set(webSocket);
responseRef.set(response);
latch.countDown();
}
@Override public void onMessage(BufferedSource payload, WebSocket.PayloadType type)
throws IOException {
listener.onMessage(payload, type);
}
@Override public void onPong(Buffer payload) {
listener.onPong(payload);
}
@Override public void onClose(int code, String reason) {
listener.onClose(code, reason);
}
@Override public void onFailure(IOException e, Response response) {
listener.onFailure(e, null);
failureRef.set(e);
latch.countDown();
}
});
try {
if (!latch.await(10, TimeUnit.SECONDS)) {
throw new AssertionError("Timed out.");
}
} catch (InterruptedException e) {
throw new AssertionError(e);
}
return webSocketRef.get();
}
private static class EmptyWebSocketListener implements WebSocketListener {
@Override public void onOpen(WebSocket webSocket, Response response) {
}
@Override public void onMessage(BufferedSource payload, WebSocket.PayloadType type)
throws IOException {
}
@Override public void onPong(Buffer payload) {
}
@Override public void onClose(int code, String reason) {
}
@Override public void onFailure(IOException e, Response response) {
}
}
}