Support for custom request scopes overloaded onto the @RequestScoped annotation. 

Also added ability to seed a scope map for both continuing HTTP request scopes as well as custom request scopes. 

Also changed continuing HTTP requests in other request threads to fail if they happen to run in an HTTP request thread as per sberlin's recommendation. See tests for details

git-svn-id: https://google-guice.googlecode.com/svn/trunk@1254 d779f126-a31b-0410-b53b-1d3aecad763e
diff --git a/extensions/servlet/src/com/google/inject/servlet/ServletScopes.java b/extensions/servlet/src/com/google/inject/servlet/ServletScopes.java
index 8b32c34..54eaa37 100644
--- a/extensions/servlet/src/com/google/inject/servlet/ServletScopes.java
+++ b/extensions/servlet/src/com/google/inject/servlet/ServletScopes.java
@@ -20,6 +20,9 @@
 import com.google.inject.OutOfScopeException;
 import com.google.inject.Provider;
 import com.google.inject.Scope;
+import com.google.inject.internal.util.Maps;
+import com.google.inject.internal.util.Preconditions;
+import java.util.Map;
 import java.util.concurrent.Callable;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpSession;
@@ -44,7 +47,36 @@
       final String name = key.toString();
       return new Provider<T>() {
         public T get() {
+          // Should we use the alternate request scope, if no http request
+          // is in progress?
+          if (null == GuiceFilter.localContext.get()) {
+
+            // NOTE(dhanji): We don't need to synchronize on the scope map
+            // unlike the HTTP request because we're the only ones who have
+            // a reference to it, and it is only available via a threadlocal.
+            Map<String, Object> scopeMap = requestScopeContext.get();
+            if (null != scopeMap) {
+              @SuppressWarnings("unchecked")
+              T t = (T) scopeMap.get(name);
+
+              // Accounts for @Nullable providers.
+              if (NullObject.INSTANCE == t) {
+                return null;
+              }
+
+              if (t == null) {
+                t = creator.get();
+                // Store a sentinel for provider-given null values.
+                scopeMap.put(name, t != null ? t : NullObject.INSTANCE);
+              }
+
+              return t;
+            } // else: fall into normal HTTP request scope and out of scope
+              // exception is thrown.
+          }
+
           HttpServletRequest request = GuiceFilter.getRequest();
+
           synchronized (request) {
             Object obj = request.getAttribute(name);
             if (NullObject.INSTANCE == obj) {
@@ -129,22 +161,32 @@
    * @throws OutOfScopeException if this method is called from a non-request
    *     thread, or if the request has completed.
    */
-  public static <T> Callable<T> continueRequest(final Callable<T> callable) {
+  public static <T> Callable<T> continueRequest(final Callable<T> callable,
+      final Map<Key<?>, Object> seedMap) {
+    Preconditions.checkArgument(null != seedMap,
+        "Seed map cannot be null, try passing in Collections.emptyMap() instead.");
+    
+    // Snapshot the seed map and add all the instances to our continuing HTTP request.
+    final ContinuingHttpServletRequest continuingRequest =
+        new ContinuingHttpServletRequest(GuiceFilter.getRequest());
+    for (Map.Entry<Key<?>, Object> entry : seedMap.entrySet()) {
+      continuingRequest.setAttribute(entry.getKey().toString(), entry.getValue());
+    }
+
     return new Callable<T>() {
-      private HttpServletRequest request =
-          new ContinuingHttpServletRequest(GuiceFilter.getRequest());
+      private HttpServletRequest request = continuingRequest;
 
       public T call() throws Exception {
         GuiceFilter.Context context = GuiceFilter.localContext.get();
-        if (null == context) {
-          // Only set up the request continuation if we're running in a
-          // new vanilla thread.
-          GuiceFilter.localContext.set(new GuiceFilter.Context(request, null));
-        }
+        Preconditions.checkState(null == context,
+            "Cannot continue request in the same thread as a HTTP request!");
+
+        // Only set up the request continuation if we're running in a
+        // new vanilla thread.
+        GuiceFilter.localContext.set(new GuiceFilter.Context(request, null));
         try {
           return callable.call();
         } finally {
-
           // Clear the copied context if we set one up.
           if (null == context) {
             GuiceFilter.localContext.remove();
@@ -153,4 +195,55 @@
       }
     };
   }
+
+  /**
+   * A threadlocal scope map for non-http request scopes. The {@link #REQUEST}
+   * scope falls back to this scope map if no http request is available, and
+   * requires {@link #scopeRequest} to be called as an alertnative.
+   */
+  private static final ThreadLocal<Map<String, Object>> requestScopeContext
+      = new ThreadLocal<Map<String, Object>>();
+
+  /**
+   * Scopes the given callable inside a request scope. This is not the same
+   * as the HTTP request scope, but is used if no HTTP request scope is in
+   * progress. In this way, keys can be scoped as @RequestScoped and exist
+   * in non-HTTP requests (for example: RPC requests) as well as in HTTP
+   * request threads.
+   *
+   * @param callable code to be executed which depends on the request scope.
+   *     Typically in another thread, but not necessarily so.
+   * @param seedMap the initial set of scoped instances for Guice to seed the
+   *     request scope with.
+   * @return a callable that when called will run inside the a request scope
+   *     that exposes the instances in the {@code seedMap} as scoped keys.
+   */
+  public static <T> Callable<T> scopeRequest(final Callable<T> callable,
+      Map<Key<?>, Object> seedMap) {
+    Preconditions.checkArgument(null != seedMap,
+        "Seed map cannot be null, try passing in Collections.emptyMap() instead.");
+
+    // Copy the seed values into our local scope map.
+    final Map<String, Object> scopeMap = Maps.newHashMap();
+    for (Map.Entry<Key<?>, Object> entry : seedMap.entrySet()) {
+      scopeMap.put(entry.getKey().toString(), entry.getValue());
+    }
+
+    return new Callable<T>() {
+      public T call() throws Exception {
+        Preconditions.checkState(null == GuiceFilter.localContext.get(),
+            "An HTTP request is already in progress, cannot scope a new request in this thread.");
+        Preconditions.checkState(null == requestScopeContext.get(),
+            "A request scope is already in progress, cannot scope a new request in this thread.");
+
+        requestScopeContext.set(scopeMap);
+
+        try {
+          return callable.call();
+        } finally {
+          requestScopeContext.remove();
+        }
+      }
+    };
+  }
 }
diff --git a/extensions/servlet/test/com/google/inject/servlet/ContinuingRequestIntegrationTest.java b/extensions/servlet/test/com/google/inject/servlet/ContinuingRequestIntegrationTest.java
index 4b2de40..1e1811c 100644
--- a/extensions/servlet/test/com/google/inject/servlet/ContinuingRequestIntegrationTest.java
+++ b/extensions/servlet/test/com/google/inject/servlet/ContinuingRequestIntegrationTest.java
@@ -19,9 +19,11 @@
 import com.google.inject.Guice;
 import com.google.inject.Inject;
 import com.google.inject.Injector;
+import com.google.inject.Key;
 import com.google.inject.Provider;
 import com.google.inject.Singleton;
 import com.google.inject.internal.util.ImmutableList;
+import com.google.inject.internal.util.ImmutableMap;
 import java.io.IOException;
 import java.util.List;
 import java.util.concurrent.AbstractExecutorService;
@@ -30,6 +32,7 @@
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import javax.servlet.FilterChain;
 import javax.servlet.FilterConfig;
 import javax.servlet.ServletContext;
@@ -38,9 +41,7 @@
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import junit.framework.TestCase;
-import static org.easymock.EasyMock.anyObject;
 import static org.easymock.EasyMock.createMock;
-import static org.easymock.EasyMock.eq;
 import static org.easymock.EasyMock.expect;
 import static org.easymock.EasyMock.replay;
 import static org.easymock.EasyMock.verify;
@@ -52,7 +53,8 @@
   private static final String PARAM_VALUE = "there";
   private static final String PARAM_NAME = "hi";
 
-  private static final AbstractExecutorService SAME_THREAD_EXECUTOR = new AbstractExecutorService() {
+  private final AtomicBoolean failed = new AtomicBoolean(false);
+  private final AbstractExecutorService sameThreadExecutor = new AbstractExecutorService() {
     public void shutdown() {
     }
 
@@ -79,8 +81,11 @@
     @Override public <T> Future<T> submit(Callable<T> task) {
       try {
         task.call();
+        fail();
       } catch (Exception e) {
-        throw new RuntimeException(e);
+        // Expected.
+        assertTrue(e instanceof IllegalStateException);
+        failed.set(true);
       }
 
       return null;
@@ -88,12 +93,17 @@
   };
 
   private ExecutorService executor;
+  private Injector injector;
+
+  @Override protected void tearDown() throws Exception {
+    injector.getInstance(GuiceFilter.class).destroy();
+  }
 
   public final void testRequestContinuesInOtherThread()
       throws ServletException, IOException, InterruptedException {
     executor = Executors.newSingleThreadExecutor();
 
-    Injector injector = Guice.createInjector(new ServletModule() {
+    injector = Guice.createInjector(new ServletModule() {
       @Override protected void configureServlets() {
         serve("/*").with(ContinuingServlet.class);
 
@@ -127,14 +137,16 @@
     verify(request, filterConfig, filterChain);
   }
 
-  public final void testRequestContinuesInSameThread()
+  public final void testRequestContinuationDiesInHttpRequestThread()
       throws ServletException, IOException, InterruptedException {
-    executor = SAME_THREAD_EXECUTOR;
-    Injector injector = Guice.createInjector(new ServletModule() {
+    executor = sameThreadExecutor;
+    injector = Guice.createInjector(new ServletModule() {
       @Override protected void configureServlets() {
         serve("/*").with(ContinuingServlet.class);
 
         bind(ExecutorService.class).toInstance(executor);
+
+        bind(SomeObject.class);
       }
     });
 
@@ -145,23 +157,10 @@
 
     HttpServletRequest request = createMock(HttpServletRequest.class);
 
-    // this time it will try to get it from the scope, because its same-thread.
-    // This is part of Isaac's patch that enabled request-scoping of the request.
-    expect(request.getAttribute("Key[type=javax.servlet.http.HttpServletResponse, annotation=[none]]"))
-        .andReturn(null);
-    request.setAttribute(eq("Key[type=javax.servlet.http.HttpServletResponse, annotation=[none]]"),
-        anyObject());
-    expect(request.getAttribute("Key[type=javax.servlet.http.HttpServletRequest, annotation=[none]]"))
-        .andReturn(null);
-    request.setAttribute(eq("Key[type=javax.servlet.http.HttpServletRequest, annotation=[none]]"),
-        anyObject());
-
     expect(request.getServletPath()).andReturn("/");
     expect(request.getMethod()).andReturn("GET");
-
     FilterChain filterChain = createMock(FilterChain.class);
-    expect(request.getParameter(PARAM_NAME)).andReturn(PARAM_VALUE);
-
+    
     replay(request, filterConfig, filterChain);
 
     guiceFilter.init(filterConfig);
@@ -171,19 +170,31 @@
     executor.shutdown();
     executor.awaitTermination(10, TimeUnit.SECONDS);
 
-    assertEquals(PARAM_VALUE, injector.getInstance(OffRequestCallable.class).value);
+    assertTrue(failed.get());
+    assertFalse(PARAM_VALUE.equals(injector.getInstance(OffRequestCallable.class).value));
 
     verify(request, filterConfig, filterChain);
   }
 
+  @RequestScoped
+  public static class SomeObject {
+  }
+
   @Singleton
   public static class ContinuingServlet extends HttpServlet {
     @Inject OffRequestCallable callable;
     @Inject ExecutorService executorService;
 
+    private SomeObject someObject;
+
     @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp)
         throws ServletException, IOException {
-      Callable<String> task = ServletScopes.continueRequest(callable);
+      assertNull(someObject);
+
+      // Seed with someobject.
+      someObject = new SomeObject();
+      Callable<String> task = ServletScopes.continueRequest(callable,
+          ImmutableMap.<Key<?>, Object>of(Key.get(SomeObject.class), someObject));
 
       executorService.submit(task);
     }
@@ -193,12 +204,16 @@
   public static class OffRequestCallable implements Callable<String> {
     @Inject Provider<HttpServletRequest> request;
     @Inject Provider<HttpServletResponse> response;
+    @Inject Provider<SomeObject> someObject;
 
     public String value;
 
     public String call() throws Exception {
       assertNull(response.get());
 
+      // Inside this request, we should always get the same instance.
+      assertSame(someObject.get(), someObject.get());
+
       return value = request.get().getParameter(PARAM_NAME);
     }
   }
diff --git a/extensions/servlet/test/com/google/inject/servlet/ScopeRequestIntegrationTest.java b/extensions/servlet/test/com/google/inject/servlet/ScopeRequestIntegrationTest.java
new file mode 100644
index 0000000..3f67ba1
--- /dev/null
+++ b/extensions/servlet/test/com/google/inject/servlet/ScopeRequestIntegrationTest.java
@@ -0,0 +1,111 @@
+/**
+ * Copyright (C) 2010 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.inject.servlet;
+
+import com.google.inject.Guice;
+import com.google.inject.Inject;
+import com.google.inject.Injector;
+import com.google.inject.Key;
+import com.google.inject.Provider;
+import com.google.inject.Singleton;
+import com.google.inject.internal.util.ImmutableMap;
+import com.google.inject.name.Named;
+import com.google.inject.name.Names;
+import java.io.IOException;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import javax.servlet.ServletException;
+import junit.framework.TestCase;
+
+/**
+ * Tests continuation of requests
+ */
+public class ScopeRequestIntegrationTest extends TestCase {
+  private static final String A_VALUE = "thereaoskdao";
+  private static final String A_DIFFERENT_VALUE = "hiaoskd";
+
+  private static final String SHOULDNEVERBESEEN = "Shouldneverbeseen!";
+
+  public final void testNonHttpRequestScopedCallable()
+      throws ServletException, IOException, InterruptedException, ExecutionException {
+    ExecutorService executor = Executors.newSingleThreadExecutor();
+
+    // We use servlet module here because we want to test that @RequestScoped
+    // behaves properly with the non-HTTP request scope logic.
+    Injector injector = Guice.createInjector(new ServletModule() {
+      @Override protected void configureServlets() {
+        bindConstant().annotatedWith(Names.named(SomeObject.INVALID)).to(SHOULDNEVERBESEEN);
+        bind(SomeObject.class).in(RequestScoped.class);
+      }
+    });
+
+    SomeObject someObject = new SomeObject(A_VALUE);
+    OffRequestCallable offRequestCallable = injector.getInstance(OffRequestCallable.class);
+    executor.submit(ServletScopes.scopeRequest(offRequestCallable,
+        ImmutableMap.<Key<?>, Object>of(Key.get(SomeObject.class), someObject))).get();
+
+    assertSame(injector.getInstance(OffRequestCallable.class), offRequestCallable);
+
+    // Make sure the value was passed on.
+    assertEquals(someObject.value, offRequestCallable.value);
+    assertFalse(SHOULDNEVERBESEEN.equals(someObject.value));
+
+    // Now create a new request and assert that the scopes don't cross.
+    someObject = new SomeObject(A_DIFFERENT_VALUE);
+    executor.submit(ServletScopes.scopeRequest(offRequestCallable,
+        ImmutableMap.<Key<?>, Object>of(Key.get(SomeObject.class), someObject))).get();
+
+    assertSame(injector.getInstance(OffRequestCallable.class), offRequestCallable);
+
+    // Make sure the value was passed on.
+    assertEquals(someObject.value, offRequestCallable.value);
+    assertFalse(SHOULDNEVERBESEEN.equals(someObject.value));
+    executor.shutdown();
+    executor.awaitTermination(2, TimeUnit.SECONDS);
+  }
+
+  @RequestScoped
+  public static class SomeObject {
+    private static final String INVALID = "invalid";
+
+    @Inject
+    public SomeObject(@Named(INVALID) String value) {
+      this.value = value;
+    }
+    private final String value;
+  }
+
+  @Singleton
+  public static class OffRequestCallable implements Callable<String> {
+    @Inject Provider<SomeObject> someObject;
+
+    public String value;
+
+    public String call() throws Exception {
+      // Inside this request, we should always get the same instance.
+      assertSame(someObject.get(), someObject.get());
+
+      value = someObject.get().value;
+      assertFalse(SHOULDNEVERBESEEN.equals(value));
+
+      return value;
+    }
+  }
+}