| package org.bouncycastle.pqc.crypto.xmss; |
| |
| import java.util.ArrayList; |
| import java.util.List; |
| |
| import org.bouncycastle.util.Arrays; |
| |
| /** |
| * WOTS+. |
| */ |
| final class WOTSPlus |
| { |
| |
| /** |
| * WOTS+ parameters. |
| */ |
| private final WOTSPlusParameters params; |
| /** |
| * Randomization functions. |
| */ |
| private final KeyedHashFunctions khf; |
| /** |
| * WOTS+ secret key seed. |
| */ |
| private byte[] secretKeySeed; |
| /** |
| * WOTS+ public seed. |
| */ |
| private byte[] publicSeed; |
| |
| /** |
| * Constructs a new WOTS+ one-time signature system based on the given WOTS+ |
| * parameters. |
| * |
| * @param params Parameters for WOTSPlus object. |
| */ |
| protected WOTSPlus(WOTSPlusParameters params) |
| { |
| super(); |
| if (params == null) |
| { |
| throw new NullPointerException("params == null"); |
| } |
| this.params = params; |
| int n = params.getDigestSize(); |
| khf = new KeyedHashFunctions(params.getDigest(), n); |
| secretKeySeed = new byte[n]; |
| publicSeed = new byte[n]; |
| } |
| |
| /** |
| * Import keys to WOTS+ instance. |
| * |
| * @param secretKeySeed Secret key seed. |
| * @param publicSeed Public seed. |
| */ |
| void importKeys(byte[] secretKeySeed, byte[] publicSeed) |
| { |
| if (secretKeySeed == null) |
| { |
| throw new NullPointerException("secretKeySeed == null"); |
| } |
| if (secretKeySeed.length != params.getDigestSize()) |
| { |
| throw new IllegalArgumentException("size of secretKeySeed needs to be equal to size of digest"); |
| } |
| if (publicSeed == null) |
| { |
| throw new NullPointerException("publicSeed == null"); |
| } |
| if (publicSeed.length != params.getDigestSize()) |
| { |
| throw new IllegalArgumentException("size of publicSeed needs to be equal to size of digest"); |
| } |
| this.secretKeySeed = secretKeySeed; |
| this.publicSeed = publicSeed; |
| } |
| |
| /** |
| * Creates a signature for the n-byte messageDigest. |
| * |
| * @param messageDigest Digest to sign. |
| * @param otsHashAddress OTS hash address for randomization. |
| * @return WOTS+ signature. |
| */ |
| protected WOTSPlusSignature sign(byte[] messageDigest, OTSHashAddress otsHashAddress) |
| { |
| if (messageDigest == null) |
| { |
| throw new NullPointerException("messageDigest == null"); |
| } |
| if (messageDigest.length != params.getDigestSize()) |
| { |
| throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest"); |
| } |
| if (otsHashAddress == null) |
| { |
| throw new NullPointerException("otsHashAddress == null"); |
| } |
| List<Integer> baseWMessage = convertToBaseW(messageDigest, params.getWinternitzParameter(), params.getLen1()); |
| /* create checksum */ |
| int checksum = 0; |
| for (int i = 0; i < params.getLen1(); i++) |
| { |
| checksum += params.getWinternitzParameter() - 1 - baseWMessage.get(i); |
| } |
| checksum <<= (8 - ((params.getLen2() * XMSSUtil.log2(params.getWinternitzParameter())) % 8)); |
| int len2Bytes = (int)Math |
| .ceil((double)(params.getLen2() * XMSSUtil.log2(params.getWinternitzParameter())) / 8); |
| List<Integer> baseWChecksum = convertToBaseW(XMSSUtil.toBytesBigEndian(checksum, len2Bytes), |
| params.getWinternitzParameter(), params.getLen2()); |
| |
| /* msg || checksum */ |
| baseWMessage.addAll(baseWChecksum); |
| |
| /* create signature */ |
| byte[][] signature = new byte[params.getLen()][]; |
| for (int i = 0; i < params.getLen(); i++) |
| { |
| otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder() |
| .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) |
| .withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(i) |
| .withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask()) |
| .build(); |
| signature[i] = chain(expandSecretKeySeed(i), 0, baseWMessage.get(i), otsHashAddress); |
| } |
| return new WOTSPlusSignature(params, signature); |
| } |
| |
| /** |
| * Verifies signature on message. |
| * |
| * @param messageDigest The digest that was signed. |
| * @param signature Signature on digest. |
| * @param otsHashAddress OTS hash address for randomization. |
| * @return true if signature was correct false else. |
| */ |
| protected boolean verifySignature(byte[] messageDigest, WOTSPlusSignature signature, |
| OTSHashAddress otsHashAddress) |
| { |
| if (messageDigest == null) |
| { |
| throw new NullPointerException("messageDigest == null"); |
| } |
| if (messageDigest.length != params.getDigestSize()) |
| { |
| throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest"); |
| } |
| if (signature == null) |
| { |
| throw new NullPointerException("signature == null"); |
| } |
| if (otsHashAddress == null) |
| { |
| throw new NullPointerException("otsHashAddress == null"); |
| } |
| byte[][] tmpPublicKey = getPublicKeyFromSignature(messageDigest, signature, otsHashAddress).toByteArray(); |
| /* compare values */ |
| return XMSSUtil.areEqual(tmpPublicKey, getPublicKey(otsHashAddress).toByteArray()) ? true : false; |
| } |
| |
| /** |
| * Calculates a public key based on digest and signature. |
| * |
| * @param messageDigest The digest that was signed. |
| * @param signature Signarure on digest. |
| * @param otsHashAddress OTS hash address for randomization. |
| * @return WOTS+ public key derived from digest and signature. |
| */ |
| protected WOTSPlusPublicKeyParameters getPublicKeyFromSignature(byte[] messageDigest, WOTSPlusSignature signature, |
| OTSHashAddress otsHashAddress) |
| { |
| if (messageDigest == null) |
| { |
| throw new NullPointerException("messageDigest == null"); |
| } |
| if (messageDigest.length != params.getDigestSize()) |
| { |
| throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest"); |
| } |
| if (signature == null) |
| { |
| throw new NullPointerException("signature == null"); |
| } |
| if (otsHashAddress == null) |
| { |
| throw new NullPointerException("otsHashAddress == null"); |
| } |
| List<Integer> baseWMessage = convertToBaseW(messageDigest, params.getWinternitzParameter(), params.getLen1()); |
| /* create checksum */ |
| int checksum = 0; |
| for (int i = 0; i < params.getLen1(); i++) |
| { |
| checksum += params.getWinternitzParameter() - 1 - baseWMessage.get(i); |
| } |
| checksum <<= (8 - ((params.getLen2() * XMSSUtil.log2(params.getWinternitzParameter())) % 8)); |
| int len2Bytes = (int)Math |
| .ceil((double)(params.getLen2() * XMSSUtil.log2(params.getWinternitzParameter())) / 8); |
| List<Integer> baseWChecksum = convertToBaseW(XMSSUtil.toBytesBigEndian(checksum, len2Bytes), |
| params.getWinternitzParameter(), params.getLen2()); |
| |
| /* msg || checksum */ |
| baseWMessage.addAll(baseWChecksum); |
| |
| byte[][] publicKey = new byte[params.getLen()][]; |
| for (int i = 0; i < params.getLen(); i++) |
| { |
| otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder() |
| .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) |
| .withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(i) |
| .withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask()) |
| .build(); |
| publicKey[i] = chain(signature.toByteArray()[i], baseWMessage.get(i), |
| params.getWinternitzParameter() - 1 - baseWMessage.get(i), otsHashAddress); |
| } |
| return new WOTSPlusPublicKeyParameters(params, publicKey); |
| } |
| |
| /** |
| * Computes an iteration of F on an n-byte input using outputs of PRF. |
| * |
| * @param startHash Starting point. |
| * @param startIndex Start index. |
| * @param steps Steps to take. |
| * @param otsHashAddress OTS hash address for randomization. |
| * @return Value obtained by iterating F for steps times on input startHash, |
| * using the outputs of PRF. |
| */ |
| private byte[] chain(byte[] startHash, int startIndex, int steps, OTSHashAddress otsHashAddress) |
| { |
| int n = params.getDigestSize(); |
| if (startHash == null) |
| { |
| throw new NullPointerException("startHash == null"); |
| } |
| if (startHash.length != n) |
| { |
| throw new IllegalArgumentException("startHash needs to be " + n + "bytes"); |
| } |
| if (otsHashAddress == null) |
| { |
| throw new NullPointerException("otsHashAddress == null"); |
| } |
| if (otsHashAddress.toByteArray() == null) |
| { |
| throw new NullPointerException("otsHashAddress byte array == null"); |
| } |
| if ((startIndex + steps) > params.getWinternitzParameter() - 1) |
| { |
| throw new IllegalArgumentException("max chain length must not be greater than w"); |
| } |
| |
| if (steps == 0) |
| { |
| return startHash; |
| } |
| |
| byte[] tmp = chain(startHash, startIndex, steps - 1, otsHashAddress); |
| otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder() |
| .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) |
| .withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(otsHashAddress.getChainAddress()) |
| .withHashAddress(startIndex + steps - 1).withKeyAndMask(0).build(); |
| byte[] key = khf.PRF(publicSeed, otsHashAddress.toByteArray()); |
| otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder() |
| .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) |
| .withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(otsHashAddress.getChainAddress()) |
| .withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(1).build(); |
| byte[] bitmask = khf.PRF(publicSeed, otsHashAddress.toByteArray()); |
| byte[] tmpMasked = new byte[n]; |
| for (int i = 0; i < n; i++) |
| { |
| tmpMasked[i] = (byte)(tmp[i] ^ bitmask[i]); |
| } |
| tmp = khf.F(key, tmpMasked); |
| return tmp; |
| } |
| |
| /** |
| * Obtain base w values from Input. |
| * |
| * @param messageDigest Input data. |
| * @param w Base. |
| * @param outLength Length of output. |
| * @return outLength-length list of base w integers. |
| */ |
| private List<Integer> convertToBaseW(byte[] messageDigest, int w, int outLength) |
| { |
| if (messageDigest == null) |
| { |
| throw new NullPointerException("msg == null"); |
| } |
| if (w != 4 && w != 16) |
| { |
| throw new IllegalArgumentException("w needs to be 4 or 16"); |
| } |
| int logW = XMSSUtil.log2(w); |
| if (outLength > ((8 * messageDigest.length) / logW)) |
| { |
| throw new IllegalArgumentException("outLength too big"); |
| } |
| |
| ArrayList<Integer> res = new ArrayList<Integer>(); |
| for (int i = 0; i < messageDigest.length; i++) |
| { |
| for (int j = 8 - logW; j >= 0; j -= logW) |
| { |
| res.add((messageDigest[i] >> j) & (w - 1)); |
| if (res.size() == outLength) |
| { |
| return res; |
| } |
| } |
| } |
| return res; |
| } |
| |
| /** |
| * Derive WOTS+ secret key for specific index as in XMSS ref impl Andreas |
| * Huelsing. |
| * |
| * @param otsHashAddress |
| * @return WOTS+ secret key at index. |
| */ |
| protected byte[] getWOTSPlusSecretKey(byte[] secretKeySeed, OTSHashAddress otsHashAddress) |
| { |
| otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder() |
| .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) |
| .withOTSAddress(otsHashAddress.getOTSAddress()).build(); |
| return khf.PRF(secretKeySeed, otsHashAddress.toByteArray()); |
| } |
| |
| /** |
| * Derive private key at index from secret key seed. |
| * |
| * @param index Index. |
| * @return Private key at index. |
| */ |
| private byte[] expandSecretKeySeed(int index) |
| { |
| if (index < 0 || index >= params.getLen()) |
| { |
| throw new IllegalArgumentException("index out of bounds"); |
| } |
| return khf.PRF(secretKeySeed, XMSSUtil.toBytesBigEndian(index, 32)); |
| } |
| |
| /** |
| * Getter parameters. |
| * |
| * @return params. |
| */ |
| protected WOTSPlusParameters getParams() |
| { |
| return params; |
| } |
| |
| /** |
| * Getter keyed hash functions. |
| * |
| * @return keyed hash functions. |
| */ |
| protected KeyedHashFunctions getKhf() |
| { |
| return khf; |
| } |
| |
| /** |
| * Getter secret key seed. |
| * |
| * @return secret key seed. |
| */ |
| protected byte[] getSecretKeySeed() |
| { |
| return Arrays.clone(secretKeySeed); |
| } |
| |
| /** |
| * Getter public seed. |
| * |
| * @return public seed. |
| */ |
| protected byte[] getPublicSeed() |
| { |
| return Arrays.clone(publicSeed); |
| } |
| |
| /** |
| * Getter private key. |
| * |
| * @return WOTS+ private key. |
| */ |
| protected WOTSPlusPrivateKeyParameters getPrivateKey() |
| { |
| byte[][] privateKey = new byte[params.getLen()][]; |
| for (int i = 0; i < privateKey.length; i++) |
| { |
| privateKey[i] = expandSecretKeySeed(i); |
| } |
| return new WOTSPlusPrivateKeyParameters(params, privateKey); |
| } |
| |
| /** |
| * Calculates a new public key based on the state of secretKeySeed, |
| * publicSeed and otsHashAddress. |
| * |
| * @param otsHashAddress OTS hash address for randomization. |
| * @return WOTS+ public key. |
| */ |
| protected WOTSPlusPublicKeyParameters getPublicKey(OTSHashAddress otsHashAddress) |
| { |
| if (otsHashAddress == null) |
| { |
| throw new NullPointerException("otsHashAddress == null"); |
| } |
| byte[][] publicKey = new byte[params.getLen()][]; |
| /* derive public key from secretKeySeed */ |
| for (int i = 0; i < params.getLen(); i++) |
| { |
| otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder() |
| .withLayerAddress(otsHashAddress.getLayerAddress()).withTreeAddress(otsHashAddress.getTreeAddress()) |
| .withOTSAddress(otsHashAddress.getOTSAddress()).withChainAddress(i) |
| .withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask()) |
| .build(); |
| publicKey[i] = chain(expandSecretKeySeed(i), 0, params.getWinternitzParameter() - 1, otsHashAddress); |
| } |
| return new WOTSPlusPublicKeyParameters(params, publicKey); |
| } |
| } |