blob: c3839be3ee3a34afbc1b86ae6bcebca0eb02ced5 [file] [log] [blame]
package org.bouncycastle.pqc.crypto.xmss;
import java.io.IOException;
import java.security.SecureRandom;
import java.text.ParseException;
import java.util.Map;
import java.util.TreeMap;
/**
* XMSS^MT.
*
*/
public final class XMSSMT {
private XMSSMTParameters params;
private XMSS xmss;
private SecureRandom prng;
private KeyedHashFunctions khf;
private XMSSMTPrivateKeyParameters privateKey;
private XMSSMTPublicKeyParameters publicKey;
/**
* XMSSMT constructor...
*
* @param params
* XMSSMTParameters.
*/
public XMSSMT(XMSSMTParameters params) {
super();
if (params == null) {
throw new NullPointerException("params == null");
}
this.params = params;
xmss = params.getXMSS();
prng = params.getXMSS().getParams().getPRNG();
khf = xmss.getKhf();
try {
privateKey = new XMSSMTPrivateKeyParameters.Builder(params).build();
publicKey = new XMSSMTPublicKeyParameters.Builder(params).build();
} catch (ParseException e) {
/* should not be possible */
e.printStackTrace();
} catch (ClassNotFoundException e) {
/* should not be possible */
e.printStackTrace();
} catch (IOException e) {
/* should not be possible */
e.printStackTrace();
}
}
/**
* Generate a new XMSSMT private key / public key pair.
*
*/
public void generateKeys() {
/* generate XMSSMT private key */
privateKey = generatePrivateKey();
/* init global xmss */
XMSSPrivateKeyParameters xmssPrivateKey = null;
XMSSPublicKeyParameters xmssPublicKey = null;
try {
xmssPrivateKey = new XMSSPrivateKeyParameters.Builder(xmss.getParams())
.withSecretKeySeed(privateKey.getSecretKeySeed()).withSecretKeyPRF(privateKey.getSecretKeyPRF())
.withPublicSeed(privateKey.getPublicSeed()).withBDSState(new BDS(xmss)).build();
xmssPublicKey = new XMSSPublicKeyParameters.Builder(xmss.getParams()).withPublicSeed(getPublicSeed())
.build();
} catch (ParseException ex) {
/* should not be possible */
ex.printStackTrace();
} catch (ClassNotFoundException e) {
/* should not be possible */
e.printStackTrace();
} catch (IOException e) {
/* should not be possible */
e.printStackTrace();
}
/* import to xmss */
try {
xmss.importState(xmssPrivateKey.toByteArray(), xmssPublicKey.toByteArray());
} catch (ParseException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
/* get root */
int rootLayerIndex = params.getLayers() - 1;
OTSHashAddress otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder().withLayerAddress(rootLayerIndex)
.build();
/* store BDS instance of root xmss instance */
BDS bdsRoot = new BDS(xmss);
XMSSNode root = bdsRoot.initialize(otsHashAddress);
getBDSState().put(rootLayerIndex, bdsRoot);
xmss.setRoot(root.getValue());
/* set XMSS^MT root / create public key */
try {
privateKey = new XMSSMTPrivateKeyParameters.Builder(params).withSecretKeySeed(privateKey.getSecretKeySeed())
.withSecretKeyPRF(privateKey.getSecretKeyPRF()).withPublicSeed(privateKey.getPublicSeed())
.withRoot(xmss.getRoot()).withBDSState(privateKey.getBDSState()).build();
publicKey = new XMSSMTPublicKeyParameters.Builder(params).withRoot(root.getValue())
.withPublicSeed(getPublicSeed()).build();
} catch (ParseException e) {
/* should not be possible */
e.printStackTrace();
} catch (ClassNotFoundException e) {
/* should not be possible */
e.printStackTrace();
} catch (IOException e) {
/* should not be possible */
e.printStackTrace();
}
}
private XMSSMTPrivateKeyParameters generatePrivateKey() {
int n = params.getDigestSize();
byte[] secretKeySeed = new byte[n];
prng.nextBytes(secretKeySeed);
byte[] secretKeyPRF = new byte[n];
prng.nextBytes(secretKeyPRF);
byte[] publicSeed = new byte[n];
prng.nextBytes(publicSeed);
XMSSMTPrivateKeyParameters privateKey = null;
try {
privateKey = new XMSSMTPrivateKeyParameters.Builder(params).withSecretKeySeed(secretKeySeed)
.withSecretKeyPRF(secretKeyPRF).withPublicSeed(publicSeed)
.withBDSState(this.privateKey.getBDSState()).build();
} catch (ParseException ex) {
/* should not be possible */
ex.printStackTrace();
} catch (ClassNotFoundException e) {
/* should not be possible */
e.printStackTrace();
} catch (IOException e) {
/* should not be possible */
e.printStackTrace();
}
return privateKey;
}
/**
* Import XMSSMT private key / public key pair.
*
* @param privateKey
* XMSSMT private key.
* @param publicKey
* XMSSMT public key.
* @throws ParseException
* @throws ClassNotFoundException
* @throws IOException
*/
public void importState(byte[] privateKey, byte[] publicKey)
throws ParseException, ClassNotFoundException, IOException {
if (privateKey == null) {
throw new NullPointerException("privateKey == null");
}
if (publicKey == null) {
throw new NullPointerException("publicKey == null");
}
XMSSMTPrivateKeyParameters xmssMTPrivateKey = new XMSSMTPrivateKeyParameters.Builder(params)
.withPrivateKey(privateKey, xmss).build();
XMSSMTPublicKeyParameters xmssMTPublicKey = new XMSSMTPublicKeyParameters.Builder(params)
.withPublicKey(publicKey).build();
if (!XMSSUtil.compareByteArray(xmssMTPrivateKey.getRoot(), xmssMTPublicKey.getRoot())) {
throw new IllegalStateException("root of private key and public key do not match");
}
if (!XMSSUtil.compareByteArray(xmssMTPrivateKey.getPublicSeed(), xmssMTPublicKey.getPublicSeed())) {
throw new IllegalStateException("public seed of private key and public key do not match");
}
/* init global xmss */
XMSSPrivateKeyParameters xmssPrivateKey = new XMSSPrivateKeyParameters.Builder(xmss.getParams())
.withSecretKeySeed(xmssMTPrivateKey.getSecretKeySeed())
.withSecretKeyPRF(xmssMTPrivateKey.getSecretKeyPRF()).withPublicSeed(xmssMTPrivateKey.getPublicSeed())
.withRoot(xmssMTPrivateKey.getRoot()).withBDSState(new BDS(xmss)).build();
XMSSPublicKeyParameters xmssPublicKey = new XMSSPublicKeyParameters.Builder(xmss.getParams())
.withRoot(xmssMTPrivateKey.getRoot()).withPublicSeed(getPublicSeed()).build();
/* import to xmss */
xmss.importState(xmssPrivateKey.toByteArray(), xmssPublicKey.toByteArray());
this.privateKey = xmssMTPrivateKey;
this.publicKey = xmssMTPublicKey;
}
/**
* Sign message.
*
* @param message
* Message to sign.
* @return XMSSMT signature on digest of message.
*/
public byte[] sign(byte[] message) {
if (message == null) {
throw new NullPointerException("message == null");
}
if (getBDSState().isEmpty()) {
throw new IllegalStateException("not initialized");
}
// privateKey.increaseIndex(this);
long globalIndex = getIndex();
int totalHeight = params.getHeight();
int xmssHeight = xmss.getParams().getHeight();
if (!XMSSUtil.isIndexValid(totalHeight, globalIndex)) {
throw new IllegalArgumentException("index out of bounds");
}
/* compress message */
byte[] random = khf.PRF(privateKey.getSecretKeyPRF(), XMSSUtil.toBytesBigEndian(globalIndex, 32));
byte[] concatenated = XMSSUtil.concat(random, privateKey.getRoot(),
XMSSUtil.toBytesBigEndian(globalIndex, params.getDigestSize()));
byte[] messageDigest = khf.HMsg(concatenated, message);
XMSSMTSignature signature = null;
try {
signature = new XMSSMTSignature.Builder(params).withIndex(globalIndex).withRandom(random).build();
} catch (ParseException ex) {
/* should not be possible */
ex.printStackTrace();
}
/* layer 0 */
long indexTree = XMSSUtil.getTreeIndex(globalIndex, xmssHeight);
int indexLeaf = XMSSUtil.getLeafIndex(globalIndex, xmssHeight);
/* reset xmss */
xmss.setIndex(indexLeaf);
xmss.setPublicSeed(getPublicSeed());
/* create signature with XMSS tree on layer 0 */
/* adjust addresses */
OTSHashAddress otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder().withTreeAddress(indexTree)
.withOTSAddress(indexLeaf).build();
/* sign message digest */
WOTSPlusSignature wotsPlusSignature = xmss.wotsSign(messageDigest, otsHashAddress);
/* get authentication path from BDS */
if (getBDSState().get(0) == null || indexLeaf == 0) {
getBDSState().put(0, new BDS(xmss));
getBDSState().get(0).initialize(otsHashAddress);
}
XMSSReducedSignature reducedSignature = null;
try {
reducedSignature = new XMSSReducedSignature.Builder(xmss.getParams())
.withWOTSPlusSignature(wotsPlusSignature).withAuthPath(getBDSState().get(0).getAuthenticationPath())
.build();
} catch (ParseException ex) {
/* should never happen */
ex.printStackTrace();
}
signature.getReducedSignatures().add(reducedSignature);
/* prepare authentication path for next leaf */
if (indexLeaf < ((1 << xmssHeight) - 1)) {
getBDSState().get(0).nextAuthenticationPath(otsHashAddress);
}
/* loop over remaining layers */
for (int layer = 1; layer < params.getLayers(); layer++) {
/* get root of layer - 1 */
XMSSNode root = getBDSState().get(layer - 1).getRoot();
indexLeaf = XMSSUtil.getLeafIndex(indexTree, xmssHeight);
indexTree = XMSSUtil.getTreeIndex(indexTree, xmssHeight);
xmss.setIndex(indexLeaf);
/* adjust addresses */
otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder().withLayerAddress(layer)
.withTreeAddress(indexTree).withOTSAddress(indexLeaf).build();
/* sign root digest of layer - 1 */
wotsPlusSignature = xmss.wotsSign(root.getValue(), otsHashAddress);
/* get authentication path from BDS */
if (getBDSState().get(layer) == null || XMSSUtil.isNewBDSInitNeeded(globalIndex, xmssHeight, layer)) {
getBDSState().put(layer, new BDS(xmss));
getBDSState().get(layer).initialize(otsHashAddress);
}
try {
reducedSignature = new XMSSReducedSignature.Builder(xmss.getParams())
.withWOTSPlusSignature(wotsPlusSignature)
.withAuthPath(getBDSState().get(layer).getAuthenticationPath()).build();
} catch (ParseException ex) {
/* should never happen */
ex.printStackTrace();
}
signature.getReducedSignatures().add(reducedSignature);
/* prepare authentication path for next leaf */
if (indexLeaf < ((1 << xmssHeight) - 1)
&& XMSSUtil.isNewAuthenticationPathNeeded(globalIndex, xmssHeight, layer)) {
getBDSState().get(layer).nextAuthenticationPath(otsHashAddress);
}
}
/* update private key */
try {
privateKey = new XMSSMTPrivateKeyParameters.Builder(params).withIndex(globalIndex + 1)
.withSecretKeySeed(privateKey.getSecretKeySeed()).withSecretKeyPRF(privateKey.getSecretKeyPRF())
.withPublicSeed(privateKey.getPublicSeed()).withRoot(privateKey.getRoot())
.withBDSState(privateKey.getBDSState()).build();
} catch (ParseException e) {
/* should not be possible */
e.printStackTrace();
} catch (ClassNotFoundException e) {
/* should not be possible */
e.printStackTrace();
} catch (IOException e) {
/* should not be possible */
e.printStackTrace();
}
return signature.toByteArray();
}
/**
* Verify an XMSSMT signature.
*
* @param message
* Message.
* @param signature
* XMSSMT signature.
* @param publicKey
* XMSSMT public key.
* @return true if signature is valid false else.
* @throws ParseException
*/
public boolean verifySignature(byte[] message, byte[] signature, byte[] publicKey) throws ParseException {
if (message == null) {
throw new NullPointerException("message == null");
}
if (signature == null) {
throw new NullPointerException("signature == null");
}
if (publicKey == null) {
throw new NullPointerException("publicKey == null");
}
/* (re)create compressed message */
XMSSMTSignature sig = new XMSSMTSignature.Builder(params).withSignature(signature).build();
XMSSMTPublicKeyParameters pubKey = new XMSSMTPublicKeyParameters.Builder(params).withPublicKey(publicKey)
.build();
byte[] concatenated = XMSSUtil.concat(sig.getRandom(), pubKey.getRoot(),
XMSSUtil.toBytesBigEndian(sig.getIndex(), params.getDigestSize()));
byte[] messageDigest = khf.HMsg(concatenated, message);
long globalIndex = sig.getIndex();
int xmssHeight = xmss.getParams().getHeight();
long indexTree = XMSSUtil.getTreeIndex(globalIndex, xmssHeight);
int indexLeaf = XMSSUtil.getLeafIndex(globalIndex, xmssHeight);
/* adjust xmss */
xmss.setIndex(indexLeaf);
xmss.setPublicSeed(pubKey.getPublicSeed());
/* prepare addresses */
OTSHashAddress otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder().withTreeAddress(indexTree)
.withOTSAddress(indexLeaf).build();
/* get root node on layer 0 */
XMSSReducedSignature xmssMTSignature = sig.getReducedSignatures().get(0);
XMSSNode rootNode = xmss.getRootNodeFromSignature(messageDigest, xmssMTSignature, otsHashAddress);
for (int layer = 1; layer < params.getLayers(); layer++) {
xmssMTSignature = sig.getReducedSignatures().get(layer);
indexLeaf = XMSSUtil.getLeafIndex(indexTree, xmssHeight);
indexTree = XMSSUtil.getTreeIndex(indexTree, xmssHeight);
xmss.setIndex(indexLeaf);
/* adjust address */
otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder().withLayerAddress(layer)
.withTreeAddress(indexTree).withOTSAddress(indexLeaf).build();
/* get root node */
rootNode = xmss.getRootNodeFromSignature(rootNode.getValue(), xmssMTSignature, otsHashAddress);
}
/* compare roots */
return XMSSUtil.compareByteArray(rootNode.getValue(), pubKey.getRoot());
}
/**
* Export XMSSMT private key.
*
* @return XMSSMT private key.
*/
public byte[] exportPrivateKey() {
return privateKey.toByteArray();
}
/**
* Export XMSSMT public key.
*
* @return XMSSMT public key.
*/
public byte[] exportPublicKey() {
return publicKey.toByteArray();
}
/**
* Getter XMSSMT params.
*
* @return XMSSMT params.
*/
public XMSSMTParameters getParams() {
return params;
}
/**
* Getter XMSSMT index.
*
* @return XMSSMT index.
*/
public long getIndex() {
return privateKey.getIndex();
}
/**
* Getter public seed.
*
* @return Public seed.
*/
public byte[] getPublicSeed() {
return privateKey.getPublicSeed();
}
protected Map<Integer, BDS> getBDSState() {
return privateKey.getBDSState();
}
protected XMSS getXMSS() {
return xmss;
}
}