Add throttling support to MockWebServer

This adds non-blocking throttling support to MockWebServer.
Most of the changes are patched across from OkHttp's version
(minus SPDY changes).

The motivation is to make an upstream OkHttp change easier
to apply, but having fewer differences with the OkHttp version
should be beneficial.

Bug: 18083851
(cherry picked from commit d8f241c21b3e2e8f94648040b1b62d7b12491d4d)

Change-Id: I7e23675cf0366028392e87c851b30e7d6dddb989
diff --git a/src/main/java/com/google/mockwebserver/Dispatcher.java b/src/main/java/com/google/mockwebserver/Dispatcher.java
index 0456025..48541a4 100644
--- a/src/main/java/com/google/mockwebserver/Dispatcher.java
+++ b/src/main/java/com/google/mockwebserver/Dispatcher.java
@@ -26,11 +26,13 @@
     public abstract MockResponse dispatch(RecordedRequest request) throws InterruptedException;
 
     /**
-     * Returns the socket policy of the next request.  Default implementation
-     * returns {@link SocketPolicy#KEEP_OPEN}. Mischievous implementations can
-     * return other values to test HTTP edge cases.
+     * Returns an early guess of the next response, used for policy on how an
+     * incoming request should be received. The default implementation returns an
+     * empty response. Mischievous implementations can return other values to test
+     * HTTP edge cases, such as unhappy socket policies or throttled request
+     * bodies.
      */
-    public SocketPolicy peekSocketPolicy() {
-        return SocketPolicy.KEEP_OPEN;
+    public MockResponse peek() {
+        return new MockResponse().setSocketPolicy(SocketPolicy.KEEP_OPEN);
     }
 }
diff --git a/src/main/java/com/google/mockwebserver/MockResponse.java b/src/main/java/com/google/mockwebserver/MockResponse.java
index 7bca741..665d85a 100644
--- a/src/main/java/com/google/mockwebserver/MockResponse.java
+++ b/src/main/java/com/google/mockwebserver/MockResponse.java
@@ -16,7 +16,6 @@
 
 package com.google.mockwebserver;
 
-import static com.google.mockwebserver.MockWebServer.ASCII;
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
@@ -25,6 +24,9 @@
 import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+import static java.nio.charset.StandardCharsets.US_ASCII;
 
 /**
  * A scripted response to be replayed by the mock web server.
@@ -40,9 +42,14 @@
     /** The response body content, or null if {@code body} is set. */
     private InputStream bodyStream;
 
-    private int bytesPerSecond = Integer.MAX_VALUE;
+    private int throttleBytesPerPeriod = Integer.MAX_VALUE;
+    private long throttlePeriod = 1;
+    private TimeUnit throttleUnit = TimeUnit.SECONDS;
+
     private SocketPolicy socketPolicy = SocketPolicy.KEEP_OPEN;
 
+    private int bodyDelayTimeMs = 0;
+
     /**
      * Creates a new mock response with an empty body.
      */
@@ -185,13 +192,13 @@
             int pos = 0;
             while (pos < body.length) {
                 int chunkSize = Math.min(body.length - pos, maxChunkSize);
-                bytesOut.write(Integer.toHexString(chunkSize).getBytes(ASCII));
-                bytesOut.write("\r\n".getBytes(ASCII));
+                bytesOut.write(Integer.toHexString(chunkSize).getBytes(US_ASCII));
+                bytesOut.write("\r\n".getBytes(US_ASCII));
                 bytesOut.write(body, pos, chunkSize);
-                bytesOut.write("\r\n".getBytes(ASCII));
+                bytesOut.write("\r\n".getBytes(US_ASCII));
                 pos += chunkSize;
             }
-            bytesOut.write("0\r\n\r\n".getBytes(ASCII)); // last chunk + empty trailer + crlf
+            bytesOut.write("0\r\n\r\n".getBytes(US_ASCII)); // last chunk + empty trailer + crlf
 
             this.body = bytesOut.toByteArray();
             return this;
@@ -221,19 +228,43 @@
         return this;
     }
 
-    public int getBytesPerSecond() {
-        return bytesPerSecond;
+    /**
+     * Throttles the response body writer to sleep for the given period after each
+     * series of {@code bytesPerPeriod} bytes are written. Use this to simulate
+     * network behavior.
+     */
+    public MockResponse throttleBody(int bytesPerPeriod, long period, TimeUnit unit) {
+        this.throttleBytesPerPeriod = bytesPerPeriod;
+        this.throttlePeriod = period;
+        this.throttleUnit = unit;
+        return this;
+    }
+
+    public int getThrottleBytesPerPeriod() {
+        return throttleBytesPerPeriod;
+    }
+
+    public long getThrottlePeriod() {
+        return throttlePeriod;
+    }
+
+    public TimeUnit getThrottleUnit() {
+        return throttleUnit;
     }
 
     /**
-     * Set simulated network speed, in bytes per second. This applies to the
-     * response body only; response headers are not throttled.
+     * Set the delayed time of the response body to {@code delay}. This applies to the
+     * response body only; response headers are not affected.
      */
-    public MockResponse setBytesPerSecond(int bytesPerSecond) {
-        this.bytesPerSecond = bytesPerSecond;
+    public MockResponse setBodyDelayTimeMs(int delay) {
+        bodyDelayTimeMs = delay;
         return this;
     }
 
+    public int getBodyDelayTimeMs() {
+        return bodyDelayTimeMs;
+    }
+
     @Override public String toString() {
         return "MockResponse{" + status + "}";
     }
diff --git a/src/main/java/com/google/mockwebserver/MockWebServer.java b/src/main/java/com/google/mockwebserver/MockWebServer.java
index afcacc5..13a4597 100644
--- a/src/main/java/com/google/mockwebserver/MockWebServer.java
+++ b/src/main/java/com/google/mockwebserver/MockWebServer.java
@@ -33,11 +33,14 @@
 import java.net.SocketException;
 import java.net.URL;
 import java.net.UnknownHostException;
+import java.nio.charset.StandardCharsets;
+import java.security.SecureRandom;
 import java.security.cert.CertificateException;
 import java.security.cert.X509Certificate;
 import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.ConcurrentHashMap;
@@ -58,12 +61,26 @@
  * replays them upon request in sequence.
  */
 public final class MockWebServer {
+    private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() {
+        @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
+                throws CertificateException {
+            throw new CertificateException();
+        }
 
-    static final String ASCII = "US-ASCII";
+        @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
+            throw new AssertionError();
+        }
+
+        @Override public X509Certificate[] getAcceptedIssuers() {
+            throw new AssertionError();
+        }
+    };
 
     private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
+
     private final BlockingQueue<RecordedRequest> requestQueue
             = new LinkedBlockingQueue<RecordedRequest>();
+
     /** All map values are Boolean.TRUE. (Collections.newSetFromMap isn't available in Froyo) */
     private final Map<Socket, Boolean> openClientSockets = new ConcurrentHashMap<Socket, Boolean>();
     private final AtomicInteger requestCount = new AtomicInteger();
@@ -78,7 +95,6 @@
     private int port = -1;
     private int workerThreads = Integer.MAX_VALUE;
 
-
     public int getPort() {
         if (port == -1) {
             throw new IllegalStateException("Cannot retrieve port before calling play()");
@@ -90,7 +106,7 @@
         try {
             return InetAddress.getLocalHost().getHostName();
         } catch (UnknownHostException e) {
-            throw new AssertionError();
+            throw new AssertionError(e);
         }
     }
 
@@ -250,7 +266,7 @@
                     } catch (SocketException e) {
                         return;
                     }
-                    final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy();
+                    SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
                     if (socketPolicy == DISCONNECT_AT_START) {
                         dispatchBookkeepingRequest(0, socket);
                         socket.close();
@@ -288,16 +304,20 @@
                     if (tunnelProxy) {
                         createTunnel();
                     }
-                    final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy();
+                    SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
                     if (socketPolicy == FAIL_HANDSHAKE) {
                         dispatchBookkeepingRequest(sequenceNumber, raw);
-                        processHandshakeFailure(raw, sequenceNumber++);
+                        processHandshakeFailure(raw);
                         return;
                     }
                     socket = sslSocketFactory.createSocket(
                             raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
-                    ((SSLSocket) socket).setUseClientMode(false);
+                    SSLSocket sslSocket = (SSLSocket) socket;
+                    sslSocket.setUseClientMode(false);
                     openClientSockets.put(socket, true);
+
+                    sslSocket.startHandshake();
+
                     openClientSockets.remove(raw);
                 } else {
                     socket = raw;
@@ -325,13 +345,11 @@
              */
             private void createTunnel() throws IOException, InterruptedException {
                 while (true) {
-                    final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy();
+                    SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
                     if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) {
                         throw new IllegalStateException("Tunnel without any CONNECT!");
                     }
-                    if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) {
-                        return;
-                    }
+                    if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) return;
                 }
             }
 
@@ -341,7 +359,7 @@
              */
             private boolean processOneRequest(Socket socket, InputStream in, OutputStream out)
                     throws IOException, InterruptedException {
-                RecordedRequest request = readRequest(socket, in, sequenceNumber);
+                RecordedRequest request = readRequest(socket, in, out, sequenceNumber);
                 if (request == null) {
                     return false;
                 }
@@ -385,21 +403,9 @@
         }));
     }
 
-    private void processHandshakeFailure(Socket raw, int sequenceNumber) throws Exception {
-        X509TrustManager untrusted = new X509TrustManager() {
-            @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
-                    throws CertificateException {
-                throw new CertificateException();
-            }
-            @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
-                throw new AssertionError();
-            }
-            @Override public X509Certificate[] getAcceptedIssuers() {
-                throw new AssertionError();
-            }
-        };
+    private void processHandshakeFailure(Socket raw) throws Exception {
         SSLContext context = SSLContext.getInstance("TLS");
-        context.init(null, new TrustManager[] { untrusted }, new java.security.SecureRandom());
+        context.init(null, new TrustManager[] { UNTRUSTED_TRUST_MANAGER }, new SecureRandom());
         SSLSocketFactory sslSocketFactory = context.getSocketFactory();
         SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(
                 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
@@ -416,14 +422,11 @@
         RecordedRequest request = new RecordedRequest(null, null, null, -1, null, sequenceNumber,
                 socket);
         dispatcher.dispatch(request);
-        requestQueue.add(request);
     }
 
-    /**
-     * @param sequenceNumber the index of this request on this connection.
-     */
-    private RecordedRequest readRequest(Socket socket, InputStream in, int sequenceNumber)
-            throws IOException {
+    /** @param sequenceNumber the index of this request on this connection. */
+    private RecordedRequest readRequest(Socket socket, InputStream in, OutputStream out,
+            int sequenceNumber) throws IOException {
         String request;
         try {
             request = readAsciiUntilCrlf(in);
@@ -435,27 +438,40 @@
         }
 
         List<String> headers = new ArrayList<String>();
-        int contentLength = -1;
+        long contentLength = -1;
         boolean chunked = false;
+        boolean expectContinue = false;
         String header;
         while ((header = readAsciiUntilCrlf(in)).length() != 0) {
             headers.add(header);
-            String lowercaseHeader = header.toLowerCase();
+            String lowercaseHeader = header.toLowerCase(Locale.US);
             if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
-                contentLength = Integer.parseInt(header.substring(15).trim());
+                contentLength = Long.parseLong(header.substring(15).trim());
             }
-            if (lowercaseHeader.startsWith("transfer-encoding:") &&
-                    lowercaseHeader.substring(18).trim().equals("chunked")) {
+            if (lowercaseHeader.startsWith("transfer-encoding:")
+                    && lowercaseHeader.substring(18).trim().equals("chunked")) {
                 chunked = true;
             }
+            if (lowercaseHeader.startsWith("expect:")
+                    && lowercaseHeader.substring(7).trim().equals("100-continue")) {
+                expectContinue = true;
+            }
+        }
+
+        if (expectContinue) {
+            out.write(("HTTP/1.1 100 Continue\r\n").getBytes(StandardCharsets.US_ASCII));
+            out.write(("Content-Length: 0\r\n").getBytes(StandardCharsets.US_ASCII));
+            out.write(("\r\n").getBytes(StandardCharsets.US_ASCII));
+            out.flush();
         }
 
         boolean hasBody = false;
         TruncatingOutputStream requestBody = new TruncatingOutputStream();
         List<Integer> chunkSizes = new ArrayList<Integer>();
+        MockResponse throttlePolicy = dispatcher.peek();
         if (contentLength != -1) {
             hasBody = true;
-            transfer(contentLength, in, requestBody);
+            throttledTransfer(throttlePolicy, in, requestBody, contentLength);
         } else if (chunked) {
             hasBody = true;
             while (true) {
@@ -465,79 +481,75 @@
                     break;
                 }
                 chunkSizes.add(chunkSize);
-                transfer(chunkSize, in, requestBody);
+                throttledTransfer(throttlePolicy, in, requestBody, chunkSize);
                 readEmptyLine(in);
             }
         }
 
-        if (request.startsWith("OPTIONS ") || request.startsWith("GET ")
-                || request.startsWith("HEAD ") || request.startsWith("DELETE ")
-                || request.startsWith("TRACE ") || request.startsWith("CONNECT ")) {
+        if (request.startsWith("OPTIONS ")
+                || request.startsWith("GET ")
+                || request.startsWith("HEAD ")
+                || request.startsWith("TRACE ")
+                || request.startsWith("CONNECT ")) {
             if (hasBody) {
                 throw new IllegalArgumentException("Request must not have a body: " + request);
             }
-        } else if (!request.startsWith("POST ") && !request.startsWith("PUT ")) {
+        } else if (!request.startsWith("POST ")
+                && !request.startsWith("PUT ")
+                && !request.startsWith("PATCH ")
+                && !request.startsWith("DELETE ")) { // Permitted as spec is ambiguous.
             throw new UnsupportedOperationException("Unexpected method: " + request);
         }
 
-        return new RecordedRequest(request, headers, chunkSizes,
-                requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber, socket);
+        return new RecordedRequest(request, headers, chunkSizes, requestBody.numBytesReceived,
+                requestBody.toByteArray(), sequenceNumber, socket);
     }
 
     private void writeResponse(OutputStream out, MockResponse response) throws IOException {
-        out.write((response.getStatus() + "\r\n").getBytes(ASCII));
-        for (String header : response.getHeaders()) {
-            out.write((header + "\r\n").getBytes(ASCII));
+        out.write((response.getStatus() + "\r\n").getBytes(StandardCharsets.US_ASCII));
+        List<String> headers = response.getHeaders();
+        for (int i = 0, size = headers.size(); i < size; i++) {
+            String header = headers.get(i);
+            out.write((header + "\r\n").getBytes(StandardCharsets.US_ASCII));
         }
-        out.write(("\r\n").getBytes(ASCII));
+        out.write(("\r\n").getBytes(StandardCharsets.US_ASCII));
         out.flush();
 
-        final InputStream in = response.getBodyStream();
-        if (in == null) {
-            return;
-        }
-        final int bytesPerSecond = response.getBytesPerSecond();
-
-        // Stream data in MTU-sized increments
-        final byte[] buffer = new byte[1452];
-        final long delayMs;
-        if (bytesPerSecond == Integer.MAX_VALUE) {
-            delayMs = 0;
-        } else {
-            delayMs = (1000 * buffer.length) / bytesPerSecond;
-        }
-
-        int read;
-        long sinceDelay = 0;
-        while ((read = in.read(buffer)) != -1) {
-            out.write(buffer, 0, read);
-            out.flush();
-
-            sinceDelay += read;
-            if (sinceDelay >= buffer.length && delayMs > 0) {
-                sinceDelay %= buffer.length;
-                try {
-                    Thread.sleep(delayMs);
-                } catch (InterruptedException e) {
-                    throw new AssertionError();
-                }
-            }
-        }
+        InputStream in = response.getBodyStream();
+        if (in == null) return;
+        throttledTransfer(response, in, out, Long.MAX_VALUE);
     }
 
     /**
      * Transfer bytes from {@code in} to {@code out} until either {@code length}
-     * bytes have been transferred or {@code in} is exhausted.
+     * bytes have been transferred or {@code in} is exhausted. The transfer is
+     * throttled according to {@code throttlePolicy}.
      */
-    private void transfer(int length, InputStream in, OutputStream out) throws IOException {
+    private void throttledTransfer(MockResponse throttlePolicy, InputStream in, OutputStream out,
+            long limit) throws IOException {
         byte[] buffer = new byte[1024];
-        while (length > 0) {
-            int count = in.read(buffer, 0, Math.min(buffer.length, length));
-            if (count == -1) {
-                return;
+        int bytesPerPeriod = throttlePolicy.getThrottleBytesPerPeriod();
+        long delayMs = throttlePolicy.getThrottleUnit().toMillis(throttlePolicy.getThrottlePeriod());
+
+        while (true) {
+            for (int b = 0; b < bytesPerPeriod; ) {
+                int toRead = (int) Math.min(Math.min(buffer.length, limit), bytesPerPeriod - b);
+                int read = in.read(buffer, 0, toRead);
+                if (read == -1) return;
+
+                out.write(buffer, 0, read);
+                out.flush();
+                b += read;
+                limit -= read;
+
+                if (limit == 0) return;
             }
-            out.write(buffer, 0, count);
-            length -= count;
+
+            try {
+                if (delayMs != 0) Thread.sleep(delayMs);
+            } catch (InterruptedException e) {
+                throw new AssertionError();
+            }
         }
     }
 
diff --git a/src/main/java/com/google/mockwebserver/QueueDispatcher.java b/src/main/java/com/google/mockwebserver/QueueDispatcher.java
index bc26694..a95089b 100644
--- a/src/main/java/com/google/mockwebserver/QueueDispatcher.java
+++ b/src/main/java/com/google/mockwebserver/QueueDispatcher.java
@@ -45,14 +45,11 @@
         return responseQueue.take();
     }
 
-    @Override public SocketPolicy peekSocketPolicy() {
+    @Override public MockResponse peek() {
         MockResponse peek = responseQueue.peek();
-        if (peek == null) {
-            return failFastResponse != null
-                    ? failFastResponse.getSocketPolicy()
-                    : SocketPolicy.KEEP_OPEN;
-        }
-        return peek.getSocketPolicy();
+        if (peek != null) return peek;
+        if (failFastResponse != null) return failFastResponse;
+        return super.peek();
     }
 
     public void enqueueResponse(MockResponse response) {