blob: d1bb7c674bd8fff02f0337fe7c920d60618231c5 [file] [log] [blame]
package org.bouncycastle.pqc.crypto.xmss;
import java.io.IOException;
import java.text.ParseException;
import java.util.Map;
import java.util.TreeMap;
/**
* XMSSMT Private Key.
*
*/
public final class XMSSMTPrivateKeyParameters implements XMSSStoreableObjectInterface {
private final XMSSMTParameters params;
private final long index;
private final byte[] secretKeySeed;
private final byte[] secretKeyPRF;
private final byte[] publicSeed;
private final byte[] root;
private final Map<Integer, BDS> bdsState;
private XMSSMTPrivateKeyParameters(Builder builder) throws ParseException, ClassNotFoundException, IOException {
super();
params = builder.params;
if (params == null) {
throw new NullPointerException("params == null");
}
int n = params.getDigestSize();
byte[] privateKey = builder.privateKey;
if (privateKey != null) {
if (builder.xmss == null) {
throw new NullPointerException("xmss == null");
}
/* import */
int totalHeight = params.getHeight();
int indexSize = (int) Math.ceil(totalHeight / (double) 8);
int secretKeySize = n;
int secretKeyPRFSize = n;
int publicSeedSize = n;
int rootSize = n;
/*
int totalSize = indexSize + secretKeySize + secretKeyPRFSize + publicSeedSize + rootSize;
if (privateKey.length != totalSize) {
throw new ParseException("private key has wrong size", 0);
}
*/
int position = 0;
index = XMSSUtil.bytesToXBigEndian(privateKey, position, indexSize);
if (!XMSSUtil.isIndexValid(totalHeight, index)) {
throw new ParseException("index out of bounds", 0);
}
position += indexSize;
secretKeySeed = XMSSUtil.extractBytesAtOffset(privateKey, position, secretKeySize);
position += secretKeySize;
secretKeyPRF = XMSSUtil.extractBytesAtOffset(privateKey, position, secretKeyPRFSize);
position += secretKeyPRFSize;
publicSeed = XMSSUtil.extractBytesAtOffset(privateKey, position, publicSeedSize);
position += publicSeedSize;
root = XMSSUtil.extractBytesAtOffset(privateKey, position, rootSize);
position += rootSize;
/* import BDS state */
byte[] bdsStateBinary = XMSSUtil.extractBytesAtOffset(privateKey, position, privateKey.length - position);
@SuppressWarnings("unchecked")
Map<Integer, BDS> bdsImport = (TreeMap<Integer, BDS>) XMSSUtil.deserialize(bdsStateBinary);
for (Integer key : bdsImport.keySet()) {
BDS bds = bdsImport.get(key);
bds.setXMSS(builder.xmss);
bds.validate();
}
bdsState = bdsImport;
} else {
/* set */
index = builder.index;
byte[] tmpSecretKeySeed = builder.secretKeySeed;
if (tmpSecretKeySeed != null) {
if (tmpSecretKeySeed.length != n) {
throw new IllegalArgumentException("size of secretKeySeed needs to be equal size of digest");
}
secretKeySeed = tmpSecretKeySeed;
} else {
secretKeySeed = new byte[n];
}
byte[] tmpSecretKeyPRF = builder.secretKeyPRF;
if (tmpSecretKeyPRF != null) {
if (tmpSecretKeyPRF.length != n) {
throw new IllegalArgumentException("size of secretKeyPRF needs to be equal size of digest");
}
secretKeyPRF = tmpSecretKeyPRF;
} else {
secretKeyPRF = new byte[n];
}
byte[] tmpPublicSeed = builder.publicSeed;
if (tmpPublicSeed != null) {
if (tmpPublicSeed.length != n) {
throw new IllegalArgumentException("size of publicSeed needs to be equal size of digest");
}
publicSeed = tmpPublicSeed;
} else {
publicSeed = new byte[n];
}
byte[] tmpRoot = builder.root;
if (tmpRoot != null) {
if (tmpRoot.length != n) {
throw new IllegalArgumentException("size of root needs to be equal size of digest");
}
root = tmpRoot;
} else {
root = new byte[n];
}
Map<Integer, BDS> tmpBDSState = builder.bdsState;
if (tmpBDSState != null) {
bdsState = tmpBDSState;
} else {
bdsState = new TreeMap<Integer, BDS>();
}
}
}
public static class Builder {
/* mandatory */
private final XMSSMTParameters params;
/* optional */
private long index = 0L;
private byte[] secretKeySeed = null;
private byte[] secretKeyPRF = null;
private byte[] publicSeed = null;
private byte[] root = null;
private Map<Integer, BDS> bdsState = null;
private byte[] privateKey = null;
private XMSS xmss = null;
public Builder(XMSSMTParameters params) {
super();
this.params = params;
}
public Builder withIndex(long val) {
index = val;
return this;
}
public Builder withSecretKeySeed(byte[] val) {
secretKeySeed = XMSSUtil.cloneArray(val);
return this;
}
public Builder withSecretKeyPRF(byte[] val) {
secretKeyPRF = XMSSUtil.cloneArray(val);
return this;
}
public Builder withPublicSeed(byte[] val) {
publicSeed = XMSSUtil.cloneArray(val);
return this;
}
public Builder withRoot(byte[] val) {
root = XMSSUtil.cloneArray(val);
return this;
}
public Builder withBDSState(Map<Integer, BDS> val) {
bdsState = val;
return this;
}
public Builder withPrivateKey(byte[] privateKeyVal, XMSS xmssVal) {
privateKey = XMSSUtil.cloneArray(privateKeyVal);
xmss = xmssVal;
return this;
}
public XMSSMTPrivateKeyParameters build() throws ParseException, ClassNotFoundException, IOException {
return new XMSSMTPrivateKeyParameters(this);
}
}
public byte[] toByteArray() {
/* index || secretKeySeed || secretKeyPRF || publicSeed || root */
int n = params.getDigestSize();
int indexSize = (int) Math.ceil(params.getHeight() / (double) 8);
int secretKeySize = n;
int secretKeyPRFSize = n;
int publicSeedSize = n;
int rootSize = n;
int totalSize = indexSize + secretKeySize + secretKeyPRFSize + publicSeedSize + rootSize;
byte[] out = new byte[totalSize];
int position = 0;
/* copy index */
byte[] indexBytes = XMSSUtil.toBytesBigEndian(index, indexSize);
XMSSUtil.copyBytesAtOffset(out, indexBytes, position);
position += indexSize;
/* copy secretKeySeed */
XMSSUtil.copyBytesAtOffset(out, secretKeySeed, position);
position += secretKeySize;
/* copy secretKeyPRF */
XMSSUtil.copyBytesAtOffset(out, secretKeyPRF, position);
position += secretKeyPRFSize;
/* copy publicSeed */
XMSSUtil.copyBytesAtOffset(out, publicSeed, position);
position += publicSeedSize;
/* copy root */
XMSSUtil.copyBytesAtOffset(out, root, position);
/* concatenate bdsState */
byte[] bdsStateOut = null;
try {
bdsStateOut = XMSSUtil.serialize(bdsState);
} catch (IOException e) {
e.printStackTrace();
throw new RuntimeException("error serializing bds state");
}
return XMSSUtil.concat(out, bdsStateOut);
}
/*
protected void increaseIndex(XMSSMT mt) {
if (mt == null) {
throw new NullPointerException("mt == null");
}
ZonedDateTime currentTime = ZonedDateTime.now(ZoneOffset.UTC);
long differenceHours = Duration.between(lastUsage, currentTime).toHours();
if (differenceHours >= 24) {
mt.getXMSS().setPublicSeed(getPublicSeed());
Map<Integer, BDS> bdsStates = mt.getBDS();
int xmssHeight = params.getXMSS().getParams().getHeight();
long keyIncreaseCount = differenceHours * indexIncreaseCountPerHour;
long oldGlobalIndex = getIndex();
long newGlobalIndex = oldGlobalIndex + keyIncreaseCount;
long oldIndexTree = XMSSUtil.getTreeIndex(oldGlobalIndex, xmssHeight);
long newIndexTree = XMSSUtil.getTreeIndex(newGlobalIndex, xmssHeight);
int newIndexLeaf = XMSSUtil.getLeafIndex(newGlobalIndex, xmssHeight);
// adjust bds instances
for (int layer = 0; layer < params.getLayers(); layer++) {
OTSHashAddress otsHashAddress = new OTSHashAddress();
otsHashAddress.setLayerAddress(layer);
otsHashAddress.setTreeAddress(newIndexTree);
if (newIndexLeaf != 0) {
if (oldIndexTree != newIndexTree || bdsStates.get(layer) == null) {
bdsStates.put(layer, new BDS(mt.getXMSS()));
bdsStates.get(layer).initialize(otsHashAddress);
}
for (int indexLeaf = bdsStates.get(layer).getIndex(); indexLeaf < newIndexLeaf; indexLeaf++) {
if (indexLeaf < ((1 << xmssHeight) - 1)) {
bdsStates.get(layer).nextAuthenticationPath(otsHashAddress);
}
}
}
oldIndexTree = XMSSUtil.getTreeIndex(oldIndexTree, xmssHeight);
newIndexLeaf = XMSSUtil.getLeafIndex(newIndexTree, xmssHeight);
newIndexTree = XMSSUtil.getTreeIndex(newIndexTree, xmssHeight);
}
setIndex(newGlobalIndex);
}
}
*/
public long getIndex() {
return index;
}
public byte[] getSecretKeySeed() {
return XMSSUtil.cloneArray(secretKeySeed);
}
public byte[] getSecretKeyPRF() {
return XMSSUtil.cloneArray(secretKeyPRF);
}
public byte[] getPublicSeed() {
return XMSSUtil.cloneArray(publicSeed);
}
public byte[] getRoot() {
return XMSSUtil.cloneArray(root);
}
public Map<Integer, BDS> getBDSState() {
return bdsState;
}
}