Merge "Retry more errors in Volley's BasicNetwork."
diff --git a/src/main/java/com/android/volley/ClientError.java b/src/main/java/com/android/volley/ClientError.java
new file mode 100644
index 0000000..a8c8141
--- /dev/null
+++ b/src/main/java/com/android/volley/ClientError.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2015 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.volley;
+
+/**
+ * Indicates that the server responded with an error response indicating that the client has erred.
+ *
+ * For backwards compatibility, extends ServerError which used to be thrown for all server errors,
+ * including 4xx error codes indicating a client error.
+ */
+@SuppressWarnings("serial")
+public class ClientError extends ServerError {
+    public ClientError(NetworkResponse networkResponse) {
+        super(networkResponse);
+    }
+
+    public ClientError() {
+        super();
+    }
+}
+
diff --git a/src/main/java/com/android/volley/Request.java b/src/main/java/com/android/volley/Request.java
index 5b42d1f..8200f6e 100644
--- a/src/main/java/com/android/volley/Request.java
+++ b/src/main/java/com/android/volley/Request.java
@@ -20,7 +20,6 @@
 import android.net.Uri;
 import android.os.Handler;
 import android.os.Looper;
-import android.os.SystemClock;
 import android.text.TextUtils;
 
 import com.android.volley.VolleyLog.MarkerLog;
@@ -90,6 +89,9 @@
     /** Whether or not a response has been delivered for this request yet. */
     private boolean mResponseDelivered = false;
 
+    /** Whether the request should be retried in the event of an HTTP 5xx (server) error. */
+    private boolean mShouldRetryServerErrors = false;
+
     /** The retry policy for this request. */
     private RetryPolicy mRetryPolicy;
 
@@ -474,6 +476,23 @@
     }
 
     /**
+     * Sets whether or not the request should be retried in the event of an HTTP 5xx (server) error.
+     *
+     * @return This Request object to allow for chaining.
+     */
+    public final Request<?> setShouldRetryServerErrors(boolean shouldRetryServerErrors) {
+        mShouldRetryServerErrors = shouldRetryServerErrors;
+        return this;
+    }
+
+    /**
+     * Returns true if this request should be retried in the event of an HTTP 5xx (server) error.
+     */
+    public final boolean shouldRetryServerErrors() {
+        return mShouldRetryServerErrors;
+    }
+
+    /**
      * Priority values.  Requests will be processed from higher priorities to
      * lower priorities, in FIFO order.
      */
diff --git a/src/main/java/com/android/volley/toolbox/BasicNetwork.java b/src/main/java/com/android/volley/toolbox/BasicNetwork.java
index 4b1603b..37c35ec 100644
--- a/src/main/java/com/android/volley/toolbox/BasicNetwork.java
+++ b/src/main/java/com/android/volley/toolbox/BasicNetwork.java
@@ -21,6 +21,7 @@
 import com.android.volley.AuthFailureError;
 import com.android.volley.Cache;
 import com.android.volley.Cache.Entry;
+import com.android.volley.ClientError;
 import com.android.volley.Network;
 import com.android.volley.NetworkError;
 import com.android.volley.NetworkResponse;
@@ -143,14 +144,14 @@
             } catch (MalformedURLException e) {
                 throw new RuntimeException("Bad URL " + request.getUrl(), e);
             } catch (IOException e) {
-                int statusCode = 0;
-                NetworkResponse networkResponse = null;
+                int statusCode;
                 if (httpResponse != null) {
                     statusCode = httpResponse.getStatusLine().getStatusCode();
                 } else {
                     throw new NoConnectionError(e);
                 }
                 VolleyLog.e("Unexpected response code %d for %s", statusCode, request.getUrl());
+                NetworkResponse networkResponse;
                 if (responseContents != null) {
                     networkResponse = new NetworkResponse(statusCode, responseContents,
                             responseHeaders, false, SystemClock.elapsedRealtime() - requestStart);
@@ -158,12 +159,22 @@
                             statusCode == HttpStatus.SC_FORBIDDEN) {
                         attemptRetryOnException("auth",
                                 request, new AuthFailureError(networkResponse));
+                    } else if (statusCode >= 400 && statusCode <= 499) {
+                        // Don't retry other client errors.
+                        throw new ClientError(networkResponse);
+                    } else if (statusCode >= 500 && statusCode <= 599) {
+                        if (request.shouldRetryServerErrors()) {
+                            attemptRetryOnException("server",
+                                    request, new ServerError(networkResponse));
+                        } else {
+                            throw new ServerError(networkResponse);
+                        }
                     } else {
-                        // TODO: Only throw ServerError for 5xx status codes.
+                        // 3xx? No reason to retry.
                         throw new ServerError(networkResponse);
                     }
                 } else {
-                    throw new NetworkError(networkResponse);
+                    attemptRetryOnException("network", request, new NetworkError());
                 }
             }
         }
diff --git a/src/test/java/com/android/volley/mock/MockHttpStack.java b/src/test/java/com/android/volley/mock/MockHttpStack.java
index 9594fde..91872d3 100644
--- a/src/test/java/com/android/volley/mock/MockHttpStack.java
+++ b/src/test/java/com/android/volley/mock/MockHttpStack.java
@@ -22,6 +22,7 @@
 
 import org.apache.http.HttpResponse;
 
+import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -29,6 +30,8 @@
 
     private HttpResponse mResponseToReturn;
 
+    private IOException mExceptionToThrow;
+
     private String mLastUrl;
 
     private Map<String, String> mLastHeaders;
@@ -51,9 +54,16 @@
         mResponseToReturn = response;
     }
 
+    public void setExceptionToThrow(IOException exception) {
+        mExceptionToThrow = exception;
+    }
+
     @Override
     public HttpResponse performRequest(Request<?> request, Map<String, String> additionalHeaders)
-            throws AuthFailureError {
+            throws IOException, AuthFailureError {
+        if (mExceptionToThrow != null) {
+            throw mExceptionToThrow;
+        }
         mLastUrl = request.getUrl();
         mLastHeaders = new HashMap<String, String>();
         if (request.getHeaders() != null) {
diff --git a/src/test/java/com/android/volley/toolbox/BasicNetworkTest.java b/src/test/java/com/android/volley/toolbox/BasicNetworkTest.java
index 89718b1..c01d9b0 100644
--- a/src/test/java/com/android/volley/toolbox/BasicNetworkTest.java
+++ b/src/test/java/com/android/volley/toolbox/BasicNetworkTest.java
@@ -16,29 +16,46 @@
 
 package com.android.volley.toolbox;
 
+import com.android.volley.AuthFailureError;
 import com.android.volley.NetworkResponse;
 import com.android.volley.Request;
 import com.android.volley.Response;
+import com.android.volley.RetryPolicy;
+import com.android.volley.ServerError;
+import com.android.volley.TimeoutError;
+import com.android.volley.VolleyError;
 import com.android.volley.mock.MockHttpStack;
 
 import org.apache.http.ProtocolVersion;
+import org.apache.http.conn.ConnectTimeoutException;
 import org.apache.http.entity.StringEntity;
 import org.apache.http.message.BasicHttpResponse;
-
-import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.Mock;
 import org.robolectric.RobolectricTestRunner;
 
-import static org.junit.Assert.*;
-
+import java.io.IOException;
+import java.net.SocketTimeoutException;
 import java.util.HashMap;
 import java.util.Map;
 
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+import static org.mockito.MockitoAnnotations.initMocks;
+
 @RunWith(RobolectricTestRunner.class)
 public class BasicNetworkTest {
 
+    @Mock private Request<String> mMockRequest;
+    @Mock private RetryPolicy mMockRetryPolicy;
+    private BasicNetwork mNetwork;
+
+    @Before public void setUp() throws Exception {
+        initMocks(this);
+    }
+
     @Test public void headersAndPostParams() throws Exception {
         MockHttpStack mockHttpStack = new MockHttpStack();
         BasicHttpResponse fakeResponse = new BasicHttpResponse(new ProtocolVersion("HTTP", 1, 1),
@@ -46,7 +63,188 @@
         fakeResponse.setEntity(new StringEntity("foobar"));
         mockHttpStack.setResponseToReturn(fakeResponse);
         BasicNetwork httpNetwork = new BasicNetwork(mockHttpStack);
-        Request<String> request = new Request<String>(Request.Method.GET, "http://foo", null) {
+        Request<String> request = buildRequest();
+        httpNetwork.performRequest(request);
+        assertEquals("foo", mockHttpStack.getLastHeaders().get("requestheader"));
+        assertEquals("requestpost=foo&", new String(mockHttpStack.getLastPostBody()));
+    }
+
+    @Test public void socketTimeout() throws Exception {
+        MockHttpStack mockHttpStack = new MockHttpStack();
+        mockHttpStack.setExceptionToThrow(new SocketTimeoutException());
+        BasicNetwork httpNetwork = new BasicNetwork(mockHttpStack);
+        Request<String> request = buildRequest();
+        request.setRetryPolicy(mMockRetryPolicy);
+        doThrow(new VolleyError()).when(mMockRetryPolicy).retry(any(VolleyError.class));
+        try {
+            httpNetwork.performRequest(request);
+        } catch (VolleyError e) {
+            // expected
+        }
+        // should retry socket timeouts
+        verify(mMockRetryPolicy).retry(any(TimeoutError.class));
+    }
+
+    @Test public void connectTimeout() throws Exception {
+        MockHttpStack mockHttpStack = new MockHttpStack();
+        mockHttpStack.setExceptionToThrow(new ConnectTimeoutException());
+        BasicNetwork httpNetwork = new BasicNetwork(mockHttpStack);
+        Request<String> request = buildRequest();
+        request.setRetryPolicy(mMockRetryPolicy);
+        doThrow(new VolleyError()).when(mMockRetryPolicy).retry(any(VolleyError.class));
+        try {
+            httpNetwork.performRequest(request);
+        } catch (VolleyError e) {
+            // expected
+        }
+        // should retry connection timeouts
+        verify(mMockRetryPolicy).retry(any(TimeoutError.class));
+    }
+
+    @Test public void noConnection() throws Exception {
+        MockHttpStack mockHttpStack = new MockHttpStack();
+        mockHttpStack.setExceptionToThrow(new IOException());
+        BasicNetwork httpNetwork = new BasicNetwork(mockHttpStack);
+        Request<String> request = buildRequest();
+        request.setRetryPolicy(mMockRetryPolicy);
+        doThrow(new VolleyError()).when(mMockRetryPolicy).retry(any(VolleyError.class));
+        try {
+            httpNetwork.performRequest(request);
+        } catch (VolleyError e) {
+            // expected
+        }
+        // should not retry when there is no connection
+        verify(mMockRetryPolicy, never()).retry(any(VolleyError.class));
+    }
+
+    @Test public void unauthorized() throws Exception {
+        MockHttpStack mockHttpStack = new MockHttpStack();
+        BasicHttpResponse fakeResponse = new BasicHttpResponse(new ProtocolVersion("HTTP", 1, 1),
+                401, "Unauthorized");
+        mockHttpStack.setResponseToReturn(fakeResponse);
+        BasicNetwork httpNetwork = new BasicNetwork(mockHttpStack);
+        Request<String> request = buildRequest();
+        request.setRetryPolicy(mMockRetryPolicy);
+        doThrow(new VolleyError()).when(mMockRetryPolicy).retry(any(VolleyError.class));
+        try {
+            httpNetwork.performRequest(request);
+        } catch (VolleyError e) {
+            // expected
+        }
+        // should retry in case it's an auth failure.
+        verify(mMockRetryPolicy).retry(any(AuthFailureError.class));
+    }
+
+    @Test public void forbidden() throws Exception {
+        MockHttpStack mockHttpStack = new MockHttpStack();
+        BasicHttpResponse fakeResponse = new BasicHttpResponse(new ProtocolVersion("HTTP", 1, 1),
+                403, "Forbidden");
+        mockHttpStack.setResponseToReturn(fakeResponse);
+        BasicNetwork httpNetwork = new BasicNetwork(mockHttpStack);
+        Request<String> request = buildRequest();
+        request.setRetryPolicy(mMockRetryPolicy);
+        doThrow(new VolleyError()).when(mMockRetryPolicy).retry(any(VolleyError.class));
+        try {
+            httpNetwork.performRequest(request);
+        } catch (VolleyError e) {
+            // expected
+        }
+        // should retry in case it's an auth failure.
+        verify(mMockRetryPolicy).retry(any(AuthFailureError.class));
+    }
+
+    @Test public void redirect() throws Exception {
+        for (int i = 300; i <= 399; i++) {
+            MockHttpStack mockHttpStack = new MockHttpStack();
+            BasicHttpResponse fakeResponse =
+                    new BasicHttpResponse(new ProtocolVersion("HTTP", 1, 1), i, "");
+            mockHttpStack.setResponseToReturn(fakeResponse);
+            BasicNetwork httpNetwork = new BasicNetwork(mockHttpStack);
+            Request<String> request = buildRequest();
+            request.setRetryPolicy(mMockRetryPolicy);
+            doThrow(new VolleyError()).when(mMockRetryPolicy).retry(any(VolleyError.class));
+            try {
+                httpNetwork.performRequest(request);
+            } catch (VolleyError e) {
+                // expected
+            }
+            // should not retry 300 responses.
+            verify(mMockRetryPolicy, never()).retry(any(VolleyError.class));
+            reset(mMockRetryPolicy);
+        }
+    }
+
+    @Test public void otherClientError() throws Exception {
+        for (int i = 400; i <= 499; i++) {
+            if (i == 401 || i == 403) {
+                // covered above.
+                continue;
+            }
+            MockHttpStack mockHttpStack = new MockHttpStack();
+            BasicHttpResponse fakeResponse =
+                    new BasicHttpResponse(new ProtocolVersion("HTTP", 1, 1), i, "");
+            mockHttpStack.setResponseToReturn(fakeResponse);
+            BasicNetwork httpNetwork = new BasicNetwork(mockHttpStack);
+            Request<String> request = buildRequest();
+            request.setRetryPolicy(mMockRetryPolicy);
+            doThrow(new VolleyError()).when(mMockRetryPolicy).retry(any(VolleyError.class));
+            try {
+                httpNetwork.performRequest(request);
+            } catch (VolleyError e) {
+                // expected
+            }
+            // should not retry other 400 errors.
+            verify(mMockRetryPolicy, never()).retry(any(VolleyError.class));
+            reset(mMockRetryPolicy);
+        }
+    }
+
+    @Test public void serverError_enableRetries() throws Exception {
+        for (int i = 500; i <= 599; i++) {
+            MockHttpStack mockHttpStack = new MockHttpStack();
+            BasicHttpResponse fakeResponse =
+                    new BasicHttpResponse(new ProtocolVersion("HTTP", 1, 1), i, "");
+            mockHttpStack.setResponseToReturn(fakeResponse);
+            BasicNetwork httpNetwork =
+                    new BasicNetwork(mockHttpStack, new ByteArrayPool(4096));
+            Request<String> request = buildRequest();
+            request.setRetryPolicy(mMockRetryPolicy);
+            request.setShouldRetryServerErrors(true);
+            doThrow(new VolleyError()).when(mMockRetryPolicy).retry(any(VolleyError.class));
+            try {
+                httpNetwork.performRequest(request);
+            } catch (VolleyError e) {
+                // expected
+            }
+            // should retry all 500 errors
+            verify(mMockRetryPolicy).retry(any(ServerError.class));
+            reset(mMockRetryPolicy);
+        }
+    }
+
+    @Test public void serverError_disableRetries() throws Exception {
+        for (int i = 500; i <= 599; i++) {
+            MockHttpStack mockHttpStack = new MockHttpStack();
+            BasicHttpResponse fakeResponse =
+                    new BasicHttpResponse(new ProtocolVersion("HTTP", 1, 1), i, "");
+            mockHttpStack.setResponseToReturn(fakeResponse);
+            BasicNetwork httpNetwork = new BasicNetwork(mockHttpStack);
+            Request<String> request = buildRequest();
+            request.setRetryPolicy(mMockRetryPolicy);
+            doThrow(new VolleyError()).when(mMockRetryPolicy).retry(any(VolleyError.class));
+            try {
+                httpNetwork.performRequest(request);
+            } catch (VolleyError e) {
+                // expected
+            }
+            // should not retry any 500 error w/ HTTP 500 retries turned off (the default).
+            verify(mMockRetryPolicy, never()).retry(any(VolleyError.class));
+            reset(mMockRetryPolicy);
+        }
+    }
+
+    private static Request<String> buildRequest() {
+        return new Request<String>(Request.Method.GET, "http://foo", null) {
 
             @Override
             protected Response<String> parseNetworkResponse(NetworkResponse response) {
@@ -71,8 +269,5 @@
                 return result;
             }
         };
-        httpNetwork.performRequest(request);
-        assertEquals("foo", mockHttpStack.getLastHeaders().get("requestheader"));
-        assertEquals("requestpost=foo&", new String(mockHttpStack.getLastPostBody()));
     }
 }