| package org.bouncycastle.pqc.crypto.xmss; |
| |
| import java.io.Serializable; |
| import java.util.Stack; |
| |
| |
| class BDSTreeHash |
| implements Serializable |
| { |
| private static final long serialVersionUID = 1L; |
| |
| private XMSSNode tailNode; |
| private final int initialHeight; |
| private int height; |
| private int nextIndex; |
| private boolean initialized; |
| private boolean finished; |
| |
| BDSTreeHash(int initialHeight) |
| { |
| super(); |
| this.initialHeight = initialHeight; |
| initialized = false; |
| finished = false; |
| } |
| |
| void initialize(int nextIndex) |
| { |
| tailNode = null; |
| height = initialHeight; |
| this.nextIndex = nextIndex; |
| initialized = true; |
| finished = false; |
| } |
| |
| void update(Stack<XMSSNode> stack, WOTSPlus wotsPlus, byte[] publicSeed, byte[] secretSeed, OTSHashAddress otsHashAddress) |
| { |
| if (otsHashAddress == null) |
| { |
| throw new NullPointerException("otsHashAddress == null"); |
| } |
| if (finished || !initialized) |
| { |
| throw new IllegalStateException("finished or not initialized"); |
| } |
| /* prepare addresses */ |
| otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder() |
| .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) |
| .withOTSAddress(nextIndex).withChainAddress(otsHashAddress.getChainAddress()) |
| .withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask()) |
| .build(); |
| LTreeAddress lTreeAddress = (LTreeAddress)new LTreeAddress.Builder() |
| .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) |
| .withLTreeAddress(nextIndex).build(); |
| HashTreeAddress hashTreeAddress = (HashTreeAddress)new HashTreeAddress.Builder() |
| .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) |
| .withTreeIndex(nextIndex).build(); |
| /* calculate leaf node */ |
| wotsPlus.importKeys(wotsPlus.getWOTSPlusSecretKey(secretSeed, otsHashAddress), publicSeed); |
| WOTSPlusPublicKeyParameters wotsPlusPublicKey = wotsPlus.getPublicKey(otsHashAddress); |
| XMSSNode node = XMSSNodeUtil.lTree(wotsPlus, wotsPlusPublicKey, lTreeAddress); |
| |
| while (!stack.isEmpty() && stack.peek().getHeight() == node.getHeight() |
| && stack.peek().getHeight() != initialHeight) |
| { |
| hashTreeAddress = (HashTreeAddress)new HashTreeAddress.Builder() |
| .withLayerAddress(hashTreeAddress.getLayerAddress()) |
| .withTreeAddress(hashTreeAddress.getTreeAddress()) |
| .withTreeHeight(hashTreeAddress.getTreeHeight()) |
| .withTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2) |
| .withKeyAndMask(hashTreeAddress.getKeyAndMask()).build(); |
| node = XMSSNodeUtil.randomizeHash(wotsPlus, stack.pop(), node, hashTreeAddress); |
| node = new XMSSNode(node.getHeight() + 1, node.getValue()); |
| hashTreeAddress = (HashTreeAddress)new HashTreeAddress.Builder() |
| .withLayerAddress(hashTreeAddress.getLayerAddress()) |
| .withTreeAddress(hashTreeAddress.getTreeAddress()) |
| .withTreeHeight(hashTreeAddress.getTreeHeight() + 1) |
| .withTreeIndex(hashTreeAddress.getTreeIndex()).withKeyAndMask(hashTreeAddress.getKeyAndMask()) |
| .build(); |
| } |
| |
| if (tailNode == null) |
| { |
| tailNode = node; |
| } |
| else |
| { |
| if (tailNode.getHeight() == node.getHeight()) |
| { |
| hashTreeAddress = (HashTreeAddress)new HashTreeAddress.Builder() |
| .withLayerAddress(hashTreeAddress.getLayerAddress()) |
| .withTreeAddress(hashTreeAddress.getTreeAddress()) |
| .withTreeHeight(hashTreeAddress.getTreeHeight()) |
| .withTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2) |
| .withKeyAndMask(hashTreeAddress.getKeyAndMask()).build(); |
| node = XMSSNodeUtil.randomizeHash(wotsPlus, tailNode, node, hashTreeAddress); |
| node = new XMSSNode(tailNode.getHeight() + 1, node.getValue()); |
| tailNode = node; |
| hashTreeAddress = (HashTreeAddress)new HashTreeAddress.Builder() |
| .withLayerAddress(hashTreeAddress.getLayerAddress()) |
| .withTreeAddress(hashTreeAddress.getTreeAddress()) |
| .withTreeHeight(hashTreeAddress.getTreeHeight() + 1) |
| .withTreeIndex(hashTreeAddress.getTreeIndex()) |
| .withKeyAndMask(hashTreeAddress.getKeyAndMask()).build(); |
| } |
| else |
| { |
| stack.push(node); |
| } |
| } |
| |
| if (tailNode.getHeight() == initialHeight) |
| { |
| finished = true; |
| } |
| else |
| { |
| height = node.getHeight(); |
| nextIndex++; |
| } |
| } |
| |
| int getHeight() |
| { |
| if (!initialized || finished) |
| { |
| return Integer.MAX_VALUE; |
| } |
| return height; |
| } |
| |
| int getIndexLeaf() |
| { |
| return nextIndex; |
| } |
| |
| void setNode(XMSSNode node) |
| { |
| tailNode = node; |
| height = node.getHeight(); |
| if (height == initialHeight) |
| { |
| finished = true; |
| } |
| } |
| |
| boolean isFinished() |
| { |
| return finished; |
| } |
| |
| boolean isInitialized() |
| { |
| return initialized; |
| } |
| |
| public XMSSNode getTailNode() |
| { |
| return tailNode.clone(); |
| } |
| } |
| |