8168714: Tighten ECDSA validation

Added additional checks to DER parsing code

Reviewed-by: vinnie, ahgross
diff --git a/jdk/src/java.base/share/classes/sun/security/provider/DSA.java b/jdk/src/java.base/share/classes/sun/security/provider/DSA.java
index c1cc11b..5763e65 100644
--- a/jdk/src/java.base/share/classes/sun/security/provider/DSA.java
+++ b/jdk/src/java.base/share/classes/sun/security/provider/DSA.java
@@ -322,19 +322,20 @@
         } else {
             // first decode the signature.
             try {
-                DerInputStream in = new DerInputStream(signature, offset,
-                                                       length);
+                // Enforce strict DER checking for signatures
+                DerInputStream in =
+                    new DerInputStream(signature, offset, length, false);
                 DerValue[] values = in.getSequence(2);
 
+                // check number of components in the read sequence
+                // and trailing data
+                if ((values.length != 2) || (in.available() != 0)) {
+                    throw new IOException("Invalid encoding for signature");
+                }
                 r = values[0].getBigInteger();
                 s = values[1].getBigInteger();
-
-                // Check for trailing signature data
-                if (in.available() != 0) {
-                    throw new IOException("Incorrect signature length");
-                }
             } catch (IOException e) {
-                throw new SignatureException("invalid encoding for signature");
+                throw new SignatureException("Invalid encoding for signature", e);
             }
         }
 
diff --git a/jdk/src/java.base/share/classes/sun/security/rsa/RSASignature.java b/jdk/src/java.base/share/classes/sun/security/rsa/RSASignature.java
index 43e605a..a426f6c 100644
--- a/jdk/src/java.base/share/classes/sun/security/rsa/RSASignature.java
+++ b/jdk/src/java.base/share/classes/sun/security/rsa/RSASignature.java
@@ -226,9 +226,10 @@
      * Decode the signature data. Verify that the object identifier matches
      * and return the message digest.
      */
-    public static byte[] decodeSignature(ObjectIdentifier oid, byte[] signature)
+    public static byte[] decodeSignature(ObjectIdentifier oid, byte[] sig)
             throws IOException {
-        DerInputStream in = new DerInputStream(signature);
+        // Enforce strict DER checking for signatures
+        DerInputStream in = new DerInputStream(sig, 0, sig.length, false);
         DerValue[] values = in.getSequence(2);
         if ((values.length != 2) || (in.available() != 0)) {
             throw new IOException("SEQUENCE length error");
diff --git a/jdk/src/java.base/share/classes/sun/security/util/DerInputBuffer.java b/jdk/src/java.base/share/classes/sun/security/util/DerInputBuffer.java
index 26bd198..1e56002 100644
--- a/jdk/src/java.base/share/classes/sun/security/util/DerInputBuffer.java
+++ b/jdk/src/java.base/share/classes/sun/security/util/DerInputBuffer.java
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1996, 2014, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2016, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -147,6 +147,11 @@
         System.arraycopy(buf, pos, bytes, 0, len);
         skip(len);
 
+        // check to make sure no extra leading 0s for DER
+        if (len >= 2 && (bytes[0] == 0) && (bytes[1] >= 0)) {
+            throw new IOException("Invalid encoding: redundant leading 0s");
+        }
+
         if (makePositive) {
             return new BigInteger(1, bytes);
         } else {
diff --git a/jdk/src/java.base/share/classes/sun/security/util/DerInputStream.java b/jdk/src/java.base/share/classes/sun/security/util/DerInputStream.java
index 264bfd0..94c9085 100644
--- a/jdk/src/java.base/share/classes/sun/security/util/DerInputStream.java
+++ b/jdk/src/java.base/share/classes/sun/security/util/DerInputStream.java
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1996, 2014, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2016, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -77,7 +77,7 @@
      * @param data the buffer from which to create the string (CONSUMED)
      */
     public DerInputStream(byte[] data) throws IOException {
-        init(data, 0, data.length);
+        init(data, 0, data.length, true);
     }
 
     /**
@@ -92,23 +92,48 @@
      *          starting at "offset"
      */
     public DerInputStream(byte[] data, int offset, int len) throws IOException {
-        init(data, offset, len);
+        init(data, offset, len, true);
+    }
+
+    /**
+     * Create a DER input stream from part of a data buffer with
+     * additional arg to indicate whether to allow constructed
+     * indefinite-length encoding.
+     * The buffer is not copied, it is shared.  Accordingly, the
+     * buffer should be treated as read-only.
+     *
+     * @param data the buffer from which to create the string (CONSUMED)
+     * @param offset the first index of <em>data</em> which will
+     *          be read as DER input in the new stream
+     * @param len how long a chunk of the buffer to use,
+     *          starting at "offset"
+     * @param allowIndefiniteLength whether to allow constructed
+     *          indefinite-length encoding
+     */
+    public DerInputStream(byte[] data, int offset, int len,
+        boolean allowIndefiniteLength) throws IOException {
+        init(data, offset, len, allowIndefiniteLength);
     }
 
     /*
      * private helper routine
      */
-    private void init(byte[] data, int offset, int len) throws IOException {
+    private void init(byte[] data, int offset, int len,
+        boolean allowIndefiniteLength) throws IOException {
         if ((offset+2 > data.length) || (offset+len > data.length)) {
             throw new IOException("Encoding bytes too short");
         }
         // check for indefinite length encoding
         if (DerIndefLenConverter.isIndefinite(data[offset+1])) {
-            byte[] inData = new byte[len];
-            System.arraycopy(data, offset, inData, 0, len);
+            if (!allowIndefiniteLength) {
+                throw new IOException("Indefinite length BER encoding found");
+            } else {
+                byte[] inData = new byte[len];
+                System.arraycopy(data, offset, inData, 0, len);
 
-            DerIndefLenConverter derIn = new DerIndefLenConverter();
-            buffer = new DerInputBuffer(derIn.convert(inData));
+                DerIndefLenConverter derIn = new DerIndefLenConverter();
+                buffer = new DerInputBuffer(derIn.convert(inData));
+            }
         } else
             buffer = new DerInputBuffer(data, offset, len);
         buffer.mark(Integer.MAX_VALUE);
@@ -239,15 +264,19 @@
          * representation.
          */
         length--;
-        int validBits = length*8 - buffer.read();
+        int excessBits = buffer.read();
+        if (excessBits < 0) {
+            throw new IOException("Unused bits of bit string invalid");
+        }
+        int validBits = length*8 - excessBits;
         if (validBits < 0) {
-            throw new IOException("valid bits of bit string invalid");
+            throw new IOException("Valid bits of bit string invalid");
         }
 
         byte[] repn = new byte[length];
 
         if ((length != 0) && (buffer.read(repn) != length)) {
-            throw new IOException("short read of DER bit string");
+            throw new IOException("Short read of DER bit string");
         }
 
         return new BitArray(validBits, repn);
@@ -263,7 +292,7 @@
         int length = getDefiniteLength(buffer);
         byte[] retval = new byte[length];
         if ((length != 0) && (buffer.read(retval) != length))
-            throw new IOException("short read of DER octet string");
+            throw new IOException("Short read of DER octet string");
 
         return retval;
     }
@@ -273,7 +302,7 @@
      */
     public void getBytes(byte[] val) throws IOException {
         if ((val.length != 0) && (buffer.read(val) != val.length)) {
-            throw new IOException("short read of DER octet string");
+            throw new IOException("Short read of DER octet string");
         }
     }
 
@@ -357,7 +386,7 @@
         DerInputStream  newstr;
 
         byte lenByte = (byte)buffer.read();
-        int len = getLength((lenByte & 0xff), buffer);
+        int len = getLength(lenByte, buffer);
 
         if (len == -1) {
            // indefinite length encoding found
@@ -403,7 +432,7 @@
         } while (newstr.available() > 0);
 
         if (newstr.available() != 0)
-            throw new IOException("extra data at end of vector");
+            throw new IOException("Extra data at end of vector");
 
         /*
          * Now stick them into the array we're returning.
@@ -494,7 +523,7 @@
         int length = getDefiniteLength(buffer);
         byte[] retval = new byte[length];
         if ((length != 0) && (buffer.read(retval) != length))
-            throw new IOException("short read of DER " +
+            throw new IOException("Short read of DER " +
                                   stringName + " string");
 
         return new String(retval, enc);
@@ -555,7 +584,11 @@
      */
     static int getLength(int lenByte, InputStream in) throws IOException {
         int value, tmp;
+        if (lenByte == -1) {
+            throw new IOException("Short read of DER length");
+        }
 
+        String mdName = "DerInputStream.getLength(): ";
         tmp = lenByte;
         if ((tmp & 0x080) == 0x00) { // short form, 1 byte datum
             value = tmp;
@@ -569,17 +602,23 @@
             if (tmp == 0)
                 return -1;
             if (tmp < 0 || tmp > 4)
-                throw new IOException("DerInputStream.getLength(): lengthTag="
-                    + tmp + ", "
+                throw new IOException(mdName + "lengthTag=" + tmp + ", "
                     + ((tmp < 0) ? "incorrect DER encoding." : "too big."));
 
-            for (value = 0; tmp > 0; tmp --) {
+            value = 0x0ff & in.read();
+            tmp--;
+            if (value == 0) {
+                // DER requires length value be encoded in minimum number of bytes
+                throw new IOException(mdName + "Redundant length bytes found");
+            }
+            while (tmp-- > 0) {
                 value <<= 8;
                 value += 0x0ff & in.read();
             }
             if (value < 0) {
-                throw new IOException("DerInputStream.getLength(): "
-                        + "Invalid length bytes");
+                throw new IOException(mdName + "Invalid length bytes");
+            } else if (value <= 127) {
+                throw new IOException(mdName + "Should use short form for length");
             }
         }
         return value;
diff --git a/jdk/src/java.base/share/classes/sun/security/util/DerValue.java b/jdk/src/java.base/share/classes/sun/security/util/DerValue.java
index 936e2ff..b47064a 100644
--- a/jdk/src/java.base/share/classes/sun/security/util/DerValue.java
+++ b/jdk/src/java.base/share/classes/sun/security/util/DerValue.java
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1996, 2014, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2016, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -248,7 +248,7 @@
 
         tag = (byte)in.read();
         byte lenByte = (byte)in.read();
-        length = DerInputStream.getLength((lenByte & 0xff), in);
+        length = DerInputStream.getLength(lenByte, in);
         if (length == -1) {  // indefinite length encoding found
             DerInputBuffer inbuf = in.dup();
             int readLen = inbuf.available();
@@ -361,7 +361,7 @@
 
         tag = (byte)in.read();
         byte lenByte = (byte)in.read();
-        length = DerInputStream.getLength((lenByte & 0xff), in);
+        length = DerInputStream.getLength(lenByte, in);
         if (length == -1) { // indefinite length encoding found
             int readLen = in.available();
             int offset = 2;     // for tag and length bytes
diff --git a/jdk/src/jdk.crypto.ec/share/classes/sun/security/ec/ECDSASignature.java b/jdk/src/jdk.crypto.ec/share/classes/sun/security/ec/ECDSASignature.java
index c88c191..3b74338 100644
--- a/jdk/src/jdk.crypto.ec/share/classes/sun/security/ec/ECDSASignature.java
+++ b/jdk/src/jdk.crypto.ec/share/classes/sun/security/ec/ECDSASignature.java
@@ -455,18 +455,22 @@
     }
 
     // Convert the DER encoding of R and S into a concatenation of R and S
-    private byte[] decodeSignature(byte[] signature) throws SignatureException {
+    private byte[] decodeSignature(byte[] sig) throws SignatureException {
 
         try {
-            DerInputStream in = new DerInputStream(signature);
+            // Enforce strict DER checking for signatures
+            DerInputStream in = new DerInputStream(sig, 0, sig.length, false);
             DerValue[] values = in.getSequence(2);
+
+            // check number of components in the read sequence
+            // and trailing data
+            if ((values.length != 2) || (in.available() != 0)) {
+                throw new IOException("Invalid encoding for signature");
+            }
+
             BigInteger r = values[0].getPositiveBigInteger();
             BigInteger s = values[1].getPositiveBigInteger();
 
-            // Check for trailing signature data
-            if (in.available() != 0) {
-                throw new IOException("Incorrect signature length");
-            }
             // trim leading zeroes
             byte[] rBytes = trimZeroes(r.toByteArray());
             byte[] sBytes = trimZeroes(s.toByteArray());
@@ -480,7 +484,7 @@
             return result;
 
         } catch (Exception e) {
-            throw new SignatureException("Could not decode signature", e);
+            throw new SignatureException("Invalid encoding for signature", e);
         }
     }
 
diff --git a/jdk/src/jdk.crypto.token/share/classes/sun/security/pkcs11/P11Signature.java b/jdk/src/jdk.crypto.token/share/classes/sun/security/pkcs11/P11Signature.java
index 43dfcd6..0260a55 100644
--- a/jdk/src/jdk.crypto.token/share/classes/sun/security/pkcs11/P11Signature.java
+++ b/jdk/src/jdk.crypto.token/share/classes/sun/security/pkcs11/P11Signature.java
@@ -740,17 +740,21 @@
         }
     }
 
-    private static byte[] asn1ToDSA(byte[] signature) throws SignatureException {
+    private static byte[] asn1ToDSA(byte[] sig) throws SignatureException {
         try {
-            DerInputStream in = new DerInputStream(signature);
+            // Enforce strict DER checking for signatures
+            DerInputStream in = new DerInputStream(sig, 0, sig.length, false);
             DerValue[] values = in.getSequence(2);
+
+            // check number of components in the read sequence
+            // and trailing data
+            if ((values.length != 2) || (in.available() != 0)) {
+                throw new IOException("Invalid encoding for signature");
+            }
+
             BigInteger r = values[0].getPositiveBigInteger();
             BigInteger s = values[1].getPositiveBigInteger();
 
-            // Check for trailing signature data
-            if (in.available() != 0) {
-                throw new IOException("Incorrect signature length");
-            }
             byte[] br = toByteArray(r, 20);
             byte[] bs = toByteArray(s, 20);
             if ((br == null) || (bs == null)) {
@@ -760,21 +764,25 @@
         } catch (SignatureException e) {
             throw e;
         } catch (Exception e) {
-            throw new SignatureException("invalid encoding for signature", e);
+            throw new SignatureException("Invalid encoding for signature", e);
         }
     }
 
-    private byte[] asn1ToECDSA(byte[] signature) throws SignatureException {
+    private byte[] asn1ToECDSA(byte[] sig) throws SignatureException {
         try {
-            DerInputStream in = new DerInputStream(signature);
+            // Enforce strict DER checking for signatures
+            DerInputStream in = new DerInputStream(sig, 0, sig.length, false);
             DerValue[] values = in.getSequence(2);
+
+            // check number of components in the read sequence
+            // and trailing data
+            if ((values.length != 2) || (in.available() != 0)) {
+                throw new IOException("Invalid encoding for signature");
+            }
+
             BigInteger r = values[0].getPositiveBigInteger();
             BigInteger s = values[1].getPositiveBigInteger();
 
-            // Check for trailing signature data
-            if (in.available() != 0) {
-                throw new IOException("Incorrect signature length");
-            }
             // trim leading zeroes
             byte[] br = KeyUtil.trimZeroes(r.toByteArray());
             byte[] bs = KeyUtil.trimZeroes(s.toByteArray());
@@ -785,7 +793,7 @@
             System.arraycopy(bs, 0, res, res.length - bs.length, bs.length);
             return res;
         } catch (Exception e) {
-            throw new SignatureException("invalid encoding for signature", e);
+            throw new SignatureException("Invalid encoding for signature", e);
         }
     }