Isaac's patch to fix scoping for null values.

git-svn-id: https://google-guice.googlecode.com/svn/trunk@1061 d779f126-a31b-0410-b53b-1d3aecad763e
diff --git a/servlet/src/com/google/inject/servlet/ServletScopes.java b/servlet/src/com/google/inject/servlet/ServletScopes.java
index c907946..b3da2e4 100644
--- a/servlet/src/com/google/inject/servlet/ServletScopes.java
+++ b/servlet/src/com/google/inject/servlet/ServletScopes.java
@@ -19,6 +19,7 @@
 import com.google.inject.Key;
 import com.google.inject.Provider;
 import com.google.inject.Scope;
+
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpSession;
 
@@ -31,6 +32,9 @@
 
   private ServletScopes() {}
 
+  /** A sentinel attribute value representing null. */
+  enum NullObject { INSTANCE }
+
   /**
    * HTTP servlet request scope.
    */
@@ -41,11 +45,15 @@
         public T get() {
           HttpServletRequest request = GuiceFilter.getRequest();
           synchronized (request) {
+            Object obj = request.getAttribute(name);
+            if (NullObject.INSTANCE == obj) {
+              return null;
+            }
             @SuppressWarnings("unchecked")
-            T t = (T) request.getAttribute(name);
+            T t = (T) obj;
             if (t == null) {
               t = creator.get();
-              request.setAttribute(name, t);
+              request.setAttribute(name, (t != null) ? t : NullObject.INSTANCE);
             }
             return t;
           }
@@ -72,11 +80,15 @@
         public T get() {
           HttpSession session = GuiceFilter.getRequest().getSession();
           synchronized (session) {
+            Object obj = session.getAttribute(name);
+            if (NullObject.INSTANCE == obj) {
+              return null;
+            }
             @SuppressWarnings("unchecked")
-            T t = (T) session.getAttribute(name);
+            T t = (T) obj;
             if (t == null) {
               t = creator.get();
-              session.setAttribute(name, t);
+              session.setAttribute(name, (t != null) ? t : NullObject.INSTANCE);
             }
             return t;
           }
diff --git a/servlet/test/com/google/inject/servlet/ServletTest.java b/servlet/test/com/google/inject/servlet/ServletTest.java
index a3d863f..cb6ebfa 100644
--- a/servlet/test/com/google/inject/servlet/ServletTest.java
+++ b/servlet/test/com/google/inject/servlet/ServletTest.java
@@ -16,23 +16,35 @@
 
 package com.google.inject.servlet;
 
-
 import com.google.inject.AbstractModule;
+import static com.google.inject.Asserts.reserialize;
+import com.google.inject.BindingAnnotation;
 import com.google.inject.CreationException;
 import com.google.inject.Guice;
 import com.google.inject.Injector;
 import com.google.inject.Key;
-
-import junit.framework.TestCase;
-
+import com.google.inject.internal.Maps;
+import static com.google.inject.servlet.ServletScopes.NullObject;
+import com.google.inject.util.Providers;
 import java.io.IOException;
+import java.io.Serializable;
+import static java.lang.annotation.ElementType.FIELD;
+import static java.lang.annotation.ElementType.METHOD;
+import static java.lang.annotation.ElementType.PARAMETER;
+import java.lang.annotation.Retention;
+import static java.lang.annotation.RetentionPolicy.RUNTIME;
+import java.lang.annotation.Target;
+import java.lang.reflect.InvocationHandler;
+import java.lang.reflect.Method;
+import java.lang.reflect.Proxy;
+import java.util.Map;
 import javax.servlet.FilterChain;
 import javax.servlet.ServletException;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpSession;
-
+import junit.framework.TestCase;
 import static org.easymock.EasyMock.createMock;
 import static org.easymock.EasyMock.eq;
 import static org.easymock.EasyMock.expect;
@@ -44,6 +56,10 @@
  * @author crazybob@google.com (Bob Lee)
  */
 public class ServletTest extends TestCase {
+  private static final Key<InRequest> IN_REQUEST_KEY = Key.get(InRequest.class);
+  private static final Key<InRequest> IN_REQUEST_NULL_KEY = Key.get(InRequest.class, Null.class);
+  private static final Key<InSession> IN_SESSION_KEY = Key.get(InSession.class);
+  private static final Key<InSession> IN_SESSION_NULL_KEY = Key.get(InSession.class, Null.class);
 
   @Override
   public void setUp() {
@@ -59,9 +75,13 @@
 
     final HttpServletRequest request = createMock(HttpServletRequest.class);
 
-    String name = Key.get(InRequest.class).toString();
-    expect(request.getAttribute(name)).andReturn(null);
-    request.setAttribute(eq(name), isA(InRequest.class));
+    String inRequestKey = IN_REQUEST_KEY.toString();
+    expect(request.getAttribute(inRequestKey)).andReturn(null);
+    request.setAttribute(eq(inRequestKey), isA(InRequest.class));
+    
+    String inRequestNullKey = IN_REQUEST_NULL_KEY.toString();
+    expect(request.getAttribute(inRequestNullKey)).andReturn(null);
+    request.setAttribute(eq(inRequestNullKey), eq(NullObject.INSTANCE));
 
     final boolean[] invoked = new boolean[1];
     FilterChain filterChain = new FilterChain() {
@@ -70,6 +90,7 @@
         invoked[0] = true;
 //        assertSame(request, servletRequest);
         assertNotNull(injector.getInstance(InRequest.class));
+        assertNull(injector.getInstance(IN_REQUEST_NULL_KEY));
       }
     };
 
@@ -90,8 +111,11 @@
     final HttpServletRequest request = createMock(HttpServletRequest.class);
 
     final InRequest inRequest = new InRequest();
-    String name = Key.get(InRequest.class).toString();
-    expect(request.getAttribute(name)).andReturn(inRequest).times(2);
+    String inRequestKey = IN_REQUEST_KEY.toString();
+    expect(request.getAttribute(inRequestKey)).andReturn(inRequest).times(2);
+    
+    String inRequestNullKey = IN_REQUEST_NULL_KEY.toString();
+    expect(request.getAttribute(inRequestNullKey)).andReturn(NullObject.INSTANCE).times(2);
 
     final boolean[] invoked = new boolean[1];
     FilterChain filterChain = new FilterChain() {
@@ -101,6 +125,9 @@
         
         assertSame(inRequest, injector.getInstance(InRequest.class));
         assertSame(inRequest, injector.getInstance(InRequest.class));
+
+        assertNull(injector.getInstance(IN_REQUEST_NULL_KEY));
+        assertNull(injector.getInstance(IN_REQUEST_NULL_KEY));
       }
     };
 
@@ -121,11 +148,15 @@
     final HttpServletRequest request = createMock(HttpServletRequest.class);
     final HttpSession session = createMock(HttpSession.class);
 
-    String name = Key.get(InSession.class).toString();
+    String inSessionKey = IN_SESSION_KEY.toString();
+    String inSessionNullKey = IN_SESSION_NULL_KEY.toString();
 
-    expect(request.getSession()).andReturn(session);
-    expect(session.getAttribute(name)).andReturn(null);
-    session.setAttribute(eq(name), isA(InSession.class));
+    expect(request.getSession()).andReturn(session).times(2);
+    expect(session.getAttribute(inSessionKey)).andReturn(null);
+    session.setAttribute(eq(inSessionKey), isA(InSession.class));
+
+    expect(session.getAttribute(inSessionNullKey)).andReturn(null);
+    session.setAttribute(eq(inSessionNullKey), eq(NullObject.INSTANCE));
 
     final boolean[] invoked = new boolean[1];
     FilterChain filterChain = new FilterChain() {
@@ -134,6 +165,7 @@
         invoked[0] = true;
 //        assertSame(request, servletRequest);
         assertNotNull(injector.getInstance(InSession.class));
+        assertNull(injector.getInstance(IN_SESSION_NULL_KEY));
       }
     };
 
@@ -154,11 +186,14 @@
     final HttpServletRequest request = createMock(HttpServletRequest.class);
     final HttpSession session = createMock(HttpSession.class);
 
-    String name = Key.get(InSession.class).toString();
+    String inSessionKey = IN_SESSION_KEY.toString();
+    String inSessionNullKey = IN_SESSION_NULL_KEY.toString();
 
     final InSession inSession = new InSession();
-    expect(request.getSession()).andReturn(session).times(2);
-    expect(session.getAttribute(name)).andReturn(inSession).times(2);
+    expect(request.getSession()).andReturn(session).times(4);
+    expect(session.getAttribute(inSessionKey)).andReturn(inSession).times(2);
+    
+    expect(session.getAttribute(inSessionNullKey)).andReturn(NullObject.INSTANCE).times(2);
 
     final boolean[] invoked = new boolean[1];
     FilterChain filterChain = new FilterChain() {
@@ -169,6 +204,9 @@
 
         assertSame(inSession, injector.getInstance(InSession.class));
         assertSame(inSession, injector.getInstance(InSession.class));
+
+        assertNull(injector.getInstance(IN_SESSION_NULL_KEY));
+        assertNull(injector.getInstance(IN_SESSION_NULL_KEY));
       }
     };
 
@@ -180,6 +218,67 @@
     assertTrue(invoked[0]);
   }
 
+  public void testHttpSessionIsSerializable()
+      throws IOException, ClassNotFoundException, ServletException {
+    final Injector injector = createInjector();
+
+    GuiceFilter filter = new GuiceFilter();
+
+    final HttpServletRequest request = createMock(HttpServletRequest.class);
+    final HttpSession session = newFakeHttpSession();
+
+    String inSessionKey = IN_SESSION_KEY.toString();
+    String inSessionNullKey = IN_SESSION_NULL_KEY.toString();
+
+    expect(request.getSession()).andReturn(session).times(2);
+
+    final boolean[] invoked = new boolean[1];
+    FilterChain filterChain = new FilterChain() {
+      public void doFilter(ServletRequest servletRequest,
+          ServletResponse servletResponse) {
+        invoked[0] = true;
+        assertNotNull(injector.getInstance(InSession.class));
+        assertNull(injector.getInstance(IN_SESSION_NULL_KEY));
+      }
+    };
+
+    replay(request);
+
+    filter.doFilter(request, null, filterChain);
+
+    verify(request);
+    assertTrue(invoked[0]);
+
+    HttpSession deserializedSession = reserialize(session);
+
+    assertTrue(deserializedSession.getAttribute(inSessionKey) instanceof InSession);
+    assertEquals(NullObject.INSTANCE, deserializedSession.getAttribute(inSessionNullKey));
+  }
+
+  private static class FakeHttpSessionHandler implements InvocationHandler, Serializable {
+    final Map<String, Object> attributes = Maps.newHashMap();
+
+    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
+      String name = method.getName();
+      if ("setAttribute".equals(name)) {
+        attributes.put((String) args[0], args[1]);
+        return null;
+      } else if ("getAttribute".equals(name)) {
+        return attributes.get(args[0]);
+      } else {
+        throw new UnsupportedOperationException();
+      }
+    }
+  }
+
+  /**
+   * Returns a fake, serializable HttpSession which stores attributes in a HashMap.
+   */
+  private HttpSession newFakeHttpSession() {
+    return (HttpSession) Proxy.newProxyInstance(HttpSession.class.getClassLoader(),
+        new Class[] { HttpSession.class }, new FakeHttpSessionHandler());
+  }
+
   private Injector createInjector() throws CreationException {
 
     return Guice.createInjector(new AbstractModule() {
@@ -188,14 +287,19 @@
       protected void configure() {
         install(new ServletModule());
         bind(InSession.class);
+        bind(IN_SESSION_NULL_KEY).toProvider(Providers.<InSession>of(null)).in(SessionScoped.class);
         bind(InRequest.class);
+        bind(IN_REQUEST_NULL_KEY).toProvider(Providers.<InRequest>of(null)).in(RequestScoped.class);
       }
     });
   }
 
   @SessionScoped
-  static class InSession {}
+  static class InSession implements Serializable {}
 
   @RequestScoped
   static class InRequest {}
+
+  @BindingAnnotation @Retention(RUNTIME) @Target({PARAMETER, METHOD, FIELD})
+  @interface Null {}
 }
diff --git a/src/com/google/inject/Scopes.java b/src/com/google/inject/Scopes.java
index bb4316f..9e03334 100644
--- a/src/com/google/inject/Scopes.java
+++ b/src/com/google/inject/Scopes.java
@@ -19,6 +19,7 @@
 import com.google.inject.internal.InjectorBuilder;
 import com.google.inject.internal.LinkedBindingImpl;
 import com.google.inject.spi.BindingScopingVisitor;
+
 import java.lang.annotation.Annotation;
 
 /**
@@ -30,14 +31,20 @@
 
   private Scopes() {}
 
+  /** A sentinel value representing null. */
+  private static final Object NULL = new Object();
+
   /**
    * One instance per {@link Injector}. Also see {@code @}{@link Singleton}.
    */
   public static final Scope SINGLETON = new Scope() {
     public <T> Provider<T> scope(Key<T> key, final Provider<T> creator) {
       return new Provider<T>() {
-
-        private volatile T instance;
+        /*
+         * The lazily initialized singleton instance. Once set, this will either have type T or will
+         * be equal to NULL.
+         */
+        private volatile Object instance;
 
         // DCL on a volatile is safe as of Java 5, which we obviously require.
         @SuppressWarnings("DoubleCheckedLocking")
@@ -51,11 +58,16 @@
              */
             synchronized (InjectorBuilder.class) {
               if (instance == null) {
-                instance = creator.get();
+                T nullableInstance = creator.get();
+                instance = (nullableInstance != null) ? nullableInstance : NULL;
               }
             }
           }
-          return instance;
+          Object localInstance = instance;
+          // This is safe because instance has type T or is equal to NULL
+          @SuppressWarnings("unchecked")
+          T returnedInstance = (localInstance != NULL) ? (T) localInstance : null;
+          return returnedInstance;
         }
 
         public String toString() {