Replace the Request/Response Context after each ServletModule-registered
Filter.  This fixes problems where wrapped request/response classes weren't passed to subsequent filters or servlets in the chain.

Revision created by MOE tool push_codebase.
MOE_MIGRATION=3340


git-svn-id: https://google-guice.googlecode.com/svn/trunk@1585 d779f126-a31b-0410-b53b-1d3aecad763e
diff --git a/extensions/servlet/src/com/google/inject/servlet/FilterChainInvocation.java b/extensions/servlet/src/com/google/inject/servlet/FilterChainInvocation.java
index 4669875..262d263 100755
--- a/extensions/servlet/src/com/google/inject/servlet/FilterChainInvocation.java
+++ b/extensions/servlet/src/com/google/inject/servlet/FilterChainInvocation.java
@@ -21,6 +21,8 @@
 import javax.servlet.ServletException;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 
 /**
  * A Filter chain impl which basically passes itself to the "current" filter and iterates the chain
@@ -54,18 +56,27 @@
       throws IOException, ServletException {
     index++;
 
-    //dispatch down the chain while there are more filters
-    if (index < filterDefinitions.length) {
-      filterDefinitions[index].doFilter(servletRequest, servletResponse, this);
-    } else {
+    GuiceFilter.Context previous = GuiceFilter.localContext.get();
+    HttpServletRequest request = (HttpServletRequest) servletRequest;
+    HttpServletResponse response = (HttpServletResponse) servletResponse;
+    HttpServletRequest originalRequest
+        = (previous != null) ? previous.getOriginalRequest() : request;
+    GuiceFilter.localContext.set(new GuiceFilter.Context(originalRequest, request, response));
+    try {
+      //dispatch down the chain while there are more filters
+      if (index < filterDefinitions.length) {
+        filterDefinitions[index].doFilter(servletRequest, servletResponse, this);
+      } else {
+        //we've reached the end of the filterchain, let's try to dispatch to a servlet
+        final boolean serviced = servletPipeline.service(servletRequest, servletResponse);
 
-      //we've reached the end of the filterchain, let's try to dispatch to a servlet
-      final boolean serviced = servletPipeline.service(servletRequest, servletResponse);
-
-      //dispatch to the normal filter chain only if one of our servlets did not match
-      if (!serviced) {
-        proceedingChain.doFilter(servletRequest, servletResponse);
+        //dispatch to the normal filter chain only if one of our servlets did not match
+        if (!serviced) {
+          proceedingChain.doFilter(servletRequest, servletResponse);
+        }
       }
+    } finally {
+      GuiceFilter.localContext.set(previous);
     }
   }
 }
diff --git a/extensions/servlet/src/com/google/inject/servlet/GuiceFilter.java b/extensions/servlet/src/com/google/inject/servlet/GuiceFilter.java
index 188adfa..4b602b5 100644
--- a/extensions/servlet/src/com/google/inject/servlet/GuiceFilter.java
+++ b/extensions/servlet/src/com/google/inject/servlet/GuiceFilter.java
@@ -96,29 +96,34 @@
   //VisibleForTesting
   static void reset() {
     pipeline = new DefaultFilterPipeline();
+    localContext.remove();
   }
 
   public void doFilter(ServletRequest servletRequest,
       ServletResponse servletResponse, FilterChain filterChain)
       throws IOException, ServletException {
 
-    Context previous = localContext.get();
-
     // Prefer the injected pipeline, but fall back on the static one for web.xml users.
     FilterPipeline filterPipeline = null != injectedPipeline ? injectedPipeline : pipeline;
 
+    Context previous = GuiceFilter.localContext.get();
+    HttpServletRequest request = (HttpServletRequest) servletRequest;
+    HttpServletResponse response = (HttpServletResponse) servletResponse;
+    HttpServletRequest originalRequest
+        = (previous != null) ? previous.getOriginalRequest() : request;
+    localContext.set(new Context(originalRequest, request, response));
     try {
-      localContext.set(new Context((HttpServletRequest) servletRequest,
-          (HttpServletResponse) servletResponse));
-
       //dispatch across the servlet pipeline, ensuring web.xml's filterchain is honored
       filterPipeline.dispatch(servletRequest, servletResponse, filterChain);
-
     } finally {
       localContext.set(previous);
     }
   }
 
+  static HttpServletRequest getOriginalRequest() {
+    return getContext().getOriginalRequest();
+  }
+
   static HttpServletRequest getRequest() {
     return getContext().getRequest();
   }
@@ -131,7 +136,7 @@
     return servletContext.get();
   }
 
-  static Context getContext() {
+  private static Context getContext() {
     Context context = localContext.get();
     if (context == null) {
       throw new OutOfScopeException("Cannot access scoped object. Either we"
@@ -143,15 +148,21 @@
   }
 
   static class Context {
-
+    final HttpServletRequest originalRequest;
     final HttpServletRequest request;
     final HttpServletResponse response;
 
-    Context(HttpServletRequest request, HttpServletResponse response) {
+    Context(HttpServletRequest originalRequest, HttpServletRequest request,
+        HttpServletResponse response) {
+      this.originalRequest = originalRequest;
       this.request = request;
       this.response = response;
     }
 
+    HttpServletRequest getOriginalRequest() {
+      return originalRequest;
+    }
+
     HttpServletRequest getRequest() {
       return request;
     }
diff --git a/extensions/servlet/src/com/google/inject/servlet/ManagedFilterPipeline.java b/extensions/servlet/src/com/google/inject/servlet/ManagedFilterPipeline.java
index 1d3ae60..62da7c7 100755
--- a/extensions/servlet/src/com/google/inject/servlet/ManagedFilterPipeline.java
+++ b/extensions/servlet/src/com/google/inject/servlet/ManagedFilterPipeline.java
@@ -137,8 +137,6 @@
   private ServletRequest withDispatcher(ServletRequest servletRequest,
       final ManagedServletPipeline servletPipeline) {
 
-    HttpServletRequest request = (HttpServletRequest) servletRequest;
-
     // don't wrap the request if there are no servlets mapped. This prevents us from inserting our
     // wrapper unless it's actually going to be used. This is necessary for compatibility for apps
     // that downcast their HttpServletRequests to a concrete implementation.
@@ -146,6 +144,7 @@
       return servletRequest;
     }
 
+    HttpServletRequest request = (HttpServletRequest) servletRequest;
     //noinspection OverlyComplexAnonymousInnerClass
     return new HttpServletRequestWrapper(request) {
 
diff --git a/extensions/servlet/src/com/google/inject/servlet/ManagedServletPipeline.java b/extensions/servlet/src/com/google/inject/servlet/ManagedServletPipeline.java
index 08a26bf..e3e35df 100755
--- a/extensions/servlet/src/com/google/inject/servlet/ManagedServletPipeline.java
+++ b/extensions/servlet/src/com/google/inject/servlet/ManagedServletPipeline.java
@@ -139,21 +139,20 @@
               requestToProcess = servletRequest;
             }
 
-            servletRequest.setAttribute(REQUEST_DISPATCHER_REQUEST, Boolean.TRUE);
-
             // now dispatch to the servlet
-            try {
-              servletDefinition.doService(requestToProcess, servletResponse);
-            } finally {
-              servletRequest.removeAttribute(REQUEST_DISPATCHER_REQUEST);
-            }
+            doServiceImpl(servletDefinition, requestToProcess, servletResponse);
           }
 
           public void include(ServletRequest servletRequest, ServletResponse servletResponse)
               throws ServletException, IOException {
+            // route to the target servlet
+            doServiceImpl(servletDefinition, servletRequest, servletResponse);
+          }
+
+          private void doServiceImpl(ServletDefinition servletDefinition, ServletRequest servletRequest,
+              ServletResponse servletResponse) throws ServletException, IOException {
             servletRequest.setAttribute(REQUEST_DISPATCHER_REQUEST, Boolean.TRUE);
 
-            // route to the target servlet
             try {
               servletDefinition.doService(servletRequest, servletResponse);
             } finally {
diff --git a/extensions/servlet/src/com/google/inject/servlet/ServletDefinition.java b/extensions/servlet/src/com/google/inject/servlet/ServletDefinition.java
index 995561d..00ac328 100755
--- a/extensions/servlet/src/com/google/inject/servlet/ServletDefinition.java
+++ b/extensions/servlet/src/com/google/inject/servlet/ServletDefinition.java
@@ -41,6 +41,7 @@
 import javax.servlet.http.HttpServlet;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequestWrapper;
+import javax.servlet.http.HttpServletResponse;
 
 /**
  * An internal representation of a servlet definition mapped to a particular URI pattern. Also
@@ -195,19 +196,20 @@
 
     HttpServletRequest request = new HttpServletRequestWrapper(
         (HttpServletRequest) servletRequest) {
+      private boolean pathComputed;
       private String path;
-      private boolean pathComputed = false;
-      //must use a boolean on the memo field, because null is a legal value (TODO no, it's not)
 
-      private boolean pathInfoComputed = false;
+      private boolean pathInfoComputed;
       private String pathInfo;
 
       @Override
       public String getPathInfo() {
         if (!isPathInfoComputed()) {
           int servletPathLength = getServletPath().length();
-          pathInfo = getRequestURI().substring(getContextPath().length()).replaceAll("[/]{2,}", "/");
-          pathInfo = pathInfo.length() > servletPathLength ? pathInfo.substring(servletPathLength) : null;
+          pathInfo = getRequestURI().substring(getContextPath().length())
+              .replaceAll("[/]{2,}", "/");
+          pathInfo = pathInfo.length() > servletPathLength
+              ? pathInfo.substring(servletPathLength) : null;
 
           // Corner case: when servlet path and request path match exactly (without trailing '/'),
           // then pathinfo is null
@@ -221,8 +223,10 @@
         return pathInfo;
       }
 
-      // NOTE(dhanji): These two are a bit of a hack to help ensure that request dipatcher-sent
+      // NOTE(dhanji): These two are a bit of a hack to help ensure that request dispatcher-sent
       // requests don't use the same path info that was memoized for the original request.
+      // NOTE(iqshum): I don't think this is possible, since the dispatcher-sent request would
+      // perform its own wrapping.
       private boolean isPathInfoComputed() {
         return pathInfoComputed
             && !(null != servletRequest.getAttribute(REQUEST_DISPATCHER_REQUEST));
@@ -261,7 +265,20 @@
       }
     };
 
-    httpServlet.get().service(request, servletResponse);
+    doServiceImpl(request, (HttpServletResponse) servletResponse);
+  }
+
+  private void doServiceImpl(HttpServletRequest request, HttpServletResponse response)
+      throws ServletException, IOException {
+    GuiceFilter.Context previous = GuiceFilter.localContext.get();
+    HttpServletRequest originalRequest
+        = (previous != null) ? previous.getOriginalRequest() : request;
+    GuiceFilter.localContext.set(new GuiceFilter.Context(originalRequest, request, response));
+    try {
+      httpServlet.get().service(request, response);
+    } finally {
+      GuiceFilter.localContext.set(previous);
+    }
   }
 
   String getKey() {
diff --git a/extensions/servlet/src/com/google/inject/servlet/ServletScopes.java b/extensions/servlet/src/com/google/inject/servlet/ServletScopes.java
index 27a6766..bd82dcf 100644
--- a/extensions/servlet/src/com/google/inject/servlet/ServletScopes.java
+++ b/extensions/servlet/src/com/google/inject/servlet/ServletScopes.java
@@ -17,6 +17,7 @@
 package com.google.inject.servlet;
 
 import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Maps;
 import com.google.inject.Key;
 import com.google.inject.OutOfScopeException;
@@ -27,6 +28,7 @@
 import java.util.concurrent.Callable;
 
 import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 import javax.servlet.http.HttpSession;
 
 /**
@@ -38,6 +40,12 @@
 
   private ServletScopes() {}
 
+  /** Keys bound in request-scope which are handled directly by GuiceFilter. */
+  private static final ImmutableSet<Key<?>> REQUEST_CONTEXT_KEYS = ImmutableSet.of(
+      Key.get(HttpServletRequest.class),
+      Key.get(HttpServletResponse.class),
+      new Key<Map<String, String[]>>(RequestParameters.class) {});
+
   /** A sentinel attribute value representing null. */
   enum NullObject { INSTANCE }
 
@@ -45,7 +53,7 @@
    * HTTP servlet request scope.
    */
   public static final Scope REQUEST = new Scope() {
-    public <T> Provider<T> scope(Key<T> key, final Provider<T> creator) {
+    public <T> Provider<T> scope(final Key<T> key, final Provider<T> creator) {
       final String name = key.toString();
       return new Provider<T>() {
         public T get() {
@@ -77,8 +85,17 @@
               // exception is thrown.
           }
 
-          HttpServletRequest request = GuiceFilter.getRequest();
-
+          // Always synchronize and get/set attributes on the underlying request
+          // object since Filters may wrap the request and change the value of
+          // {@code GuiceFilter.getRequest()}.
+          //
+          // This _correctly_ throws up if the thread is out of scope.
+          HttpServletRequest request = GuiceFilter.getOriginalRequest();
+          if (REQUEST_CONTEXT_KEYS.contains(key)) {
+            // Don't store these keys as attributes, since they are handled by
+            // GuiceFilter itself.
+            return creator.get();
+          }
           synchronized (request) {
             Object obj = request.getAttribute(name);
             if (NullObject.INSTANCE == obj) {
@@ -182,7 +199,7 @@
     }
 
     return new Callable<T>() {
-      private HttpServletRequest request = continuingRequest;
+      private final HttpServletRequest request = continuingRequest;
 
       public T call() throws Exception {
         GuiceFilter.Context context = GuiceFilter.localContext.get();
@@ -191,7 +208,7 @@
 
         // Only set up the request continuation if we're running in a
         // new vanilla thread.
-        GuiceFilter.localContext.set(new GuiceFilter.Context(request, null));
+        GuiceFilter.localContext.set(new GuiceFilter.Context(request, request, null));
         try {
           return callable.call();
         } finally {
diff --git a/extensions/servlet/test/com/google/inject/servlet/ServletTest.java b/extensions/servlet/test/com/google/inject/servlet/ServletTest.java
index 9532fb0..a0ac426 100644
--- a/extensions/servlet/test/com/google/inject/servlet/ServletTest.java
+++ b/extensions/servlet/test/com/google/inject/servlet/ServletTest.java
@@ -24,13 +24,17 @@
 import static java.lang.annotation.RetentionPolicy.RUNTIME;
 
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.inject.AbstractModule;
 import com.google.inject.BindingAnnotation;
 import com.google.inject.CreationException;
 import com.google.inject.Guice;
+import com.google.inject.Inject;
 import com.google.inject.Injector;
 import com.google.inject.Key;
+import com.google.inject.Module;
+import com.google.inject.Provider;
 import com.google.inject.util.Providers;
 
 import junit.framework.TestCase;
@@ -44,13 +48,17 @@
 import java.lang.reflect.Proxy;
 import java.util.Map;
 
+import javax.servlet.Filter;
 import javax.servlet.FilterChain;
+import javax.servlet.FilterConfig;
 import javax.servlet.ServletException;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServlet;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequestWrapper;
 import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpServletResponseWrapper;
 import javax.servlet.http.HttpSession;
 
 /**
@@ -77,7 +85,6 @@
     final Injector injector = createInjector();
     final HttpServletRequest request = newFakeHttpServletRequest();
     final HttpServletResponse response = newFakeHttpServletResponse();
-    final Map<String, String[]> params = Maps.newHashMap();
 
     final boolean[] invoked = new boolean[1];
     GuiceFilter filter = new GuiceFilter();
@@ -101,6 +108,148 @@
     assertTrue(invoked[0]);
   }
 
+  public void testRequestAndResponseBindings_wrappingFilter() throws Exception {
+    final HttpServletRequest request = newFakeHttpServletRequest();
+    final ImmutableMap<String, String[]> wrappedParamMap
+        = ImmutableMap.of("wrap", new String[]{"a", "b"});
+    final HttpServletRequestWrapper requestWrapper = new HttpServletRequestWrapper(request) {
+      @Override public Map getParameterMap() {
+        return wrappedParamMap;
+      }
+
+      @Override public Object getAttribute(String attr) {
+        // Ensure that attributes are stored on the original request object.
+        throw new UnsupportedOperationException();
+      }
+    };
+    final HttpServletResponse response = newFakeHttpServletResponse();
+    final HttpServletResponseWrapper responseWrapper = new HttpServletResponseWrapper(response);
+
+    final boolean[] filterInvoked = new boolean[1];
+    final Injector injector = createInjector(new ServletModule() {
+      @Override protected void configureServlets() {
+        filter("/*").through(new Filter() {
+          @Inject Provider<ServletRequest> servletReqProvider;
+          @Inject Provider<HttpServletRequest> reqProvider;
+          @Inject Provider<ServletResponse> servletRespProvider;
+          @Inject Provider<HttpServletResponse> respProvider;
+
+          public void init(FilterConfig filterConfig) {}
+
+          public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain)
+              throws IOException, ServletException {
+            filterInvoked[0] = true;
+            assertSame(req, servletReqProvider.get());
+            assertSame(req, reqProvider.get());
+
+            assertSame(resp, servletRespProvider.get());
+            assertSame(resp, respProvider.get());
+
+            chain.doFilter(requestWrapper, responseWrapper);
+
+            assertSame(req, reqProvider.get());
+            assertSame(resp, respProvider.get());
+          }
+
+          public void destroy() {}
+        });
+      }
+    });
+
+    GuiceFilter filter = new GuiceFilter();
+    final boolean[] chainInvoked = new boolean[1];
+    FilterChain filterChain = new FilterChain() {
+      public void doFilter(ServletRequest servletRequest,
+          ServletResponse servletResponse) {
+        chainInvoked[0] = true;
+        assertSame(requestWrapper, servletRequest);
+        assertSame(requestWrapper, injector.getInstance(ServletRequest.class));
+        assertSame(requestWrapper, injector.getInstance(HTTP_REQ_KEY));
+
+        assertSame(responseWrapper, servletResponse);
+        assertSame(responseWrapper, injector.getInstance(ServletResponse.class));
+        assertSame(responseWrapper, injector.getInstance(HTTP_RESP_KEY));
+
+        assertSame(servletRequest.getParameterMap(), injector.getInstance(REQ_PARAMS_KEY));
+
+        InRequest inRequest = injector.getInstance(InRequest.class);
+        assertSame(inRequest, injector.getInstance(InRequest.class));
+      }
+    };
+    filter.doFilter(request, response, filterChain);
+
+    assertTrue(chainInvoked[0]);
+    assertTrue(filterInvoked[0]);
+  }
+
+  public void testRequestAndResponseBindings_matchesPassedParameters() throws Exception {
+    final int[] filterInvoked = new int[1];
+    final boolean[] servletInvoked = new boolean[1];
+    final Injector injector = createInjector(new ServletModule() {
+      @Override protected void configureServlets() {
+        final HttpServletRequest[] previousReq = new HttpServletRequest[1];
+        final HttpServletResponse[] previousResp = new HttpServletResponse[1];
+
+        final Provider<ServletRequest> servletReqProvider = getProvider(ServletRequest.class);
+        final Provider<HttpServletRequest> reqProvider = getProvider(HttpServletRequest.class);
+        final Provider<ServletResponse> servletRespProvider = getProvider(ServletResponse.class);
+        final Provider<HttpServletResponse> respProvider = getProvider(HttpServletResponse.class);
+
+        Filter filter = new Filter() {
+          public void init(FilterConfig filterConfig) {}
+
+          public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain)
+              throws IOException, ServletException {
+            filterInvoked[0]++;
+            assertSame(req, servletReqProvider.get());
+            assertSame(req, reqProvider.get());
+            if (previousReq[0] != null) {
+              assertEquals(req, previousReq[0]);
+            }
+
+            assertSame(resp, servletRespProvider.get());
+            assertSame(resp, respProvider.get());
+            if (previousResp[0] != null) {
+              assertEquals(resp, previousResp[0]);
+            }
+
+            chain.doFilter(
+                previousReq[0] = new HttpServletRequestWrapper((HttpServletRequest) req),
+                previousResp[0] = new HttpServletResponseWrapper((HttpServletResponse) resp));
+
+            assertSame(req, reqProvider.get());
+            assertSame(resp, respProvider.get());
+          }
+
+          public void destroy() {}
+        };
+
+        filter("/*").through(filter);
+        filter("/*").through(filter);  // filter twice to test wrapping in filters
+        serve("/*").with(new HttpServlet() {
+          @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
+            servletInvoked[0] = true;
+            assertSame(req, servletReqProvider.get());
+            assertSame(req, reqProvider.get());
+
+            assertSame(resp, servletRespProvider.get());
+            assertSame(resp, respProvider.get());
+          }
+        });
+      }
+    });
+
+    GuiceFilter filter = new GuiceFilter();
+    filter.doFilter(newFakeHttpServletRequest(), newFakeHttpServletResponse(), new FilterChain() {
+      public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
+        throw new IllegalStateException("Shouldn't get here");
+      }
+    });
+
+    assertEquals(2, filterInvoked[0]);
+    assertTrue(servletInvoked[0]);
+  }
+
   public void testNewRequestObject()
       throws CreationException, IOException, ServletException {
     final Injector injector = createInjector();
@@ -240,6 +389,10 @@
       final Map<String, Object> attributes = Maps.newHashMap(); 
       final HttpSession session = newFakeHttpSession();
 
+      @Override public String getMethod() {
+        return "GET";
+      }
+
       @Override public Object getAttribute(String name) {
         return attributes.get(name);
       }
@@ -300,8 +453,8 @@
         new Class[] { HttpSession.class }, new FakeHttpSessionHandler());
   }
 
-  private Injector createInjector() throws CreationException {
-    return Guice.createInjector(new AbstractModule() {
+  private Injector createInjector(Module... modules) throws CreationException {
+    return Guice.createInjector(Lists.<Module>asList(new AbstractModule() {
       @Override
       protected void configure() {
         install(new ServletModule());
@@ -310,7 +463,7 @@
         bind(InRequest.class);
         bind(IN_REQUEST_NULL_KEY).toProvider(Providers.<InRequest>of(null)).in(RequestScoped.class);
       }
-    });
+    }, modules));
   }
 
   @SessionScoped