blob: d0ed3e9436c05f61a07367df8dec137cd6701061 [file] [log] [blame]
package org.bouncycastle.pqc.crypto.xmss;
import java.io.IOException;
import java.security.SecureRandom;
import java.text.ParseException;
/**
* XMSS.
*
*/
public class XMSS {
/**
* XMSS parameters.
*/
private XMSSParameters params;
/**
* WOTS+ instance.
*/
private WOTSPlus wotsPlus;
/**
* PRNG.
*/
private SecureRandom prng;
/**
* Randomization functions.
*/
private KeyedHashFunctions khf;
/**
* XMSS private key.
*/
private XMSSPrivateKeyParameters privateKey;
/**
* XMSS public key.
*/
private XMSSPublicKeyParameters publicKey;
/**
* XMSS constructor...
*
* @param params
* XMSSParameters.
*/
public XMSS(XMSSParameters params) {
super();
if (params == null) {
throw new NullPointerException("params == null");
}
this.params = params;
wotsPlus = params.getWOTSPlus();
prng = params.getPRNG();
khf = wotsPlus.getKhf();
try {
privateKey = new XMSSPrivateKeyParameters.Builder(params).withBDSState(new BDS(this)).build();
publicKey = new XMSSPublicKeyParameters.Builder(params).build();
} catch (ParseException e) {
/* should not be possible */
e.printStackTrace();
} catch (ClassNotFoundException e) {
/* should not be possible */
e.printStackTrace();
} catch (IOException e) {
/* should not be possible */
e.printStackTrace();
}
}
/**
* Generate a new XMSS private key / public key pair.
*
*/
public void generateKeys() {
/* generate private key */
privateKey = generatePrivateKey();
XMSSNode root = getBDSState().initialize((OTSHashAddress) new OTSHashAddress.Builder().build());
try {
privateKey = new XMSSPrivateKeyParameters.Builder(params).withIndex(privateKey.getIndex())
.withSecretKeySeed(privateKey.getSecretKeySeed()).withSecretKeyPRF(privateKey.getSecretKeyPRF())
.withPublicSeed(privateKey.getPublicSeed()).withRoot(root.getValue())
.withBDSState(privateKey.getBDSState()).build();
publicKey = new XMSSPublicKeyParameters.Builder(params).withRoot(root.getValue())
.withPublicSeed(getPublicSeed()).build();
} catch (ParseException ex) {
/* should not be possible */
ex.printStackTrace();
} catch (ClassNotFoundException e) {
/* should not be possible */
e.printStackTrace();
} catch (IOException e) {
/* should not be possible */
e.printStackTrace();
}
}
/**
* Generate an XMSS private key.
*
* @return XMSS private key.
*/
private XMSSPrivateKeyParameters generatePrivateKey() {
int n = params.getDigestSize();
byte[] secretKeySeed = new byte[n];
prng.nextBytes(secretKeySeed);
byte[] secretKeyPRF = new byte[n];
prng.nextBytes(secretKeyPRF);
byte[] publicSeed = new byte[n];
prng.nextBytes(publicSeed);
XMSSPrivateKeyParameters privateKey = null;
try {
privateKey = new XMSSPrivateKeyParameters.Builder(params).withSecretKeySeed(secretKeySeed)
.withSecretKeyPRF(secretKeyPRF).withPublicSeed(publicSeed)
.withBDSState(this.privateKey.getBDSState()).build();
} catch (ParseException e) {
/* should not be possible */
e.printStackTrace();
} catch (ClassNotFoundException e) {
/* should not be possible */
e.printStackTrace();
} catch (IOException e) {
/* should not be possible */
e.printStackTrace();
}
return privateKey;
}
/**
* Import XMSS private key / public key pair.
*
* @param privateKey
* XMSS private key.
* @param publicKey
* XMSS public key.
* @throws ParseException
* @throws ClassNotFoundException
* @throws IOException
*/
public void importState(byte[] privateKey, byte[] publicKey)
throws ParseException, ClassNotFoundException, IOException {
if (privateKey == null) {
throw new NullPointerException("privateKey == null");
}
if (publicKey == null) {
throw new NullPointerException("publicKey == null");
}
/* import keys */
XMSSPrivateKeyParameters tmpPrivateKey = new XMSSPrivateKeyParameters.Builder(params)
.withPrivateKey(privateKey, this).build();
XMSSPublicKeyParameters tmpPublicKey = new XMSSPublicKeyParameters.Builder(params).withPublicKey(publicKey)
.build();
if (!XMSSUtil.compareByteArray(tmpPrivateKey.getRoot(), tmpPublicKey.getRoot())) {
throw new IllegalStateException("root of private key and public key do not match");
}
if (!XMSSUtil.compareByteArray(tmpPrivateKey.getPublicSeed(), tmpPublicKey.getPublicSeed())) {
throw new IllegalStateException("public seed of private key and public key do not match");
}
/* import */
this.privateKey = tmpPrivateKey;
this.publicKey = tmpPublicKey;
wotsPlus.importKeys(new byte[params.getDigestSize()], this.privateKey.getPublicSeed());
}
/**
* Sign message.
*
* @param message
* Message to sign.
* @return XMSS signature on digest of message.
*/
public byte[] sign(byte[] message) {
if (message == null) {
throw new NullPointerException("message == null");
}
if (getBDSState().getAuthenticationPath().isEmpty()) {
throw new IllegalStateException("not initialized");
}
int index = privateKey.getIndex();
if (!XMSSUtil.isIndexValid(getParams().getHeight(), index)) {
throw new IllegalArgumentException("index out of bounds");
}
/* create (randomized keyed) messageDigest of message */
byte[] random = khf.PRF(privateKey.getSecretKeyPRF(), XMSSUtil.toBytesBigEndian(index, 32));
byte[] concatenated = XMSSUtil.concat(random, privateKey.getRoot(),
XMSSUtil.toBytesBigEndian(index, params.getDigestSize()));
byte[] messageDigest = khf.HMsg(concatenated, message);
/* create signature for messageDigest */
OTSHashAddress otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder().withOTSAddress(index).build();
WOTSPlusSignature wotsPlusSignature = wotsSign(messageDigest, otsHashAddress);
XMSSSignature signature = null;
try {
signature = (XMSSSignature) new XMSSSignature.Builder(params).withIndex(index).withRandom(random)
.withWOTSPlusSignature(wotsPlusSignature).withAuthPath(getBDSState().getAuthenticationPath())
.build();
} catch (ParseException ex) {
/* should not happen */
ex.printStackTrace();
}
/* prepare authentication path for next leaf */
int treeHeight = this.getParams().getHeight();
if (index < ((1 << treeHeight) - 1)) {
getBDSState().nextAuthenticationPath((OTSHashAddress) new OTSHashAddress.Builder().build());
}
/* update index */
setIndex(index + 1);
return signature.toByteArray();
}
/**
* Verify an XMSS signature.
*
* @param message
* Message.
* @param signature
* XMSS signature.
* @param publicKey
* XMSS public key.
* @return true if signature is valid false else.
* @throws ParseException
*/
public boolean verifySignature(byte[] message, byte[] signature, byte[] publicKey) throws ParseException {
if (message == null) {
throw new NullPointerException("message == null");
}
if (signature == null) {
throw new NullPointerException("signature == null");
}
if (publicKey == null) {
throw new NullPointerException("publicKey == null");
}
/* parse signature and public key */
XMSSSignature sig = new XMSSSignature.Builder(params).withSignature(signature).build();
/* generate public key */
XMSSPublicKeyParameters pubKey = new XMSSPublicKeyParameters.Builder(params).withPublicKey(publicKey).build();
/* save state */
int savedIndex = privateKey.getIndex();
byte[] savedPublicSeed = privateKey.getPublicSeed();
/* set index / public seed */
int index = sig.getIndex();
setIndex(index);
setPublicSeed(pubKey.getPublicSeed());
/* reinitialize WOTS+ object */
wotsPlus.importKeys(new byte[params.getDigestSize()], getPublicSeed());
/* create message digest */
byte[] concatenated = XMSSUtil.concat(sig.getRandom(), pubKey.getRoot(),
XMSSUtil.toBytesBigEndian(index, params.getDigestSize()));
byte[] messageDigest = khf.HMsg(concatenated, message);
/* get root from signature */
OTSHashAddress otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder().withOTSAddress(index).build();
XMSSNode rootNodeFromSignature = getRootNodeFromSignature(messageDigest, sig, otsHashAddress);
/* reset state */
setIndex(savedIndex);
setPublicSeed(savedPublicSeed);
return XMSSUtil.compareByteArray(rootNodeFromSignature.getValue(), pubKey.getRoot());
}
/**
* Export XMSS private key.
*
* @return XMSS private key.
*/
public byte[] exportPrivateKey() {
return privateKey.toByteArray();
}
/**
* Export XMSS public key.
*
* @return XMSS public key.
*/
public byte[] exportPublicKey() {
return publicKey.toByteArray();
}
/**
* Randomization of nodes in binary tree.
*
* @param left
* Left node.
* @param right
* Right node.
* @param address
* Address.
* @return Randomized hash of parent of left / right node.
*/
protected XMSSNode randomizeHash(XMSSNode left, XMSSNode right, XMSSAddress address) {
if (left == null) {
throw new NullPointerException("left == null");
}
if (right == null) {
throw new NullPointerException("right == null");
}
if (left.getHeight() != right.getHeight()) {
throw new IllegalStateException("height of both nodes must be equal");
}
if (address == null) {
throw new NullPointerException("address == null");
}
byte[] publicSeed = getPublicSeed();
if (address instanceof LTreeAddress) {
LTreeAddress tmpAddress = (LTreeAddress) address;
address = (LTreeAddress) new LTreeAddress.Builder().withLayerAddress(tmpAddress.getLayerAddress())
.withTreeAddress(tmpAddress.getTreeAddress()).withLTreeAddress(tmpAddress.getLTreeAddress())
.withTreeHeight(tmpAddress.getTreeHeight()).withTreeIndex(tmpAddress.getTreeIndex())
.withKeyAndMask(0).build();
} else if (address instanceof HashTreeAddress) {
HashTreeAddress tmpAddress = (HashTreeAddress) address;
address = (HashTreeAddress) new HashTreeAddress.Builder().withLayerAddress(tmpAddress.getLayerAddress())
.withTreeAddress(tmpAddress.getTreeAddress()).withTreeHeight(tmpAddress.getTreeHeight())
.withTreeIndex(tmpAddress.getTreeIndex()).withKeyAndMask(0).build();
}
byte[] key = khf.PRF(publicSeed, address.toByteArray());
if (address instanceof LTreeAddress) {
LTreeAddress tmpAddress = (LTreeAddress) address;
address = (LTreeAddress) new LTreeAddress.Builder().withLayerAddress(tmpAddress.getLayerAddress())
.withTreeAddress(tmpAddress.getTreeAddress()).withLTreeAddress(tmpAddress.getLTreeAddress())
.withTreeHeight(tmpAddress.getTreeHeight()).withTreeIndex(tmpAddress.getTreeIndex())
.withKeyAndMask(1).build();
} else if (address instanceof HashTreeAddress) {
HashTreeAddress tmpAddress = (HashTreeAddress) address;
address = (HashTreeAddress) new HashTreeAddress.Builder().withLayerAddress(tmpAddress.getLayerAddress())
.withTreeAddress(tmpAddress.getTreeAddress()).withTreeHeight(tmpAddress.getTreeHeight())
.withTreeIndex(tmpAddress.getTreeIndex()).withKeyAndMask(1).build();
}
byte[] bitmask0 = khf.PRF(publicSeed, address.toByteArray());
if (address instanceof LTreeAddress) {
LTreeAddress tmpAddress = (LTreeAddress) address;
address = (LTreeAddress) new LTreeAddress.Builder().withLayerAddress(tmpAddress.getLayerAddress())
.withTreeAddress(tmpAddress.getTreeAddress()).withLTreeAddress(tmpAddress.getLTreeAddress())
.withTreeHeight(tmpAddress.getTreeHeight()).withTreeIndex(tmpAddress.getTreeIndex())
.withKeyAndMask(2).build();
} else if (address instanceof HashTreeAddress) {
HashTreeAddress tmpAddress = (HashTreeAddress) address;
address = (HashTreeAddress) new HashTreeAddress.Builder().withLayerAddress(tmpAddress.getLayerAddress())
.withTreeAddress(tmpAddress.getTreeAddress()).withTreeHeight(tmpAddress.getTreeHeight())
.withTreeIndex(tmpAddress.getTreeIndex()).withKeyAndMask(2).build();
}
byte[] bitmask1 = khf.PRF(publicSeed, address.toByteArray());
int n = params.getDigestSize();
byte[] tmpMask = new byte[2 * n];
for (int i = 0; i < n; i++) {
tmpMask[i] = (byte) (left.getValue()[i] ^ bitmask0[i]);
}
for (int i = 0; i < n; i++) {
tmpMask[i + n] = (byte) (right.getValue()[i] ^ bitmask1[i]);
}
byte[] out = khf.H(key, tmpMask);
return new XMSSNode(left.getHeight(), out);
}
/**
* Compresses a WOTS+ public key to a single n-byte string.
*
* @param publicKey
* WOTS+ public key to compress.
* @param address
* Address.
* @return Compressed n-byte string of public key.
*/
protected XMSSNode lTree(WOTSPlusPublicKeyParameters publicKey, LTreeAddress address) {
if (publicKey == null) {
throw new NullPointerException("publicKey == null");
}
if (address == null) {
throw new NullPointerException("address == null");
}
int len = wotsPlus.getParams().getLen();
/* duplicate public key to XMSSNode Array */
byte[][] publicKeyBytes = publicKey.toByteArray();
XMSSNode[] publicKeyNodes = new XMSSNode[publicKeyBytes.length];
for (int i = 0; i < publicKeyBytes.length; i++) {
publicKeyNodes[i] = new XMSSNode(0, publicKeyBytes[i]);
}
address = (LTreeAddress) new LTreeAddress.Builder().withLayerAddress(address.getLayerAddress())
.withTreeAddress(address.getTreeAddress()).withLTreeAddress(address.getLTreeAddress()).withTreeHeight(0)
.withTreeIndex(address.getTreeIndex()).withKeyAndMask(address.getKeyAndMask()).build();
while (len > 1) {
for (int i = 0; i < (int) Math.floor(len / 2); i++) {
address = (LTreeAddress) new LTreeAddress.Builder().withLayerAddress(address.getLayerAddress())
.withTreeAddress(address.getTreeAddress()).withLTreeAddress(address.getLTreeAddress())
.withTreeHeight(address.getTreeHeight()).withTreeIndex(i)
.withKeyAndMask(address.getKeyAndMask()).build();
publicKeyNodes[i] = randomizeHash(publicKeyNodes[2 * i], publicKeyNodes[(2 * i) + 1], address);
}
if (len % 2 == 1) {
publicKeyNodes[(int) Math.floor(len / 2)] = publicKeyNodes[len - 1];
}
len = (int) Math.ceil((double) len / 2);
address = (LTreeAddress) new LTreeAddress.Builder().withLayerAddress(address.getLayerAddress())
.withTreeAddress(address.getTreeAddress()).withLTreeAddress(address.getLTreeAddress())
.withTreeHeight(address.getTreeHeight() + 1).withTreeIndex(address.getTreeIndex())
.withKeyAndMask(address.getKeyAndMask()).build();
}
return publicKeyNodes[0];
}
/**
* Generate a WOTS+ signature on a message without the corresponding
* authentication path
*
* @param messageDigest
* Message digest of length n.
* @param otsHashAddress
* OTS hash address.
* @return XMSS signature.
*/
protected WOTSPlusSignature wotsSign(byte[] messageDigest, OTSHashAddress otsHashAddress) {
if (messageDigest.length != params.getDigestSize()) {
throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest");
}
if (otsHashAddress == null) {
throw new NullPointerException("otsHashAddress == null");
}
/* (re)initialize WOTS+ instance */
wotsPlus.importKeys(getWOTSPlusSecretKey(otsHashAddress), getPublicSeed());
/* create WOTS+ signature */
return wotsPlus.sign(messageDigest, otsHashAddress);
}
/**
* Compute a root node from a tree signature.
*
* @param messageDigest
* Message digest.
* @param signature
* XMSS signature.
* @return Root node calculated from signature.
*/
protected XMSSNode getRootNodeFromSignature(byte[] messageDigest, XMSSReducedSignature signature,
OTSHashAddress otsHashAddress) {
if (messageDigest.length != params.getDigestSize()) {
throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest");
}
if (signature == null) {
throw new NullPointerException("signature == null");
}
if (otsHashAddress == null) {
throw new NullPointerException("otsHashAddress == null");
}
/* prepare adresses */
LTreeAddress lTreeAddress = (LTreeAddress) new LTreeAddress.Builder()
.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
.withLTreeAddress(otsHashAddress.getOTSAddress()).build();
HashTreeAddress hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
.withTreeIndex(otsHashAddress.getOTSAddress()).build();
/*
* calculate WOTS+ public key and compress to obtain original leaf hash
*/
WOTSPlusPublicKeyParameters wotsPlusPK = wotsPlus.getPublicKeyFromSignature(messageDigest,
signature.getWOTSPlusSignature(), otsHashAddress);
XMSSNode[] node = new XMSSNode[2];
node[0] = lTree(wotsPlusPK, lTreeAddress);
for (int k = 0; k < params.getHeight(); k++) {
hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(hashTreeAddress.getLayerAddress())
.withTreeAddress(hashTreeAddress.getTreeAddress()).withTreeHeight(k)
.withTreeIndex(hashTreeAddress.getTreeIndex()).withKeyAndMask(hashTreeAddress.getKeyAndMask())
.build();
if (Math.floor(privateKey.getIndex() / (1 << k)) % 2 == 0) {
hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(hashTreeAddress.getLayerAddress())
.withTreeAddress(hashTreeAddress.getTreeAddress())
.withTreeHeight(hashTreeAddress.getTreeHeight())
.withTreeIndex(hashTreeAddress.getTreeIndex() / 2)
.withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
node[1] = randomizeHash(node[0], signature.getAuthPath().get(k), hashTreeAddress);
node[1] = new XMSSNode(node[1].getHeight() + 1, node[1].getValue());
} else {
hashTreeAddress = (HashTreeAddress) new HashTreeAddress.Builder()
.withLayerAddress(hashTreeAddress.getLayerAddress())
.withTreeAddress(hashTreeAddress.getTreeAddress())
.withTreeHeight(hashTreeAddress.getTreeHeight())
.withTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2)
.withKeyAndMask(hashTreeAddress.getKeyAndMask()).build();
node[1] = randomizeHash(signature.getAuthPath().get(k), node[0], hashTreeAddress);
node[1] = new XMSSNode(node[1].getHeight() + 1, node[1].getValue());
}
node[0] = node[1];
}
return node[0];
}
/**
* Derive WOTS+ secret key for specific index as in XMSS ref impl Andreas
* Huelsing.
*
* @param otsHashAddress
* @return WOTS+ secret key at index.
*/
protected byte[] getWOTSPlusSecretKey(OTSHashAddress otsHashAddress) {
otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder()
.withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress())
.withOTSAddress(otsHashAddress.getOTSAddress()).build();
return khf.PRF(privateKey.getSecretKeySeed(), otsHashAddress.toByteArray());
}
/**
* Getter XMSS params.
*
* @return XMSS params.
*/
public XMSSParameters getParams() {
return params;
}
/**
* Getter WOTS+.
*
* @return WOTS+ instance.
*/
protected WOTSPlus getWOTSPlus() {
return wotsPlus;
}
protected KeyedHashFunctions getKhf() {
return khf;
}
/**
* Getter XMSS root.
*
* @return Root of binary tree.
*/
public byte[] getRoot() {
return privateKey.getRoot();
}
protected void setRoot(byte[] root) {
try {
privateKey = new XMSSPrivateKeyParameters.Builder(params).withIndex(privateKey.getIndex())
.withSecretKeySeed(privateKey.getSecretKeySeed()).withSecretKeyPRF(privateKey.getSecretKeyPRF())
.withPublicSeed(getPublicSeed()).withRoot(root).withBDSState(privateKey.getBDSState()).build();
publicKey = new XMSSPublicKeyParameters.Builder(params).withRoot(root).withPublicSeed(getPublicSeed())
.build();
} catch (ParseException ex) {
/* should not be possible */
ex.printStackTrace();
} catch (ClassNotFoundException e) {
/* should not be possible */
e.printStackTrace();
} catch (IOException e) {
/* should not be possible */
e.printStackTrace();
}
}
/**
* Getter XMSS index.
*
* @return Index.
*/
public int getIndex() {
return privateKey.getIndex();
}
protected void setIndex(int index) {
try {
privateKey = new XMSSPrivateKeyParameters.Builder(params).withIndex(index)
.withSecretKeySeed(privateKey.getSecretKeySeed()).withSecretKeyPRF(privateKey.getSecretKeyPRF())
.withPublicSeed(privateKey.getPublicSeed()).withRoot(privateKey.getRoot())
.withBDSState(privateKey.getBDSState()).build();
} catch (ParseException ex) {
/* should not happen */
ex.printStackTrace();
} catch (ClassNotFoundException e) {
/* should not be possible */
e.printStackTrace();
} catch (IOException e) {
/* should not be possible */
e.printStackTrace();
}
}
/**
* Getter XMSS public seed.
*
* @return Public seed.
*/
public byte[] getPublicSeed() {
return privateKey.getPublicSeed();
}
protected void setPublicSeed(byte[] publicSeed) {
try {
privateKey = new XMSSPrivateKeyParameters.Builder(params).withIndex(privateKey.getIndex())
.withSecretKeySeed(privateKey.getSecretKeySeed()).withSecretKeyPRF(privateKey.getSecretKeyPRF())
.withPublicSeed(publicSeed).withRoot(getRoot()).withBDSState(privateKey.getBDSState()).build();
publicKey = new XMSSPublicKeyParameters.Builder(params).withRoot(getRoot()).withPublicSeed(publicSeed)
.build();
} catch (ParseException ex) {
/* should not happen */
ex.printStackTrace();
} catch (ClassNotFoundException e) {
/* should not be possible */
e.printStackTrace();
} catch (IOException e) {
/* should not be possible */
e.printStackTrace();
}
wotsPlus.importKeys(new byte[params.getDigestSize()], publicSeed);
}
protected BDS getBDSState() {
return privateKey.getBDSState();
}
}