Fix bug in TreeMap. All methods that returns Entrys must return immutable Entrys, except entrySet(), which returns a set of mutable Entrys. Currently, the Entrys from entrySet() are immutable. Adds some new unit tests to verify this behavior.

Change-Id: I1149185412cd60d8e9c888179c43f5bef5057a69
diff --git a/luni/src/main/java/java/util/TreeMap.java b/luni/src/main/java/java/util/TreeMap.java
index a2362b9..11fe4df 100644
--- a/luni/src/main/java/java/util/TreeMap.java
+++ b/luni/src/main/java/java/util/TreeMap.java
@@ -584,10 +584,10 @@
      */
 
     public Entry<K, V> firstEntry() {
-        return root == null ? null : root.first();
+        return new SimpleImmutableEntry<K, V>(root == null ? null : root.first());
     }
 
-    public Entry<K, V> pollFirstEntry() {
+    private Entry<K, V> internalPollFirstEntry() {
         if (root == null) {
             return null;
         }
@@ -596,6 +596,10 @@
         return result;
     }
 
+    public Entry<K, V> pollFirstEntry() {
+        return new SimpleImmutableEntry<K, V>(internalPollFirstEntry());
+    }
+
     public K firstKey() {
         if (root == null) {
             throw new NoSuchElementException();
@@ -604,10 +608,10 @@
     }
 
     public Entry<K, V> lastEntry() {
-        return root == null ? null : root.last();
+        return new SimpleImmutableEntry<K, V>(root == null ? null : root.last());
     }
 
-    public Entry<K, V> pollLastEntry() {
+    private Entry<K, V> internalPollLastEntry() {
         if (root == null) {
             return null;
         }
@@ -616,6 +620,10 @@
         return result;
     }
 
+    public Entry<K, V> pollLastEntry() {
+        return new SimpleImmutableEntry<K, V>(internalPollLastEntry());
+    }
+
     public K lastKey() {
         if (root == null) {
             throw new NoSuchElementException();
@@ -624,7 +632,7 @@
     }
 
     public Entry<K, V> lowerEntry(K key) {
-        return find(key, LOWER);
+        return new SimpleImmutableEntry<K, V>(find(key, LOWER));
     }
 
     public K lowerKey(K key) {
@@ -633,7 +641,7 @@
     }
 
     public Entry<K, V> floorEntry(K key) {
-        return find(key, FLOOR);
+        return new SimpleImmutableEntry<K, V>(find(key, FLOOR));
     }
 
     public K floorKey(K key) {
@@ -642,7 +650,7 @@
     }
 
     public Entry<K, V> ceilingEntry(K key) {
-        return find(key, CEILING);
+        return new SimpleImmutableEntry<K, V>(find(key, CEILING));
     }
 
     public K ceilingKey(K key) {
@@ -651,7 +659,7 @@
     }
 
     public Entry<K, V> higherEntry(K key) {
-        return find(key, HIGHER);
+        return new SimpleImmutableEntry<K, V>(find(key, HIGHER));
     }
 
     public K higherKey(K key) {
@@ -757,7 +765,9 @@
         }
 
         public V setValue(V value) {
-            throw new UnsupportedOperationException(); // per the spec
+            V oldValue = this.value;
+            this.value = value;
+            return oldValue;
         }
 
         @Override public boolean equals(Object o) {
@@ -908,7 +918,7 @@
         }
 
         @Override public Iterator<Entry<K, V>> iterator() {
-            return new MapIterator<Entry<K, V>>((Node<K, V>) firstEntry()) {
+            return new MapIterator<Entry<K, V>>(root == null ? null : root.first()) {
                 public Entry<K, V> next() {
                     return stepForward();
                 }
@@ -943,7 +953,7 @@
         }
 
         @Override public Iterator<K> iterator() {
-            return new MapIterator<K>((Node<K, V>) firstEntry()) {
+            return new MapIterator<K>(root == null ? null : root.first()) {
                 public K next() {
                     return stepForward().key;
                 }
@@ -951,7 +961,7 @@
         }
 
         public Iterator<K> descendingIterator() {
-            return new MapIterator<K>((Node<K, V>) lastEntry()) {
+            return new MapIterator<K>(root == null ? null : root.last()) {
                 public K next() {
                     return stepBackward().key;
                 }
@@ -1003,12 +1013,12 @@
         }
 
         public K pollFirst() {
-            Entry<K, V> entry = pollFirstEntry();
+            Entry<K, V> entry = internalPollFirstEntry();
             return entry != null ? entry.getKey() : null;
         }
 
         public K pollLast() {
-            Entry<K, V> entry = pollLastEntry();
+            Entry<K, V> entry = internalPollLastEntry();
             return entry != null ? entry.getKey() : null;
         }
 
@@ -1116,7 +1126,7 @@
         }
 
         @Override public boolean isEmpty() {
-            return firstEntry() == null;
+            return endpoint(true) == null;
         }
 
         @Override public V get(Object key) {
@@ -1177,7 +1187,7 @@
         /**
          * Returns the entry if it is in bounds, or null if it is out of bounds.
          */
-        private Entry<K, V> bound(Entry<K, V> node, Bound fromBound, Bound toBound) {
+        private Node<K, V> bound(Node<K, V> node, Bound fromBound, Bound toBound) {
             return node != null && isInBounds(node.getKey(), fromBound, toBound) ? node : null;
         }
 
@@ -1186,19 +1196,19 @@
          */
 
         public Entry<K, V> firstEntry() {
-            return endpoint(true);
+            return new SimpleImmutableEntry<K, V>(endpoint(true));
         }
 
         public Entry<K, V> pollFirstEntry() {
-            Node<K, V> result = (Node<K, V>) firstEntry();
+            Node<K, V> result = endpoint(true);
             if (result != null) {
                 removeInternal(result);
             }
-            return result;
+            return new SimpleImmutableEntry<K, V>(result);
         }
 
         public K firstKey() {
-            Entry<K, V> entry = firstEntry();
+            Entry<K, V> entry = endpoint(true);
             if (entry == null) {
                 throw new NoSuchElementException();
             }
@@ -1206,19 +1216,19 @@
         }
 
         public Entry<K, V> lastEntry() {
-            return endpoint(false);
+            return new SimpleImmutableEntry<K, V>(endpoint(false));
         }
 
         public Entry<K, V> pollLastEntry() {
-            Node<K, V> result = (Node<K, V>) lastEntry();
+            Node<K, V> result = endpoint(false);
             if (result != null) {
                 removeInternal(result);
             }
-            return result;
+            return new SimpleImmutableEntry<K, V>(result);
         }
 
         public K lastKey() {
-            Entry<K, V> entry = lastEntry();
+            Entry<K, V> entry = endpoint(false);
             if (entry == null) {
                 throw new NoSuchElementException();
             }
@@ -1228,38 +1238,38 @@
         /**
          * @param first true for the first element, false for the last.
          */
-        private Entry<K, V> endpoint(boolean first) {
-            Entry<K, V> entry;
+        private Node<K, V> endpoint(boolean first) {
+            Node<K, V> node;
             if (ascending == first) {
                 switch (fromBound) {
                     case NO_BOUND:
-                        entry = TreeMap.this.firstEntry();
+                        node = root == null ? null : root.first();
                         break;
                     case INCLUSIVE:
-                        entry = TreeMap.this.ceilingEntry(from);
+                        node = find(from, CEILING);
                         break;
                     case EXCLUSIVE:
-                        entry = TreeMap.this.higherEntry(from);
+                        node = find(from, HIGHER);
                         break;
                     default:
                         throw new AssertionError();
                 }
-                return bound(entry, NO_BOUND, toBound);
+                return bound(node, NO_BOUND, toBound);
             } else {
                 switch (toBound) {
                     case NO_BOUND:
-                        entry = TreeMap.this.lastEntry();
+                        node = root == null ? null : root.last();
                         break;
                     case INCLUSIVE:
-                        entry = TreeMap.this.floorEntry(to);
+                        node = find(to, FLOOR);
                         break;
                     case EXCLUSIVE:
-                        entry = TreeMap.this.lowerEntry(to);
+                        node = find(to, LOWER);
                         break;
                     default:
                         throw new AssertionError();
                 }
-                return bound(entry, fromBound, NO_BOUND);
+                return bound(node, fromBound, NO_BOUND);
             }
         }
 
@@ -1317,7 +1327,7 @@
         }
 
         public Entry<K, V> lowerEntry(K key) {
-            return findBounded(key, LOWER);
+            return new SimpleImmutableEntry<K, V>(findBounded(key, LOWER));
         }
 
         public K lowerKey(K key) {
@@ -1326,7 +1336,7 @@
         }
 
         public Entry<K, V> floorEntry(K key) {
-            return findBounded(key, FLOOR);
+            return new SimpleImmutableEntry<K, V>(findBounded(key, FLOOR));
         }
 
         public K floorKey(K key) {
@@ -1335,7 +1345,7 @@
         }
 
         public Entry<K, V> ceilingEntry(K key) {
-            return findBounded(key, CEILING);
+            return new SimpleImmutableEntry<K, V>(findBounded(key, CEILING));
         }
 
         public K ceilingKey(K key) {
@@ -1344,7 +1354,7 @@
         }
 
         public Entry<K, V> higherEntry(K key) {
-            return findBounded(key, HIGHER);
+            return new SimpleImmutableEntry<K, V>(findBounded(key, HIGHER));
         }
 
         public K higherKey(K key) {
@@ -1496,7 +1506,7 @@
             }
 
             @Override public Iterator<Entry<K, V>> iterator() {
-                return new BoundedIterator<Entry<K, V>>((Node<K, V>) firstEntry()) {
+                return new BoundedIterator<Entry<K, V>>(endpoint(true)) {
                     public Entry<K, V> next() {
                         return ascending ? stepForward() : stepBackward();
                     }
@@ -1530,7 +1540,7 @@
             }
 
             @Override public Iterator<K> iterator() {
-                return new BoundedIterator<K>((Node<K, V>) firstEntry()) {
+                return new BoundedIterator<K>(endpoint(true)) {
                     public K next() {
                         return (ascending ? stepForward() : stepBackward()).key;
                     }
@@ -1538,7 +1548,7 @@
             }
 
             public Iterator<K> descendingIterator() {
-                return new BoundedIterator<K>((Node<K, V>) lastEntry()) {
+                return new BoundedIterator<K>(endpoint(false)) {
                     public K next() {
                         return (ascending ? stepBackward() : stepForward()).key;
                     }
diff --git a/luni/src/test/java/java/util/TreeMapTest.java b/luni/src/test/java/java/util/TreeMapTest.java
index 1235dfa..0202c07 100644
--- a/luni/src/test/java/java/util/TreeMapTest.java
+++ b/luni/src/test/java/java/util/TreeMapTest.java
@@ -22,6 +22,110 @@
 
 public class TreeMapTest extends TestCase {
 
+    /**
+     * Test that the entrySet() method produces correctly mutable Entrys.
+     */
+    public void testEntrySetSetValue() {
+        TreeMap<String, String> map = new TreeMap<String, String>();
+        map.put("A", "a");
+        map.put("B", "b");
+        map.put("C", "c");
+
+        Iterator<Entry<String,String>> iterator = map.entrySet().iterator();
+        Entry<String, String> entryA = iterator.next();
+        assertEquals("a", entryA.setValue("x"));
+        assertEquals("x", entryA.getValue());
+        assertEquals("x", map.get("A"));
+        Entry<String, String> entryB = iterator.next();
+        assertEquals("b", entryB.setValue("y"));
+        Entry<String, String> entryC = iterator.next();
+        assertEquals("c", entryC.setValue("z"));
+        assertEquals("y", entryB.getValue());
+        assertEquals("y", map.get("B"));
+        assertEquals("z", entryC.getValue());
+        assertEquals("z", map.get("C"));
+    }
+
+    /**
+     * Test that the entrySet() method of a submap produces correctly mutable Entrys that
+     * propagate changes to the original map.
+     */
+    public void testSubMapEntrySetSetValue() {
+        TreeMap<String, String> map = new TreeMap<String, String>();
+        map.put("A", "a");
+        map.put("B", "b");
+        map.put("C", "c");
+        map.put("D", "d");
+        NavigableMap<String, String> subMap = map.subMap("A", true, "C", true);
+
+        Iterator<Entry<String,String>> iterator = subMap.entrySet().iterator();
+        Entry<String, String> entryA = iterator.next();
+        assertEquals("a", entryA.setValue("x"));
+        assertEquals("x", entryA.getValue());
+        assertEquals("x", subMap.get("A"));
+        assertEquals("x", map.get("A"));
+        Entry<String, String> entryB = iterator.next();
+        assertEquals("b", entryB.setValue("y"));
+        Entry<String, String> entryC = iterator.next();
+        assertEquals("c", entryC.setValue("z"));
+        assertEquals("y", entryB.getValue());
+        assertEquals("y", subMap.get("B"));
+        assertEquals("y", map.get("B"));
+        assertEquals("z", entryC.getValue());
+        assertEquals("z", subMap.get("C"));
+        assertEquals("z", map.get("C"));
+    }
+
+    /**
+     * Test that an Entry given by any method except entrySet() is immutable.
+     */
+    public void testExceptionsOnSetValue() {
+        TreeMap<String, String> map = new TreeMap<String, String>();
+        map.put("A", "a");
+        map.put("B", "b");
+        map.put("C", "c");
+
+        assertAllEntryMethodsReturnImmutableEntries(map);
+    }
+
+    /**
+     * Test that an Entry given by any method except entrySet() of a submap is immutable.
+     */
+    public void testExceptionsOnSubMapSetValue() {
+        TreeMap<String, String> map = new TreeMap<String, String>();
+        map.put("A", "a");
+        map.put("B", "b");
+        map.put("C", "c");
+        map.put("D", "d");
+
+        assertAllEntryMethodsReturnImmutableEntries(map.subMap("A", true, "C", true));
+    }
+
+    /**
+     * Asserts that each NavigableMap method that returns an Entry (except entrySet()) returns an
+     * immutable one. Assumes that the map contains at least entries with keys "A", "B" and "C".
+     */
+    private void assertAllEntryMethodsReturnImmutableEntries(NavigableMap<String, String> map) {
+        assertImmutable(map.ceilingEntry("B"));
+        assertImmutable(map.firstEntry());
+        assertImmutable(map.floorEntry("D"));
+        assertImmutable(map.higherEntry("A"));
+        assertImmutable(map.lastEntry());
+        assertImmutable(map.lowerEntry("C"));
+        assertImmutable(map.pollFirstEntry());
+        assertImmutable(map.pollLastEntry());
+    }
+
+    private void assertImmutable(Entry<String, String> entry) {
+        String initialValue = entry.getValue();
+        try {
+            entry.setValue("x");
+            fail();
+        } catch (UnsupportedOperationException e) {
+        }
+        assertEquals(initialValue, entry.getValue());
+    }
+
     public void testConcurrentModificationDetection() {
         Map<String, String> map = new TreeMap<String, String>();
         map.put("A", "a");