Fixing a serialization problem as reported by Ben Yu: TypeLiteral is serializable, which causes warnings in anonymous inner type literals.

git-svn-id: https://google-guice.googlecode.com/svn/trunk@619 d779f126-a31b-0410-b53b-1d3aecad763e
diff --git a/src/com/google/inject/Key.java b/src/com/google/inject/Key.java
index ab2cdf8..d20d055 100644
--- a/src/com/google/inject/Key.java
+++ b/src/com/google/inject/Key.java
@@ -21,6 +21,8 @@
 import com.google.inject.internal.Annotations;
 import com.google.inject.internal.MoreTypes;
 import com.google.inject.internal.ToStringBuilder;
+import java.io.InvalidObjectException;
+import java.io.ObjectInputStream;
 import java.io.Serializable;
 import java.lang.annotation.Annotation;
 import java.lang.reflect.Type;
@@ -503,15 +505,34 @@
   }
 
   /**
-   * Returns the canonical form of this key for serialization. The returned
-   * instance is always a {@code Key}, never a subclass. This prevents problems
-   * caused by serializing anonymous types.
+   * Serializes the key without its type literal, annotation strategy, or
+   * hash code.
    */
-  protected final Object writeReplace() {
-    return getClass() == Key.class
-        ? this
-        : new Key<T>(typeLiteral, annotationStrategy);
+  private static class SerializedForm implements Serializable {
+    final Type type;
+    final Class<? extends Annotation> annotationType;
+    final Annotation annotationInstance;
+
+    SerializedForm(Key<?> key) {
+      this.type = key.getTypeLiteral().getType();
+      this.annotationType = key.getAnnotationType();
+      this.annotationInstance = key.getAnnotation();
+    }
+
+    final Object readResolve() {
+      return annotationInstance != null ? Key.get(type, annotationInstance)
+          : annotationType != null ? Key.get(type, annotationType)
+          : Key.get(type);
+    }
+
+    private static final long serialVersionUID = 0;
   }
 
-  private static final long serialVersionUID = 0;
+  protected void readObject(ObjectInputStream stream) throws InvalidObjectException {
+    throw new InvalidObjectException("Use SerializedForm");
+  }
+
+  protected final Object writeReplace() {
+    return new SerializedForm(this);
+  }
 }
diff --git a/src/com/google/inject/TypeLiteral.java b/src/com/google/inject/TypeLiteral.java
index 05bfe28..a954561 100644
--- a/src/com/google/inject/TypeLiteral.java
+++ b/src/com/google/inject/TypeLiteral.java
@@ -16,12 +16,10 @@
 
 package com.google.inject;
 
+import static com.google.common.base.Preconditions.checkNotNull;
 import com.google.inject.internal.MoreTypes;
 import static com.google.inject.internal.MoreTypes.canonicalize;
 import com.google.inject.util.Types;
-import static com.google.common.base.Preconditions.checkNotNull;
-
-import java.io.Serializable;
 import java.lang.reflect.ParameterizedType;
 import java.lang.reflect.Type;
 
@@ -42,7 +40,7 @@
  *
  * @author crazybob@google.com (Bob Lee)
  */
-public class TypeLiteral<T> implements Serializable {
+public class TypeLiteral<T> {
 
   final Class<? super T> rawType;
   final Type type;
@@ -143,17 +141,4 @@
   public static <T> TypeLiteral<T> get(Class<T> type) {
     return new TypeLiteral<T>(type);
   }
-
-  /**
-   * Returns the canonical form of this type literal for serialization. The
-   * returned instance is always a {@code TypeLiteral}, never a subclass. This
-   * prevents problems caused by serializing anonymous types.
-   */
-  protected final Object writeReplace() {
-    return getClass() == TypeLiteral.class
-        ? this
-        : get(type);
-  }
-
-  private static final long serialVersionUID = 0;
 }
diff --git a/src/com/google/inject/spi/ModuleWriter.java b/src/com/google/inject/spi/ModuleWriter.java
index 3d43d16..28341c5 100644
--- a/src/com/google/inject/spi/ModuleWriter.java
+++ b/src/com/google/inject/spi/ModuleWriter.java
@@ -105,39 +105,39 @@
     }
   }
 
-  public void writeMessage(final Binder binder, final Message element) {
+  protected void writeMessage(final Binder binder, final Message element) {
     binder.addError(element);
   }
 
-  public void writeBindInterceptor(final Binder binder, final InterceptorBinding element) {
+  protected void writeBindInterceptor(final Binder binder, final InterceptorBinding element) {
     List<MethodInterceptor> interceptors = element.getInterceptors();
     binder.withSource(element.getSource()).bindInterceptor(
         element.getClassMatcher(), element.getMethodMatcher(),
         interceptors.toArray(new MethodInterceptor[interceptors.size()]));
   }
 
-  public void writeBindScope(final Binder binder, final ScopeBinding element) {
+  protected void writeBindScope(final Binder binder, final ScopeBinding element) {
     binder.withSource(element.getSource()).bindScope(
         element.getAnnotationType(), element.getScope());
   }
 
-  public void writeRequestInjection(final Binder binder,
+  protected void writeRequestInjection(final Binder binder,
       final InjectionRequest command) {
     binder.withSource(command.getSource()).requestInjection(command.getInstance());
   }
 
-  public void writeRequestStaticInjection(final Binder binder,
+  protected void writeRequestStaticInjection(final Binder binder,
       final StaticInjectionRequest element) {
     Class<?> type = element.getType();
     binder.withSource(element.getSource()).requestStaticInjection(type);
   }
 
-  public void writeConvertToTypes(final Binder binder, final TypeConverterBinding element) {
+  protected void writeConvertToTypes(final Binder binder, final TypeConverterBinding element) {
     binder.withSource(element.getSource())
         .convertToTypes(element.getTypeMatcher(), element.getTypeConverter());
   }
 
-  public <T> void writeBind(final Binder binder, final Binding<T> element) {
+  protected <T> void writeBind(final Binder binder, final Binding<T> element) {
     LinkedBindingBuilder<T> lbb = binder.withSource(element.getSource()).bind(element.getKey());
 
     ScopedBindingBuilder sbb = applyTarget(element, lbb);
@@ -147,7 +147,7 @@
   /**
    * Execute this target against the linked binding builder.
    */
-  public <T> ScopedBindingBuilder applyTarget(Binding<T> binding,
+  protected <T> ScopedBindingBuilder applyTarget(Binding<T> binding,
       final LinkedBindingBuilder<T> linkedBindingBuilder) {
     return binding.acceptTargetVisitor(new BindingTargetVisitor<T, ScopedBindingBuilder>() {
       public ScopedBindingBuilder visitInstance(T instance, Set<InjectionPoint> injectionPoints) {
@@ -187,7 +187,7 @@
     });
   }
 
-  public void applyScoping(Binding<?> binding, final ScopedBindingBuilder scopedBindingBuilder) {
+  protected void applyScoping(Binding<?> binding, final ScopedBindingBuilder scopedBindingBuilder) {
     binding.acceptScopingVisitor(new BindingScopingVisitor<Void>() {
       public Void visitEagerSingleton() {
         scopedBindingBuilder.asEagerSingleton();
@@ -211,7 +211,7 @@
     });
   }
 
-  public <T> void writeGetProvider(final Binder binder, final ProviderLookup<T> element) {
+  protected <T> void writeGetProvider(final Binder binder, final ProviderLookup<T> element) {
     Provider<T> provider = binder.withSource(element.getSource()).getProvider(element.getKey());
     element.initDelegate(provider);
   }
diff --git a/test/com/google/inject/Asserts.java b/test/com/google/inject/Asserts.java
index cb2ef58..8984a1e 100644
--- a/test/com/google/inject/Asserts.java
+++ b/test/com/google/inject/Asserts.java
@@ -24,6 +24,9 @@
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import junit.framework.Assert;
+import static junit.framework.Assert.assertEquals;
+import static junit.framework.Assert.assertNotNull;
+import static junit.framework.Assert.assertTrue;
 
 /**
  * @author jessewilson@google.com (Jesse Wilson)
@@ -37,11 +40,11 @@
    * for testing the equals method itself.
    */
   public static void assertEqualsBothWays(Object expected, Object actual) {
-    Assert.assertNotNull(expected);
-    Assert.assertNotNull(actual);
-    Assert.assertTrue("expected.equals(actual)", expected.equals(actual));
-    Assert.assertTrue("actual.equals(expected)", actual.equals(expected));
-    Assert.assertEquals("hashCode", expected.hashCode(), actual.hashCode());
+    assertNotNull(expected);
+    assertNotNull(actual);
+    assertTrue("expected.equals(actual)", expected.equals(actual));
+    assertTrue("actual.equals(expected)", actual.equals(expected));
+    assertEquals("hashCode", expected.hashCode(), actual.hashCode());
   }
 
   /**
@@ -49,15 +52,15 @@
    */
   public static void assertContains(String text, String... substrings) {
     int startingFrom = 0;
-    for (int i = 0; i < substrings.length; i++) {
-      int index = text.indexOf(substrings[i], startingFrom);
-      Assert.assertTrue(String.format("Expected \"%s\" to contain substring \"%s\"",
-          text, substrings[i]), index >= startingFrom);
-      startingFrom = index + substrings[i].length();
+    for (String substring : substrings) {
+      int index = text.indexOf(substring, startingFrom);
+      assertTrue(String.format("Expected \"%s\" to contain substring \"%s\"", text, substring),
+          index >= startingFrom);
+      startingFrom = index + substring.length();
     }
 
     String lastSubstring = substrings[substrings.length - 1];
-    Assert.assertTrue(String.format("Expected \"%s\" to contain substring \"%s\" only once),",
+    assertTrue(String.format("Expected \"%s\" to contain substring \"%s\" only once),",
         text, lastSubstring), text.indexOf(lastSubstring, startingFrom) == -1);
   }
 
@@ -66,42 +69,33 @@
    */
   public static void assertEqualWhenReserialized(Object object)
       throws IOException {
-    Assert.assertTrue("Expected serialVersionUID", hasSerialVersionUid(object));
     Object reserialized = reserialize(object);
-    Assert.assertEquals(object, reserialized);
-    Assert.assertEquals(object.hashCode(), reserialized.hashCode());
+    assertEquals(object, reserialized);
+    assertEquals(object.hashCode(), reserialized.hashCode());
   }
 
   /**
    * Fails unless {@code object} has the same toString value when reserialized.
    */
   public static void assertSimilarWhenReserialized(Object object) throws IOException {
-    Assert.assertTrue("Expected serialVersionUID", hasSerialVersionUid(object));
     Object reserialized = reserialize(object);
-    Assert.assertEquals(object.toString(), reserialized.toString());
+    assertEquals(object.toString(), reserialized.toString());
   }
 
-  static boolean hasSerialVersionUid(Object object) {
-    try {
-      return null != object.getClass().getDeclaredField("serialVersionUID");
-    } catch (NoSuchFieldException e) {
-      return false;
-    }
-  }
-
-  static Object reserialize(Object object) throws IOException {
+  public static <E> E reserialize(E original) throws IOException {
     try {
       ByteArrayOutputStream out = new ByteArrayOutputStream();
-      new ObjectOutputStream(out).writeObject(object);
+      new ObjectOutputStream(out).writeObject(original);
       ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
-      return new ObjectInputStream(in).readObject();
+      @SuppressWarnings("unchecked") // the reserialized type is assignable
+      E reserialized = (E) new ObjectInputStream(in).readObject();
+      return reserialized;
     } catch (ClassNotFoundException e) {
       throw new RuntimeException(e);
     }
   }
 
   public static void assertNotSerializable(Object object) throws IOException {
-    Assert.assertFalse("Unexpected serialVersionUID", hasSerialVersionUid(object));
     try {
       reserialize(object);
       Assert.fail();
diff --git a/test/com/google/inject/KeyTest.java b/test/com/google/inject/KeyTest.java
index 9ccec28..3fab894 100644
--- a/test/com/google/inject/KeyTest.java
+++ b/test/com/google/inject/KeyTest.java
@@ -129,7 +129,7 @@
     assertEqualWhenReserialized(Key.get(B[].class));
     assertEqualWhenReserialized(Key.get(new TypeLiteral<Map<List<B>, B>>() {}));
     assertEqualWhenReserialized(Key.get(new TypeLiteral<List<B[]>>() {}));
-    assertEqualWhenReserialized(new Key<List<B[]>>() {});
+    assertEquals(new Key<List<B[]>>() {}, Asserts.reserialize(new Key<List<B[]>>() {}));
     assertEqualWhenReserialized(Key.get(Types.listOf(Types.subtypeOf(CharSequence.class))));
   }
 
diff --git a/test/com/google/inject/TypeLiteralTest.java b/test/com/google/inject/TypeLiteralTest.java
index 903a3a2..2d4237b 100644
--- a/test/com/google/inject/TypeLiteralTest.java
+++ b/test/com/google/inject/TypeLiteralTest.java
@@ -16,8 +16,8 @@
 
 package com.google.inject;
 
-import static com.google.inject.Asserts.assertEqualWhenReserialized;
 import static com.google.inject.Asserts.assertEqualsBothWays;
+import static com.google.inject.Asserts.assertNotSerializable;
 import com.google.inject.util.Types;
 import java.io.IOException;
 import java.util.List;
@@ -61,8 +61,8 @@
     assertEqualsBothWays(a, b);
     assertEquals("java.util.List<? extends java.lang.CharSequence>", a.toString());
     assertEquals("java.util.List<? extends java.lang.CharSequence>", b.toString());
-    assertEqualWhenReserialized(a);
-    assertEqualWhenReserialized(b);
+    assertNotSerializable(a);
+    assertNotSerializable(b);
   }
 
   public void testMissingTypeParameter() {
@@ -127,6 +127,6 @@
   }
 
   public void testSerialization() throws IOException {
-    assertEqualWhenReserialized(new TypeLiteral<List<String>>() {});
+    assertNotSerializable(new TypeLiteral<List<String>>() {});
   }
 }