Andrew McLaughlin's patch for making forward() pass the original request URI. Fixes a bug with JSP servlet etc., not being able to know the target of the original request.

git-svn-id: https://google-guice.googlecode.com/svn/trunk@1015 d779f126-a31b-0410-b53b-1d3aecad763e
diff --git a/servlet/src/com/google/inject/servlet/GuiceFilter.java b/servlet/src/com/google/inject/servlet/GuiceFilter.java
index 0a840c6..348681d 100644
--- a/servlet/src/com/google/inject/servlet/GuiceFilter.java
+++ b/servlet/src/com/google/inject/servlet/GuiceFilter.java
@@ -77,7 +77,7 @@
 
     // Multiple injectors with Servlet pipelines?!
     // We don't throw an exception in DEVELOPMENT stage, to allow for legacy
-    // tests that don't have a tearDown that calls GuiceFilter#reset().
+    // tests that don't have a tearDown that calls GuiceFilter#destroy().
     if (GuiceFilter.pipeline instanceof ManagedFilterPipeline) {
       if (Stage.PRODUCTION.equals(stage)) {
         throw new RuntimeException(MULTIPLE_INJECTORS_ERROR);
diff --git a/servlet/src/com/google/inject/servlet/ManagedServletPipeline.java b/servlet/src/com/google/inject/servlet/ManagedServletPipeline.java
index 00faf7c..320229e 100755
--- a/servlet/src/com/google/inject/servlet/ManagedServletPipeline.java
+++ b/servlet/src/com/google/inject/servlet/ManagedServletPipeline.java
@@ -35,6 +35,8 @@
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
 import javax.servlet.http.HttpServlet;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletRequestWrapper;
 
 /**
  * A wrapping dispatcher for servlets, in much the same way as {@link ManagedFilterPipeline} is for
@@ -112,13 +114,12 @@
    * the given path or null if no mapping was found.
    */
   RequestDispatcher getRequestDispatcher(String path) {
+    final String newRequestUri = path;
     for (final ServletDefinition servletDefinition : servletDefinitions) {
       if (servletDefinition.shouldServe(path)) {
         return new RequestDispatcher() {
-
           public void forward(ServletRequest servletRequest, ServletResponse servletResponse)
               throws ServletException, IOException {
-
             Preconditions.checkState(!servletResponse.isCommitted(),
                 "Response has been committed--you can only call forward before"
                 + " committing the response (hint: don't flush buffers)");
@@ -126,8 +127,23 @@
             // clear buffer before forwarding
             servletResponse.resetBuffer();
 
+            ServletRequest requestToProcess;
+            if (servletRequest instanceof HttpServletRequest) {
+               requestToProcess =
+                   new HttpServletRequestWrapper((HttpServletRequest) servletRequest) {
+                     public String getRequestURI() {
+                       return newRequestUri;
+                     }
+                   };
+            } else {
+              // This should never happen, but instead of throwing an exception
+              // we will allow a happy case pass thru for maximum tolerance to
+              // legacy (and internal) code.
+              requestToProcess = servletRequest;
+            }
+
             // now dispatch to the servlet
-            servletDefinition.doService(servletRequest, servletResponse);
+            servletDefinition.doService(requestToProcess, servletResponse);
           }
 
           public void include(ServletRequest servletRequest, ServletResponse servletResponse)
diff --git a/servlet/test/com/google/inject/servlet/FilterDispatchIntegrationTest.java b/servlet/test/com/google/inject/servlet/FilterDispatchIntegrationTest.java
index 02d8508..ac6649d 100644
--- a/servlet/test/com/google/inject/servlet/FilterDispatchIntegrationTest.java
+++ b/servlet/test/com/google/inject/servlet/FilterDispatchIntegrationTest.java
@@ -4,22 +4,23 @@
 import com.google.inject.Injector;
 import com.google.inject.Key;
 import com.google.inject.Singleton;
-
-import static org.easymock.EasyMock.createMock;
-import static org.easymock.EasyMock.expect;
-import static org.easymock.EasyMock.replay;
-import static org.easymock.EasyMock.verify;
-
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
 import javax.servlet.Filter;
 import javax.servlet.FilterChain;
 import javax.servlet.FilterConfig;
 import javax.servlet.ServletException;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServlet;
 import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 import junit.framework.TestCase;
-
+import org.easymock.EasyMock;
+import static org.easymock.EasyMock.expect;
+import static org.easymock.EasyMock.expectLastCall;
+import org.easymock.IMocksControl;
 
 /**
  *
@@ -33,12 +34,14 @@
 public class FilterDispatchIntegrationTest extends TestCase {
     private static int inits, doFilters, destroys;
 
+  private IMocksControl control;
+
   @Override
   public final void setUp() {
     inits = 0;
     doFilters = 0;
     destroys = 0;
-
+    control = EasyMock.createControl();
     GuiceFilter.reset();
   }
 
@@ -55,25 +58,44 @@
         // These filters should never fire
         filter("/index/*").through(Key.get(TestFilter.class));
         filter("*.jsp").through(Key.get(TestFilter.class));
+
+        // Bind a servlet
+        serve("*.html").with(TestServlet.class);
       }
     });
 
     final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
     pipeline.initPipeline(null);
 
-    //create ourselves a mock request with test URI
-    HttpServletRequest requestMock = createMock(HttpServletRequest.class);
+    // create ourselves a mock request with test URI
+    HttpServletRequest requestMock = control.createMock(HttpServletRequest.class);
 
     expect(requestMock.getServletPath())
             .andReturn("/index.html")
             .anyTimes();
+    expect(requestMock.getRequestURI())
+            .andReturn("/index.html")
+            .anyTimes();
+
+    HttpServletResponse responseMock = control.createMock(HttpServletResponse.class);
+    expect(responseMock.isCommitted())
+        .andReturn(false)
+        .anyTimes();
+    responseMock.resetBuffer();
+    expectLastCall().anyTimes();
+
+    FilterChain filterChain = control.createMock(FilterChain.class);
 
     //dispatch request
-    replay(requestMock);
-    pipeline.dispatch(requestMock, null, createMock(FilterChain.class));
+    control.replay();
+    pipeline.dispatch(requestMock, responseMock, filterChain);
     pipeline.destroyPipeline();
+    control.verify();
 
-    verify(requestMock);
+    TestServlet servlet = injector.getInstance(TestServlet.class);
+    assertEquals(2, servlet.processedUris.size());
+    assertTrue(servlet.processedUris.contains("/index.html"));
+    assertTrue(servlet.processedUris.contains(TestServlet.FORWARD_TO));
 
     assert inits == 5 && doFilters == 3 && destroys == 5 : "lifecycle states did not"
           + " fire correct number of times-- inits: " + inits + "; dos: " + doFilters
@@ -99,18 +121,19 @@
     pipeline.initPipeline(null);
 
     //create ourselves a mock request with test URI
-    HttpServletRequest requestMock = createMock(HttpServletRequest.class);
+    HttpServletRequest requestMock = control.createMock(HttpServletRequest.class);
 
     expect(requestMock.getServletPath())
             .andReturn("/index.xhtml")
             .anyTimes();
 
     //dispatch request
-    replay(requestMock);
-    pipeline.dispatch(requestMock, null, createMock(FilterChain.class));
+    FilterChain filterChain = control.createMock(FilterChain.class);
+    filterChain.doFilter(requestMock, null);
+    control.replay();
+    pipeline.dispatch(requestMock, null, filterChain);
     pipeline.destroyPipeline();
-
-    verify(requestMock);
+    control.verify();
 
     assert inits == 5 && doFilters == 0 && destroys == 5 : "lifecycle states did not "
           + "fire correct number of times-- inits: " + inits + "; dos: " + doFilters
@@ -126,7 +149,6 @@
       protected void configureServlets() {
         filterRegex("/[A-Za-z]*").through(TestFilter.class);
         filterRegex("/index").through(TestFilter.class);
-
         //these filters should never fire
         filterRegex("\\w").through(Key.get(TestFilter.class));
       }
@@ -136,18 +158,19 @@
     pipeline.initPipeline(null);
 
     //create ourselves a mock request with test URI
-    HttpServletRequest requestMock = createMock(HttpServletRequest.class);
+    HttpServletRequest requestMock = control.createMock(HttpServletRequest.class);
 
     expect(requestMock.getServletPath())
             .andReturn("/index")
             .anyTimes();
 
-    //dispatch request
-    replay(requestMock);
-    pipeline.dispatch(requestMock, null, createMock(FilterChain.class));
+    // dispatch request
+    FilterChain filterChain = control.createMock(FilterChain.class);
+    filterChain.doFilter(requestMock, null);
+    control.replay();
+    pipeline.dispatch(requestMock, null, filterChain);
     pipeline.destroyPipeline();
-
-    verify(requestMock);
+    control.verify();
 
     assert inits == 3 && doFilters == 2 && destroys == 3 : "lifecycle states did not fire "
         + "correct number of times-- inits: " + inits + "; dos: " + doFilters
@@ -170,4 +193,28 @@
       destroys++;
     }
   }
+
+  @Singleton
+  public static class TestServlet extends HttpServlet {
+    public static final String FORWARD_FROM = "/index.html";
+    public static final String FORWARD_TO = "/forwarded.html";
+    public List<String> processedUris = new ArrayList<String>();
+
+    protected void service(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse)
+        throws ServletException, IOException {
+      String requestUri = httpServletRequest.getRequestURI();
+      processedUris.add(requestUri);
+      
+      // If the client is requesting /index.html then we forward to /forwarded.html
+      if (FORWARD_FROM.equals(requestUri)) {
+        httpServletRequest.getRequestDispatcher(FORWARD_TO)
+            .forward(httpServletRequest, httpServletResponse);
+      }
+    }
+
+    public void service(ServletRequest servletRequest, ServletResponse servletResponse)
+        throws ServletException, IOException {
+      service((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse);
+    }
+  }
 }
diff --git a/servlet/test/com/google/inject/servlet/ServletPipelineRequestDispatcherTest.java b/servlet/test/com/google/inject/servlet/ServletPipelineRequestDispatcherTest.java
index db8c77c..bbf8c21 100644
--- a/servlet/test/com/google/inject/servlet/ServletPipelineRequestDispatcherTest.java
+++ b/servlet/test/com/google/inject/servlet/ServletPipelineRequestDispatcherTest.java
@@ -24,6 +24,7 @@
 import com.google.inject.internal.Maps;
 import com.google.inject.internal.Sets;
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Date;
 import java.util.HashMap;
 import java.util.List;
@@ -130,11 +131,11 @@
     mockResponse.resetBuffer();
     expectLastCall().once();
 
-    final boolean[] run = new boolean[1];
+    final List<String> paths = new ArrayList<String>();
     final HttpServlet mockServlet = new HttpServlet() {
       protected void service(HttpServletRequest request, HttpServletResponse httpServletResponse)
           throws ServletException, IOException {
-        run[0] = true;
+        paths.add(request.getRequestURI());
 
         final Object o = request.getAttribute(A_KEY);
         assertEquals("Wrong attrib returned - " + o, A_VALUE, o);
@@ -170,7 +171,7 @@
     assertNotNull(dispatcher);
     dispatcher.forward(mockRequest, mockResponse);
 
-    assertTrue("Include did not dispatch to our servlet!", run[0]);
+    assertTrue("Include did not dispatch to our servlet!", paths.contains(pattern));
 
     verify(injector, mockRequest, mockResponse, mockBinding);
   }