8267086: ArrayIndexOutOfBoundsException in java.security.KeyFactory.generatePublic

Reviewed-by: mbaesken
Backport-of: 2e375ae9ed459527393f9dd13d15d1031ad6095f
diff --git a/src/java.base/share/classes/sun/security/util/DerIndefLenConverter.java b/src/java.base/share/classes/sun/security/util/DerIndefLenConverter.java
index 56e1c63..b33adc9 100644
--- a/src/java.base/share/classes/sun/security/util/DerIndefLenConverter.java
+++ b/src/java.base/share/classes/sun/security/util/DerIndefLenConverter.java
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1998, 2019, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1998, 2021, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -31,9 +31,12 @@
 import java.util.Arrays;
 
 /**
- * A package private utility class to convert indefinite length DER
+ * A package private utility class to convert indefinite length BER
  * encoded byte arrays to definite length DER encoded byte arrays.
- *
+ * <p>
+ * Note: This class only substitute indefinite length octets to definite
+ * length octets. It does not update the contents even if they are not DER.
+ * <p>
  * This assumes that the basic data structure is "tag, length, value"
  * triplet. In the case where the length is "indefinite", terminating
  * end-of-contents bytes are expected.
@@ -42,26 +45,30 @@
  */
 class DerIndefLenConverter {
 
-    private static final int TAG_MASK            = 0x1f; // bits 5-1
-    private static final int FORM_MASK           = 0x20; // bits 6
-    private static final int CLASS_MASK          = 0xC0; // bits 8 and 7
-
     private static final int LEN_LONG            = 0x80; // bit 8 set
     private static final int LEN_MASK            = 0x7f; // bits 7 - 1
-    private static final int SKIP_EOC_BYTES      = 2;
 
     private byte[] data, newData;
     private int newDataPos, dataPos, dataSize, index;
     private int unresolved = 0;
 
+    // A list to store each indefinite length occurrence. Whenever an indef
+    // length is seen, the position after the 0x80 byte is appended to the
+    // list as an integer. Whenever its matching EOC is seen, we know the
+    // actual length and the position value is substituted with a calculated
+    // length octets. At the end, the new DER encoding is a concatenation of
+    // all existing tags, existing definite length octets, existing contents,
+    // and the newly created definte length octets in this list.
     private ArrayList<Object> ndefsList = new ArrayList<Object>();
 
+    // Length of extra bytes needed to convert indefinite encoding to definite.
+    // For each resolved indefinite length encoding, the starting 0x80 byte
+    // and the ending 00 00 bytes will be removed and a new definite length
+    // octets will be added. This value might be positive or negative.
     private int numOfTotalLenBytes = 0;
 
-    private boolean isEOC(int tag) {
-        return (((tag & TAG_MASK) == 0x00) &&  // EOC
-                ((tag & FORM_MASK) == 0x00) && // primitive
-                ((tag & CLASS_MASK) == 0x00)); // universal
+    private static boolean isEOC(byte[] data, int pos) {
+        return data[pos] == 0 && data[pos + 1] == 0;
     }
 
     // if bit 8 is set then it implies either indefinite length or long form
@@ -88,11 +95,14 @@
     }
 
     /**
-     * Parse the tag and if it is an end-of-contents tag then
-     * add the current position to the <code>eocList</code> vector.
+     * Consumes the tag at {@code dataPos}.
+     * <p>
+     * If it is EOC then replace the matching start position (i.e. the previous
+     * {@code dataPos} where an indefinite length was found by #parseLength)
+     * in {@code ndefsList} with a length octets for this section.
      */
     private void parseTag() throws IOException {
-        if (isEOC(data[dataPos]) && (data[dataPos + 1] == 0)) {
+        if (isEOC(data, dataPos)) {
             int numOfEncapsulatedLenBytes = 0;
             Object elem = null;
             int index;
@@ -103,6 +113,9 @@
                 if (elem instanceof Integer) {
                     break;
                 } else {
+                    // For each existing converted part, 3 bytes (80 at the
+                    // beginning and 00 00 at the end) are removed and a
+                    // new length octets is added.
                     numOfEncapsulatedLenBytes += ((byte[])elem).length - 3;
                 }
             }
@@ -114,6 +127,7 @@
                              numOfEncapsulatedLenBytes;
             byte[] sectionLenBytes = getLengthBytes(sectionLen);
             ndefsList.set(index, sectionLenBytes);
+            assert unresolved > 0;
             unresolved--;
 
             // Add the number of bytes required to represent this section
@@ -130,34 +144,41 @@
      * then skip the tag and its 1 byte length of zero.
      */
     private void writeTag() {
-        if (dataPos == dataSize)
+        if (dataPos == dataSize) {
             return;
-        int tag = data[dataPos++];
-        if (isEOC(tag) && (data[dataPos] == 0)) {
-            dataPos++;  // skip length
+        }
+        assert dataPos + 1 < dataSize;
+        if (isEOC(data, dataPos)) {
+            dataPos += 2;  // skip tag and length
             writeTag();
-        } else
-            newData[newDataPos++] = (byte)tag;
+        } else {
+            newData[newDataPos++] = data[dataPos++];
+        }
     }
 
     /**
-     * Parse the length and if it is an indefinite length then add
-     * the current position to the <code>ndefsList</code> vector.
+     * Parse the length octets started at {@code dataPos}. After this method
+     * is called, {@code dataPos} is placed after the length octets except
+     * -1 is returned.
      *
-     * @return the length of definite length data next, or -1 if there is
-     *         not enough bytes to determine it
+     * @return a) the length of definite length data next
+     *         b) -1, if it is a definite length data next but the length
+     *            octets is not complete to determine the actual length
+     *         c) 0, if it is an indefinite length. Also, append the current
+     *            position to the {@code ndefsList} vector.
      * @throws IOException if invalid data is read
      */
     private int parseLength() throws IOException {
-        int curLen = 0;
-        if (dataPos == dataSize)
-            return curLen;
+        if (dataPos == dataSize) {
+            return 0;
+        }
         int lenByte = data[dataPos++] & 0xff;
         if (isIndefinite(lenByte)) {
             ndefsList.add(dataPos);
             unresolved++;
-            return curLen;
+            return 0;
         }
+        int curLen = 0;
         if (isLongForm(lenByte)) {
             lenByte &= LEN_MASK;
             if (lenByte > 4) {
@@ -179,14 +200,17 @@
     }
 
     /**
-     * Write the length and if it is an indefinite length
-     * then calculate the definite length from the positions
-     * of the indefinite length and its matching EOC terminator.
-     * Then, write the value.
+     * Write the length and value.
+     * <p>
+     * If it was definite length, just re-write the length and copy the value.
+     * If it was an indefinite length, copy the precalculated definite octets
+     * from {@code ndefsList}. There is no values here because they will be
+     * sub-encodings of a constructed encoding.
      */
     private void writeLengthAndValue() throws IOException {
-        if (dataPos == dataSize)
-           return;
+        if (dataPos == dataSize) {
+            return;
+        }
         int curLen = 0;
         int lenByte = data[dataPos++] & 0xff;
         if (isIndefinite(lenByte)) {
@@ -194,21 +218,21 @@
             System.arraycopy(lenBytes, 0, newData, newDataPos,
                              lenBytes.length);
             newDataPos += lenBytes.length;
-            return;
-        }
-        if (isLongForm(lenByte)) {
-            lenByte &= LEN_MASK;
-            for (int i = 0; i < lenByte; i++) {
-                curLen = (curLen << 8) + (data[dataPos++] & 0xff);
-            }
-            if (curLen < 0) {
-                throw new IOException("Invalid length bytes");
-            }
         } else {
-            curLen = (lenByte & LEN_MASK);
+            if (isLongForm(lenByte)) {
+                lenByte &= LEN_MASK;
+                for (int i = 0; i < lenByte; i++) {
+                    curLen = (curLen << 8) + (data[dataPos++] & 0xff);
+                }
+                if (curLen < 0) {
+                    throw new IOException("Invalid length bytes");
+                }
+            } else {
+                curLen = (lenByte & LEN_MASK);
+            }
+            writeLength(curLen);
+            writeValue(curLen);
         }
-        writeLength(curLen);
-        writeValue(curLen);
     }
 
     private void writeLength(int curLen) {
@@ -297,18 +321,12 @@
     }
 
     /**
-     * Parse the value;
-     */
-    private void parseValue(int curLen) {
-        dataPos += curLen;
-    }
-
-    /**
      * Write the value;
      */
     private void writeValue(int curLen) {
-        for (int i=0; i < curLen; i++)
-            newData[newDataPos++] = data[dataPos++];
+        System.arraycopy(data, dataPos, newData, newDataPos, curLen);
+        dataPos += curLen;
+        newDataPos += curLen;
     }
 
     /**
@@ -323,10 +341,8 @@
      */
     byte[] convertBytes(byte[] indefData) throws IOException {
         data = indefData;
-        dataPos=0; index=0;
+        dataPos = 0;
         dataSize = data.length;
-        int len=0;
-        int unused = 0;
 
         // parse and set up the vectors of all the indefinite-lengths
         while (dataPos < dataSize) {
@@ -335,14 +351,17 @@
                 return null;
             }
             parseTag();
-            len = parseLength();
+            int len = parseLength();
             if (len < 0) {
                 return null;
             }
-            parseValue(len);
+            dataPos += len;
+            if (dataPos < 0) {
+                // overflow
+                throw new IOException("Data overflow");
+            }
             if (unresolved == 0) {
-                unused = dataSize - dataPos;
-                dataSize = dataPos;
+                assert !ndefsList.isEmpty() && ndefsList.get(0) instanceof byte[];
                 break;
             }
         }
@@ -351,14 +370,18 @@
             return null;
         }
 
+        int unused = dataSize - dataPos;
+        assert unused >= 0;
+        dataSize = dataPos;
+
         newData = new byte[dataSize + numOfTotalLenBytes + unused];
-        dataPos=0; newDataPos=0; index=0;
+        dataPos = 0; newDataPos = 0; index = 0;
 
         // write out the new byte array replacing all the indefinite-lengths
         // and EOCs
         while (dataPos < dataSize) {
-           writeTag();
-           writeLengthAndValue();
+            writeTag();
+            writeLengthAndValue();
         }
         System.arraycopy(indefData, dataSize,
                          newData, dataSize + numOfTotalLenBytes, unused);
@@ -396,7 +419,7 @@
             if (result == null) {
                 int next = in.read(); // This could block, but we need more
                 if (next == -1) {
-                    throw new IOException("not all indef len BER resolved");
+                    throw new IOException("not enough data to resolve indef len BER");
                 }
                 int more = in.available();
                 // expand array to include next and more