blob: 1d55ec1bb82ecfc4f84b1246619605af5313d6ee [file] [log] [blame]
/**
* Copyright (C) 2011 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.inject.servlet;
import static com.google.inject.servlet.ManagedServletPipeline.REQUEST_DISPATCHER_REQUEST;
import static com.google.inject.servlet.ServletTestUtils.newFakeHttpServletRequest;
import static com.google.inject.servlet.ServletTestUtils.newNoOpFilterChain;
import static org.easymock.EasyMock.expect;
import static org.easymock.EasyMock.expectLastCall;
import com.google.inject.Guice;
import com.google.inject.Injector;
import com.google.inject.Key;
import com.google.inject.Singleton;
import junit.framework.TestCase;
import org.easymock.EasyMock;
import org.easymock.IMocksControl;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
*
* This tests that filter stage of the pipeline dispatches
* correctly to guice-managed filters.
*
* WARNING(dhanji): Non-parallelizable test =(
*
* @author dhanji@gmail.com (Dhanji R. Prasanna)
*/
public class FilterDispatchIntegrationTest extends TestCase {
private static int inits, doFilters, destroys;
private IMocksControl control;
@Override
public final void setUp() {
inits = 0;
doFilters = 0;
destroys = 0;
control = EasyMock.createControl();
GuiceFilter.reset();
}
public final void testDispatchRequestToManagedPipeline() throws ServletException, IOException {
final Injector injector = Guice.createInjector(new ServletModule() {
@Override
protected void configureServlets() {
filter("/*").through(TestFilter.class);
filter("*.html").through(TestFilter.class);
filter("/*").through(Key.get(TestFilter.class));
// These filters should never fire
filter("/index/*").through(Key.get(TestFilter.class));
filter("*.jsp").through(Key.get(TestFilter.class));
// Bind a servlet
serve("*.html").with(TestServlet.class);
}
});
final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
pipeline.initPipeline(null);
// create ourselves a mock request with test URI
HttpServletRequest requestMock = control.createMock(HttpServletRequest.class);
expect(requestMock.getRequestURI())
.andReturn("/index.html")
.anyTimes();
expect(requestMock.getContextPath())
.andReturn("")
.anyTimes();
requestMock.setAttribute(REQUEST_DISPATCHER_REQUEST, true);
requestMock.removeAttribute(REQUEST_DISPATCHER_REQUEST);
HttpServletResponse responseMock = control.createMock(HttpServletResponse.class);
expect(responseMock.isCommitted())
.andReturn(false)
.anyTimes();
responseMock.resetBuffer();
expectLastCall().anyTimes();
FilterChain filterChain = control.createMock(FilterChain.class);
//dispatch request
control.replay();
pipeline.dispatch(requestMock, responseMock, filterChain);
pipeline.destroyPipeline();
control.verify();
TestServlet servlet = injector.getInstance(TestServlet.class);
assertEquals(2, servlet.processedUris.size());
assertTrue(servlet.processedUris.contains("/index.html"));
assertTrue(servlet.processedUris.contains(TestServlet.FORWARD_TO));
assertTrue("lifecycle states did not"
+ " fire correct number of times-- inits: " + inits + "; dos: " + doFilters
+ "; destroys: " + destroys, inits == 1 && doFilters == 3 && destroys == 1);
}
public final void testDispatchThatNoFiltersFire() throws ServletException, IOException {
final Injector injector = Guice.createInjector(new ServletModule() {
@Override
protected void configureServlets() {
filter("/public/*").through(TestFilter.class);
filter("*.html").through(TestFilter.class);
filter("*.xml").through(Key.get(TestFilter.class));
// These filters should never fire
filter("/index/*").through(Key.get(TestFilter.class));
filter("*.jsp").through(Key.get(TestFilter.class));
}
});
final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
pipeline.initPipeline(null);
//create ourselves a mock request with test URI
HttpServletRequest requestMock = control.createMock(HttpServletRequest.class);
expect(requestMock.getRequestURI())
.andReturn("/index.xhtml")
.anyTimes();
expect(requestMock.getContextPath())
.andReturn("")
.anyTimes();
//dispatch request
FilterChain filterChain = control.createMock(FilterChain.class);
filterChain.doFilter(requestMock, null);
control.replay();
pipeline.dispatch(requestMock, null, filterChain);
pipeline.destroyPipeline();
control.verify();
assertTrue("lifecycle states did not "
+ "fire correct number of times-- inits: " + inits + "; dos: " + doFilters
+ "; destroys: " + destroys,
inits == 1 && doFilters == 0 && destroys == 1);
}
public final void testDispatchFilterPipelineWithRegexMatching() throws ServletException,
IOException {
final Injector injector = Guice.createInjector(new ServletModule() {
@Override
protected void configureServlets() {
filterRegex("/[A-Za-z]*").through(TestFilter.class);
filterRegex("/index").through(TestFilter.class);
//these filters should never fire
filterRegex("\\w").through(Key.get(TestFilter.class));
}
});
final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
pipeline.initPipeline(null);
//create ourselves a mock request with test URI
HttpServletRequest requestMock = control.createMock(HttpServletRequest.class);
expect(requestMock.getRequestURI())
.andReturn("/index")
.anyTimes();
expect(requestMock.getContextPath())
.andReturn("")
.anyTimes();
// dispatch request
FilterChain filterChain = control.createMock(FilterChain.class);
filterChain.doFilter(requestMock, null);
control.replay();
pipeline.dispatch(requestMock, null, filterChain);
pipeline.destroyPipeline();
control.verify();
assertTrue("lifecycle states did not fire "
+ "correct number of times-- inits: " + inits + "; dos: " + doFilters
+ "; destroys: " + destroys,
inits == 1 && doFilters == 2 && destroys == 1);
}
@Singleton
public static class TestFilter implements Filter {
public void init(FilterConfig filterConfig) {
inits++;
}
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse,
FilterChain filterChain) throws IOException, ServletException {
doFilters++;
filterChain.doFilter(servletRequest, servletResponse);
}
public void destroy() {
destroys++;
}
}
@Singleton
public static class TestServlet extends HttpServlet {
public static final String FORWARD_FROM = "/index.html";
public static final String FORWARD_TO = "/forwarded.html";
public List<String> processedUris = new ArrayList<String>();
@Override
protected void service(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse)
throws ServletException, IOException {
String requestUri = httpServletRequest.getRequestURI();
processedUris.add(requestUri);
// If the client is requesting /index.html then we forward to /forwarded.html
if (FORWARD_FROM.equals(requestUri)) {
httpServletRequest.getRequestDispatcher(FORWARD_TO)
.forward(httpServletRequest, httpServletResponse);
}
}
@Override
public void service(ServletRequest servletRequest, ServletResponse servletResponse)
throws ServletException, IOException {
service((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse);
}
}
public void testFilterOrder() throws Exception {
AtomicInteger counter = new AtomicInteger();
final CountFilter f1 = new CountFilter(counter);
final CountFilter f2 = new CountFilter(counter);
Injector injector = Guice.createInjector(new ServletModule() {
@Override
protected void configureServlets() {
filter("/").through(f1);
install(new ServletModule() {
@Override
protected void configureServlets() {
filter("/").through(f2);
}
});
}
});
HttpServletRequest request = newFakeHttpServletRequest();
final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
pipeline.initPipeline(null);
pipeline.dispatch(request, null, newNoOpFilterChain());
assertEquals(0, f1.calledAt);
assertEquals(1, f2.calledAt);
}
/** A filter that keeps count of when it was called by increment a counter. */
private static class CountFilter implements Filter {
private final AtomicInteger counter;
private int calledAt = -1;
public CountFilter(AtomicInteger counter) {
this.counter = counter;
}
public void destroy() {
}
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws ServletException, IOException {
if (calledAt != -1) {
fail("not expecting to be called twice");
}
calledAt = counter.getAndIncrement();
chain.doFilter(request, response);
}
public void init(FilterConfig filterConfig) {}
}
public final void testFilterExceptionPrunesStack() throws Exception {
Injector injector = Guice.createInjector(new ServletModule() {
@Override
protected void configureServlets() {
filter("/").through(TestFilter.class);
filter("/nothing").through(TestFilter.class);
filter("/").through(ThrowingFilter.class);
}
});
HttpServletRequest request = newFakeHttpServletRequest();
FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
pipeline.initPipeline(null);
try {
pipeline.dispatch(request, null, null);
fail("expected exception");
} catch(ServletException ex) {
for (StackTraceElement element : ex.getStackTrace()) {
String className = element.getClassName();
assertTrue("was: " + element,
!className.equals(FilterChainInvocation.class.getName())
&& !className.equals(FilterDefinition.class.getName()));
}
}
}
public final void testServletExceptionPrunesStack() throws Exception {
Injector injector = Guice.createInjector(new ServletModule() {
@Override
protected void configureServlets() {
filter("/").through(TestFilter.class);
filter("/nothing").through(TestFilter.class);
serve("/").with(ThrowingServlet.class);
}
});
HttpServletRequest request = newFakeHttpServletRequest();
FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
pipeline.initPipeline(null);
try {
pipeline.dispatch(request, null, null);
fail("expected exception");
} catch(ServletException ex) {
for (StackTraceElement element : ex.getStackTrace()) {
String className = element.getClassName();
assertTrue("was: " + element,
!className.equals(FilterChainInvocation.class.getName())
&& !className.equals(FilterDefinition.class.getName()));
}
}
}
@Singleton
private static class ThrowingServlet extends HttpServlet {
@Override
protected void service(HttpServletRequest req, HttpServletResponse resp)
throws ServletException {
throw new ServletException("failure!");
}
}
@Singleton
private static class ThrowingFilter implements Filter {
@Override
public void destroy() {
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws ServletException {
throw new ServletException("we failed!");
}
@Override
public void init(FilterConfig filterConfig) {
}
}
}