blob: 61d595d0ae8a5133497e59379241523ef5431fb9 [file] [log] [blame]
/*
* Copyright (C) 2009 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.conscrypt;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSessionContext;
/**
* Supports SSL session caches.
*/
abstract class AbstractSessionContext implements SSLSessionContext {
/**
* Maximum lifetime of a session (in seconds) after which it's considered invalid and should not
* be used to for new connections.
*/
private static final int DEFAULT_SESSION_TIMEOUT_SECONDS = 8 * 60 * 60;
private volatile int maximumSize;
private volatile int timeout = DEFAULT_SESSION_TIMEOUT_SECONDS;
final long sslCtxNativePointer = NativeCrypto.SSL_CTX_new();
/** Identifies OpenSSL sessions. */
private static final int OPEN_SSL = 1;
/** Identifies OpenSSL sessions with OCSP stapled data. */
private static final int OPEN_SSL_WITH_OCSP = 2;
/** Identifies OpenSSL sessions with TLS SCT data. */
private static final int OPEN_SSL_WITH_TLS_SCT = 3;
@SuppressWarnings("serial")
private final Map<ByteArray, SSLSession> sessions = new LinkedHashMap<ByteArray, SSLSession>() {
@Override
protected boolean removeEldestEntry(
Map.Entry<ByteArray, SSLSession> eldest) {
boolean remove = maximumSize > 0 && size() > maximumSize;
if (remove) {
remove(eldest.getKey());
sessionRemoved(eldest.getValue());
}
return false;
}
};
/**
* Constructs a new session context.
*
* @param maximumSize of cache
*/
AbstractSessionContext(int maximumSize) {
this.maximumSize = maximumSize;
}
/**
* Returns the collection of sessions ordered from oldest to newest
*/
private Iterator<SSLSession> sessionIterator() {
synchronized (sessions) {
SSLSession[] array = sessions.values().toArray(
new SSLSession[sessions.size()]);
return Arrays.asList(array).iterator();
}
}
@Override
public final Enumeration<byte[]> getIds() {
final Iterator<SSLSession> i = sessionIterator();
return new Enumeration<byte[]>() {
private SSLSession next;
@Override
public boolean hasMoreElements() {
if (next != null) {
return true;
}
while (i.hasNext()) {
SSLSession session = i.next();
if (session.isValid()) {
next = session;
return true;
}
}
next = null;
return false;
}
@Override
public byte[] nextElement() {
if (hasMoreElements()) {
byte[] id = next.getId();
next = null;
return id;
}
throw new NoSuchElementException();
}
};
}
@Override
public final int getSessionCacheSize() {
return maximumSize;
}
@Override
public final int getSessionTimeout() {
return timeout;
}
/**
* Makes sure cache size is < maximumSize.
*/
private void trimToSize() {
synchronized (sessions) {
int size = sessions.size();
if (size > maximumSize) {
int removals = size - maximumSize;
Iterator<SSLSession> i = sessions.values().iterator();
do {
SSLSession session = i.next();
i.remove();
sessionRemoved(session);
} while (--removals > 0);
}
}
}
@Override
public void setSessionTimeout(int seconds)
throws IllegalArgumentException {
if (seconds < 0) {
throw new IllegalArgumentException("seconds < 0");
}
timeout = seconds;
synchronized (sessions) {
Iterator<SSLSession> i = sessions.values().iterator();
while (i.hasNext()) {
SSLSession session = i.next();
// SSLSession's know their context and consult the
// timeout as part of their validity condition.
if (!session.isValid()) {
i.remove();
sessionRemoved(session);
}
}
}
}
/**
* Called when a session is removed. Used by ClientSessionContext
* to update its host-and-port based cache.
*/
protected abstract void sessionRemoved(SSLSession session);
@Override
public final void setSessionCacheSize(int size)
throws IllegalArgumentException {
if (size < 0) {
throw new IllegalArgumentException("size < 0");
}
int oldMaximum = maximumSize;
maximumSize = size;
// Trim cache to size if necessary.
if (size < oldMaximum) {
trimToSize();
}
}
/**
* Converts the given session to bytes.
*
* @return session data as bytes or null if the session can't be converted
*/
byte[] toBytes(SSLSession session) {
// TODO: Support SSLSessionImpl, too.
if (!(session instanceof OpenSSLSessionImpl)) {
return null;
}
OpenSSLSessionImpl sslSession = (OpenSSLSessionImpl) session;
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream daos = new DataOutputStream(baos);
daos.writeInt(OPEN_SSL_WITH_TLS_SCT); // session type ID
// Session data.
byte[] data = sslSession.getEncoded();
daos.writeInt(data.length);
daos.write(data);
// Certificates.
Certificate[] certs = session.getPeerCertificates();
daos.writeInt(certs.length);
for (Certificate cert : certs) {
data = cert.getEncoded();
daos.writeInt(data.length);
daos.write(data);
}
List<byte[]> ocspResponses = sslSession.getStatusResponses();
daos.writeInt(ocspResponses.size());
for (byte[] ocspResponse : ocspResponses) {
daos.writeInt(ocspResponse.length);
daos.write(ocspResponse);
}
byte[] tlsSctData = sslSession.getTlsSctData();
if (tlsSctData != null) {
daos.writeInt(tlsSctData.length);
daos.write(tlsSctData);
} else {
daos.writeInt(0);
}
// TODO: local certificates?
return baos.toByteArray();
} catch (IOException e) {
System.err.println("Failed to convert saved SSL Session: " + e.getMessage());
return null;
} catch (CertificateEncodingException e) {
log(e);
return null;
}
}
private static void checkRemaining(ByteBuffer buf, int length) throws IOException {
if (length < 0) {
throw new IOException("Length is negative: " + length);
}
if (length > buf.remaining()) {
throw new IOException(
"Length of blob is longer than available: " + length + " > " + buf.remaining());
}
}
/**
* Creates a session from the given bytes.
*
* @return a session or null if the session can't be converted
*/
OpenSSLSessionImpl toSession(byte[] data, String host, int port) {
ByteBuffer buf = ByteBuffer.wrap(data);
try {
int type = buf.getInt();
if (type != OPEN_SSL && type != OPEN_SSL_WITH_OCSP && type != OPEN_SSL_WITH_TLS_SCT) {
throw new IOException("Unexpected type ID: " + type);
}
int length = buf.getInt();
checkRemaining(buf, length);
byte[] sessionData = new byte[length];
buf.get(sessionData);
int count = buf.getInt();
checkRemaining(buf, count);
X509Certificate[] certs = new X509Certificate[count];
for (int i = 0; i < count; i++) {
length = buf.getInt();
checkRemaining(buf, length);
byte[] certData = new byte[length];
buf.get(certData);
try {
certs[i] = OpenSSLX509Certificate.fromX509Der(certData);
} catch (Exception e) {
throw new IOException("Can not read certificate " + i + "/" + count);
}
}
byte[] ocspData = null;
if (type >= OPEN_SSL_WITH_OCSP) {
// We only support one OCSP response now, but in the future
// we may support RFC 6961 which has multiple.
int countOcspResponses = buf.getInt();
checkRemaining(buf, countOcspResponses);
if (countOcspResponses >= 1) {
int ocspLength = buf.getInt();
checkRemaining(buf, ocspLength);
ocspData = new byte[ocspLength];
buf.get(ocspData);
// Skip the rest of the responses.
for (int i = 1; i < countOcspResponses; i++) {
ocspLength = buf.getInt();
checkRemaining(buf, ocspLength);
buf.position(buf.position() + ocspLength);
}
}
}
byte[] tlsSctData = null;
if (type == OPEN_SSL_WITH_TLS_SCT) {
int tlsSctDataLength = buf.getInt();
checkRemaining(buf, tlsSctDataLength);
if (tlsSctDataLength > 0) {
tlsSctData = new byte[tlsSctDataLength];
buf.get(tlsSctData);
}
}
if (buf.remaining() != 0) {
log(new AssertionError("Read entire session, but data still remains; rejecting"));
return null;
}
return new OpenSSLSessionImpl(sessionData, host, port, certs, ocspData, tlsSctData,
this);
} catch (IOException e) {
log(e);
return null;
} catch (BufferUnderflowException e) {
log(e);
return null;
}
}
SSLSession wrapSSLSessionIfNeeded(SSLSession session) {
if (session instanceof AbstractOpenSSLSession) {
return Platform.wrapSSLSession((AbstractOpenSSLSession) session);
} else {
return session;
}
}
@Override
public SSLSession getSession(byte[] sessionId) {
if (sessionId == null) {
throw new NullPointerException("sessionId == null");
}
ByteArray key = new ByteArray(sessionId);
SSLSession session;
synchronized (sessions) {
session = sessions.get(key);
}
if (session != null && session.isValid()) {
return wrapSSLSessionIfNeeded(session);
}
return null;
}
void putSession(SSLSession session) {
byte[] id = session.getId();
if (id.length == 0) {
return;
}
ByteArray key = new ByteArray(id);
synchronized (sessions) {
sessions.put(key, session);
}
}
private static void log(Throwable t) {
System.out.println("Error inflating SSL session: "
+ (t.getMessage() != null ? t.getMessage() : t.getClass().getName()));
}
@Override
protected void finalize() throws Throwable {
try {
NativeCrypto.SSL_CTX_free(sslCtxNativePointer);
} finally {
super.finalize();
}
}
}