| package org.bouncycastle.pqc.crypto.xmss; |
| |
| import java.io.IOException; |
| import java.text.ParseException; |
| import java.util.Map; |
| import java.util.TreeMap; |
| |
| /** |
| * XMSSMT Private Key. |
| * |
| */ |
| public final class XMSSMTPrivateKeyParameters implements XMSSStoreableObjectInterface { |
| |
| private final XMSSMTParameters params; |
| private final long index; |
| private final byte[] secretKeySeed; |
| private final byte[] secretKeyPRF; |
| private final byte[] publicSeed; |
| private final byte[] root; |
| private final Map<Integer, BDS> bdsState; |
| |
| private XMSSMTPrivateKeyParameters(Builder builder) throws ParseException, ClassNotFoundException, IOException { |
| super(); |
| params = builder.params; |
| if (params == null) { |
| throw new NullPointerException("params == null"); |
| } |
| int n = params.getDigestSize(); |
| byte[] privateKey = builder.privateKey; |
| if (privateKey != null) { |
| if (builder.xmss == null) { |
| throw new NullPointerException("xmss == null"); |
| } |
| /* import */ |
| int totalHeight = params.getHeight(); |
| int indexSize = (int) Math.ceil(totalHeight / (double) 8); |
| int secretKeySize = n; |
| int secretKeyPRFSize = n; |
| int publicSeedSize = n; |
| int rootSize = n; |
| /* |
| int totalSize = indexSize + secretKeySize + secretKeyPRFSize + publicSeedSize + rootSize; |
| if (privateKey.length != totalSize) { |
| throw new ParseException("private key has wrong size", 0); |
| } |
| */ |
| int position = 0; |
| index = XMSSUtil.bytesToXBigEndian(privateKey, position, indexSize); |
| if (!XMSSUtil.isIndexValid(totalHeight, index)) { |
| throw new ParseException("index out of bounds", 0); |
| } |
| position += indexSize; |
| secretKeySeed = XMSSUtil.extractBytesAtOffset(privateKey, position, secretKeySize); |
| position += secretKeySize; |
| secretKeyPRF = XMSSUtil.extractBytesAtOffset(privateKey, position, secretKeyPRFSize); |
| position += secretKeyPRFSize; |
| publicSeed = XMSSUtil.extractBytesAtOffset(privateKey, position, publicSeedSize); |
| position += publicSeedSize; |
| root = XMSSUtil.extractBytesAtOffset(privateKey, position, rootSize); |
| position += rootSize; |
| /* import BDS state */ |
| byte[] bdsStateBinary = XMSSUtil.extractBytesAtOffset(privateKey, position, privateKey.length - position); |
| @SuppressWarnings("unchecked") |
| Map<Integer, BDS> bdsImport = (TreeMap<Integer, BDS>) XMSSUtil.deserialize(bdsStateBinary); |
| for (Integer key : bdsImport.keySet()) { |
| BDS bds = bdsImport.get(key); |
| bds.setXMSS(builder.xmss); |
| bds.validate(); |
| } |
| bdsState = bdsImport; |
| } else { |
| /* set */ |
| index = builder.index; |
| byte[] tmpSecretKeySeed = builder.secretKeySeed; |
| if (tmpSecretKeySeed != null) { |
| if (tmpSecretKeySeed.length != n) { |
| throw new IllegalArgumentException("size of secretKeySeed needs to be equal size of digest"); |
| } |
| secretKeySeed = tmpSecretKeySeed; |
| } else { |
| secretKeySeed = new byte[n]; |
| } |
| byte[] tmpSecretKeyPRF = builder.secretKeyPRF; |
| if (tmpSecretKeyPRF != null) { |
| if (tmpSecretKeyPRF.length != n) { |
| throw new IllegalArgumentException("size of secretKeyPRF needs to be equal size of digest"); |
| } |
| secretKeyPRF = tmpSecretKeyPRF; |
| } else { |
| secretKeyPRF = new byte[n]; |
| } |
| byte[] tmpPublicSeed = builder.publicSeed; |
| if (tmpPublicSeed != null) { |
| if (tmpPublicSeed.length != n) { |
| throw new IllegalArgumentException("size of publicSeed needs to be equal size of digest"); |
| } |
| publicSeed = tmpPublicSeed; |
| } else { |
| publicSeed = new byte[n]; |
| } |
| byte[] tmpRoot = builder.root; |
| if (tmpRoot != null) { |
| if (tmpRoot.length != n) { |
| throw new IllegalArgumentException("size of root needs to be equal size of digest"); |
| } |
| root = tmpRoot; |
| } else { |
| root = new byte[n]; |
| } |
| Map<Integer, BDS> tmpBDSState = builder.bdsState; |
| if (tmpBDSState != null) { |
| bdsState = tmpBDSState; |
| } else { |
| bdsState = new TreeMap<Integer, BDS>(); |
| } |
| } |
| } |
| |
| public static class Builder { |
| |
| /* mandatory */ |
| private final XMSSMTParameters params; |
| /* optional */ |
| private long index = 0L; |
| private byte[] secretKeySeed = null; |
| private byte[] secretKeyPRF = null; |
| private byte[] publicSeed = null; |
| private byte[] root = null; |
| private Map<Integer, BDS> bdsState = null; |
| private byte[] privateKey = null; |
| private XMSS xmss = null; |
| |
| public Builder(XMSSMTParameters params) { |
| super(); |
| this.params = params; |
| } |
| |
| public Builder withIndex(long val) { |
| index = val; |
| return this; |
| } |
| |
| public Builder withSecretKeySeed(byte[] val) { |
| secretKeySeed = XMSSUtil.cloneArray(val); |
| return this; |
| } |
| |
| public Builder withSecretKeyPRF(byte[] val) { |
| secretKeyPRF = XMSSUtil.cloneArray(val); |
| return this; |
| } |
| |
| public Builder withPublicSeed(byte[] val) { |
| publicSeed = XMSSUtil.cloneArray(val); |
| return this; |
| } |
| |
| public Builder withRoot(byte[] val) { |
| root = XMSSUtil.cloneArray(val); |
| return this; |
| } |
| |
| public Builder withBDSState(Map<Integer, BDS> val) { |
| bdsState = val; |
| return this; |
| } |
| |
| public Builder withPrivateKey(byte[] privateKeyVal, XMSS xmssVal) { |
| privateKey = XMSSUtil.cloneArray(privateKeyVal); |
| xmss = xmssVal; |
| return this; |
| } |
| |
| public XMSSMTPrivateKeyParameters build() throws ParseException, ClassNotFoundException, IOException { |
| return new XMSSMTPrivateKeyParameters(this); |
| } |
| } |
| |
| public byte[] toByteArray() { |
| /* index || secretKeySeed || secretKeyPRF || publicSeed || root */ |
| int n = params.getDigestSize(); |
| int indexSize = (int) Math.ceil(params.getHeight() / (double) 8); |
| int secretKeySize = n; |
| int secretKeyPRFSize = n; |
| int publicSeedSize = n; |
| int rootSize = n; |
| int totalSize = indexSize + secretKeySize + secretKeyPRFSize + publicSeedSize + rootSize; |
| byte[] out = new byte[totalSize]; |
| int position = 0; |
| /* copy index */ |
| byte[] indexBytes = XMSSUtil.toBytesBigEndian(index, indexSize); |
| XMSSUtil.copyBytesAtOffset(out, indexBytes, position); |
| position += indexSize; |
| /* copy secretKeySeed */ |
| XMSSUtil.copyBytesAtOffset(out, secretKeySeed, position); |
| position += secretKeySize; |
| /* copy secretKeyPRF */ |
| XMSSUtil.copyBytesAtOffset(out, secretKeyPRF, position); |
| position += secretKeyPRFSize; |
| /* copy publicSeed */ |
| XMSSUtil.copyBytesAtOffset(out, publicSeed, position); |
| position += publicSeedSize; |
| /* copy root */ |
| XMSSUtil.copyBytesAtOffset(out, root, position); |
| /* concatenate bdsState */ |
| byte[] bdsStateOut = null; |
| try { |
| bdsStateOut = XMSSUtil.serialize(bdsState); |
| } catch (IOException e) { |
| e.printStackTrace(); |
| throw new RuntimeException("error serializing bds state"); |
| } |
| return XMSSUtil.concat(out, bdsStateOut); |
| } |
| |
| /* |
| protected void increaseIndex(XMSSMT mt) { |
| if (mt == null) { |
| throw new NullPointerException("mt == null"); |
| } |
| ZonedDateTime currentTime = ZonedDateTime.now(ZoneOffset.UTC); |
| long differenceHours = Duration.between(lastUsage, currentTime).toHours(); |
| if (differenceHours >= 24) { |
| mt.getXMSS().setPublicSeed(getPublicSeed()); |
| |
| Map<Integer, BDS> bdsStates = mt.getBDS(); |
| int xmssHeight = params.getXMSS().getParams().getHeight(); |
| long keyIncreaseCount = differenceHours * indexIncreaseCountPerHour; |
| long oldGlobalIndex = getIndex(); |
| long newGlobalIndex = oldGlobalIndex + keyIncreaseCount; |
| long oldIndexTree = XMSSUtil.getTreeIndex(oldGlobalIndex, xmssHeight); |
| long newIndexTree = XMSSUtil.getTreeIndex(newGlobalIndex, xmssHeight); |
| int newIndexLeaf = XMSSUtil.getLeafIndex(newGlobalIndex, xmssHeight); |
| |
| // adjust bds instances |
| for (int layer = 0; layer < params.getLayers(); layer++) { |
| OTSHashAddress otsHashAddress = new OTSHashAddress(); |
| otsHashAddress.setLayerAddress(layer); |
| otsHashAddress.setTreeAddress(newIndexTree); |
| |
| if (newIndexLeaf != 0) { |
| if (oldIndexTree != newIndexTree || bdsStates.get(layer) == null) { |
| bdsStates.put(layer, new BDS(mt.getXMSS())); |
| bdsStates.get(layer).initialize(otsHashAddress); |
| } |
| for (int indexLeaf = bdsStates.get(layer).getIndex(); indexLeaf < newIndexLeaf; indexLeaf++) { |
| if (indexLeaf < ((1 << xmssHeight) - 1)) { |
| bdsStates.get(layer).nextAuthenticationPath(otsHashAddress); |
| } |
| } |
| } |
| oldIndexTree = XMSSUtil.getTreeIndex(oldIndexTree, xmssHeight); |
| newIndexLeaf = XMSSUtil.getLeafIndex(newIndexTree, xmssHeight); |
| newIndexTree = XMSSUtil.getTreeIndex(newIndexTree, xmssHeight); |
| } |
| setIndex(newGlobalIndex); |
| } |
| } |
| */ |
| |
| public long getIndex() { |
| return index; |
| } |
| |
| public byte[] getSecretKeySeed() { |
| return XMSSUtil.cloneArray(secretKeySeed); |
| } |
| |
| public byte[] getSecretKeyPRF() { |
| return XMSSUtil.cloneArray(secretKeyPRF); |
| } |
| |
| public byte[] getPublicSeed() { |
| return XMSSUtil.cloneArray(publicSeed); |
| } |
| |
| public byte[] getRoot() { |
| return XMSSUtil.cloneArray(root); |
| } |
| |
| public Map<Integer, BDS> getBDSState() { |
| return bdsState; |
| } |
| } |