blob: bf20e9bfb94836b1fc2a3310689553cbd9e706ed [file] [log] [blame]
package org.bouncycastle.crypto.tls.test;
import java.io.IOException;
import java.io.PrintStream;
import java.util.Hashtable;
import java.util.Vector;
import org.bouncycastle.asn1.ASN1EncodableVector;
import org.bouncycastle.asn1.DERBitString;
import org.bouncycastle.asn1.DERSequence;
import org.bouncycastle.asn1.x509.Certificate;
import org.bouncycastle.crypto.tls.AlertDescription;
import org.bouncycastle.crypto.tls.AlertLevel;
import org.bouncycastle.crypto.tls.CertificateRequest;
import org.bouncycastle.crypto.tls.ClientCertificateType;
import org.bouncycastle.crypto.tls.ConnectionEnd;
import org.bouncycastle.crypto.tls.DefaultTlsClient;
import org.bouncycastle.crypto.tls.ProtocolVersion;
import org.bouncycastle.crypto.tls.SignatureAlgorithm;
import org.bouncycastle.crypto.tls.SignatureAndHashAlgorithm;
import org.bouncycastle.crypto.tls.TlsAuthentication;
import org.bouncycastle.crypto.tls.TlsCredentials;
import org.bouncycastle.crypto.tls.TlsFatalAlert;
import org.bouncycastle.crypto.tls.TlsSignerCredentials;
import org.bouncycastle.crypto.tls.TlsUtils;
import org.bouncycastle.util.Arrays;
class TlsTestClientImpl
extends DefaultTlsClient
{
protected final TlsTestConfig config;
protected int firstFatalAlertConnectionEnd = -1;
protected short firstFatalAlertDescription = -1;
TlsTestClientImpl(TlsTestConfig config)
{
this.config = config;
}
int getFirstFatalAlertConnectionEnd()
{
return firstFatalAlertConnectionEnd;
}
short getFirstFatalAlertDescription()
{
return firstFatalAlertDescription;
}
public ProtocolVersion getClientVersion()
{
if (config.clientOfferVersion != null)
{
return config.clientOfferVersion;
}
return super.getClientVersion();
}
public ProtocolVersion getMinimumVersion()
{
if (config.clientMinimumVersion != null)
{
return config.clientMinimumVersion;
}
return super.getMinimumVersion();
}
public Hashtable getClientExtensions() throws IOException
{
Hashtable clientExtensions = super.getClientExtensions();
if (clientExtensions != null && !config.clientSendSignatureAlgorithms)
{
clientExtensions.remove(TlsUtils.EXT_signature_algorithms);
this.supportedSignatureAlgorithms = null;
}
return clientExtensions;
}
public boolean isFallback()
{
return config.clientFallback;
}
public void notifyAlertRaised(short alertLevel, short alertDescription, String message, Throwable cause)
{
if (alertLevel == AlertLevel.fatal && firstFatalAlertConnectionEnd == -1)
{
firstFatalAlertConnectionEnd = ConnectionEnd.client;
firstFatalAlertDescription = alertDescription;
}
if (TlsTestConfig.DEBUG)
{
PrintStream out = (alertLevel == AlertLevel.fatal) ? System.err : System.out;
out.println("TLS client raised alert: " + AlertLevel.getText(alertLevel)
+ ", " + AlertDescription.getText(alertDescription));
if (message != null)
{
out.println("> " + message);
}
if (cause != null)
{
cause.printStackTrace(out);
}
}
}
public void notifyAlertReceived(short alertLevel, short alertDescription)
{
if (alertLevel == AlertLevel.fatal && firstFatalAlertConnectionEnd == -1)
{
firstFatalAlertConnectionEnd = ConnectionEnd.server;
firstFatalAlertDescription = alertDescription;
}
if (TlsTestConfig.DEBUG)
{
PrintStream out = (alertLevel == AlertLevel.fatal) ? System.err : System.out;
out.println("TLS client received alert: " + AlertLevel.getText(alertLevel)
+ ", " + AlertDescription.getText(alertDescription));
}
}
public void notifyServerVersion(ProtocolVersion serverVersion) throws IOException
{
super.notifyServerVersion(serverVersion);
if (TlsTestConfig.DEBUG)
{
System.out.println("TLS client negotiated " + serverVersion);
}
}
public TlsAuthentication getAuthentication()
throws IOException
{
return new TlsAuthentication()
{
public void notifyServerCertificate(org.bouncycastle.crypto.tls.Certificate serverCertificate)
throws IOException
{
boolean isEmpty = serverCertificate == null || serverCertificate.isEmpty();
Certificate[] chain = serverCertificate.getCertificateList();
// TODO Cache test resources?
if (isEmpty || !(chain[0].equals(TlsTestUtils.loadCertificateResource("x509-server.pem"))
|| chain[0].equals(TlsTestUtils.loadCertificateResource("x509-server-dsa.pem"))
|| chain[0].equals(TlsTestUtils.loadCertificateResource("x509-server-ecdsa.pem"))))
{
throw new TlsFatalAlert(AlertDescription.bad_certificate);
}
if (TlsTestConfig.DEBUG)
{
System.out.println("TLS client received server certificate chain of length " + chain.length);
for (int i = 0; i != chain.length; i++)
{
Certificate entry = chain[i];
// TODO Create fingerprint based on certificate signature algorithm digest
System.out.println(" fingerprint:SHA-256 " + TlsTestUtils.fingerprint(entry) + " ("
+ entry.getSubject() + ")");
}
}
}
public TlsCredentials getClientCredentials(CertificateRequest certificateRequest)
throws IOException
{
if (config.serverCertReq == TlsTestConfig.SERVER_CERT_REQ_NONE)
{
throw new IllegalStateException();
}
if (config.clientAuth == TlsTestConfig.CLIENT_AUTH_NONE)
{
return null;
}
short[] certificateTypes = certificateRequest.getCertificateTypes();
if (certificateTypes == null || !Arrays.contains(certificateTypes, ClientCertificateType.rsa_sign))
{
return null;
}
Vector supportedSigAlgs = certificateRequest.getSupportedSignatureAlgorithms();
if (supportedSigAlgs != null && config.clientAuthSigAlg != null)
{
supportedSigAlgs = new Vector(1);
supportedSigAlgs.addElement(config.clientAuthSigAlg);
}
final TlsSignerCredentials signerCredentials = TlsTestUtils.loadSignerCredentials(context,
supportedSigAlgs, SignatureAlgorithm.rsa, "x509-client.pem", "x509-client-key.pem");
if (config.clientAuth == TlsTestConfig.CLIENT_AUTH_VALID)
{
return signerCredentials;
}
return new TlsSignerCredentials()
{
public byte[] generateCertificateSignature(byte[] hash) throws IOException
{
byte[] sig = signerCredentials.generateCertificateSignature(hash);
if (config.clientAuth == TlsTestConfig.CLIENT_AUTH_INVALID_VERIFY)
{
sig = corruptBit(sig);
}
return sig;
}
public org.bouncycastle.crypto.tls.Certificate getCertificate()
{
org.bouncycastle.crypto.tls.Certificate cert = signerCredentials.getCertificate();
if (config.clientAuth == TlsTestConfig.CLIENT_AUTH_INVALID_CERT)
{
cert = corruptCertificate(cert);
}
return cert;
}
public SignatureAndHashAlgorithm getSignatureAndHashAlgorithm()
{
return signerCredentials.getSignatureAndHashAlgorithm();
}
};
}
};
}
protected org.bouncycastle.crypto.tls.Certificate corruptCertificate(org.bouncycastle.crypto.tls.Certificate cert)
{
Certificate[] certList = cert.getCertificateList();
certList[0] = corruptCertificateSignature(certList[0]);
return new org.bouncycastle.crypto.tls.Certificate(certList);
}
protected Certificate corruptCertificateSignature(Certificate cert)
{
ASN1EncodableVector v = new ASN1EncodableVector();
v.add(cert.getTBSCertificate());
v.add(cert.getSignatureAlgorithm());
v.add(corruptSignature(cert.getSignature()));
return Certificate.getInstance(new DERSequence(v));
}
protected DERBitString corruptSignature(DERBitString bs)
{
return new DERBitString(corruptBit(bs.getOctets()));
}
protected byte[] corruptBit(byte[] bs)
{
bs = Arrays.clone(bs);
// Flip a random bit
int bit = context.getSecureRandom().nextInt(bs.length << 3);
bs[bit >>> 3] ^= (1 << (bit & 7));
return bs;
}
}