blob: 8a4796f0fb0b8f5ed6f835474cee71f2927b064c [file] [log] [blame]
package org.bouncycastle.pqc.crypto.gmss;
import org.bouncycastle.crypto.Digest;
import org.bouncycastle.pqc.crypto.gmss.util.GMSSRandom;
import org.bouncycastle.util.encoders.Hex;
/**
* This class implements the distributed signature generation of the Winternitz
* one-time signature scheme (OTSS), described in C.Dods, N.P. Smart, and M.
* Stam, "Hash Based Digital Signature Schemes", LNCS 3796, pages 96–115,
* 2005. The class is used by the GMSS classes.
*/
public class GMSSRootSig
{
/**
* The hash function used by the OTS
*/
private Digest messDigestOTS;
/**
* The length of the message digest and private key
*/
private int mdsize, keysize;
/**
* The private key
*/
private byte[] privateKeyOTS;
/**
* The message bytes
*/
private byte[] hash;
/**
* The signature bytes
*/
private byte[] sign;
/**
* The Winternitz parameter
*/
private int w;
/**
* The source of randomness for OTS private key generation
*/
private GMSSRandom gmssRandom;
/**
* Sizes of the message
*/
private int messagesize;
/**
* Some precalculated values
*/
private int k;
/**
* Some variables for storing the actual status of distributed signing
*/
private int r, test, counter, ii;
/**
* variables for storing big numbers for the actual status of distributed
* signing
*/
private long test8, big8;
/**
* The necessary steps of each updateSign() call
*/
private int steps;
/**
* The checksum part
*/
private int checksum;
/**
* The height of the tree
*/
private int height;
/**
* The current intern OTSseed
*/
private byte[] seed;
/**
* This constructor regenerates a prior GMSSRootSig object used by the
* GMSSPrivateKeyASN.1 class
*
* @param digest an array of strings, containing the digest of the used hash
* function, the digest of the PRGN and the names of the
* corresponding providers
* @param statByte status byte array
* @param statInt status int array
*/
public GMSSRootSig(Digest digest, byte[][] statByte, int[] statInt)
{
messDigestOTS = digest;
gmssRandom = new GMSSRandom(messDigestOTS);
this.counter = statInt[0];
this.test = statInt[1];
this.ii = statInt[2];
this.r = statInt[3];
this.steps = statInt[4];
this.keysize = statInt[5];
this.height = statInt[6];
this.w = statInt[7];
this.checksum = statInt[8];
this.mdsize = messDigestOTS.getDigestSize();
this.k = (1 << w) - 1;
int mdsizeBit = mdsize << 3;
this.messagesize = (int)Math.ceil((double)(mdsizeBit) / (double)w);
this.privateKeyOTS = statByte[0];
this.seed = statByte[1];
this.hash = statByte[2];
this.sign = statByte[3];
this.test8 = ((statByte[4][0] & 0xff))
| ((long)(statByte[4][1] & 0xff) << 8)
| ((long)(statByte[4][2] & 0xff) << 16)
| ((long)(statByte[4][3] & 0xff)) << 24
| ((long)(statByte[4][4] & 0xff)) << 32
| ((long)(statByte[4][5] & 0xff)) << 40
| ((long)(statByte[4][6] & 0xff)) << 48
| ((long)(statByte[4][7] & 0xff)) << 56;
this.big8 = ((statByte[4][8] & 0xff))
| ((long)(statByte[4][9] & 0xff) << 8)
| ((long)(statByte[4][10] & 0xff) << 16)
| ((long)(statByte[4][11] & 0xff)) << 24
| ((long)(statByte[4][12] & 0xff)) << 32
| ((long)(statByte[4][13] & 0xff)) << 40
| ((long)(statByte[4][14] & 0xff)) << 48
| ((long)(statByte[4][15] & 0xff)) << 56;
}
/**
* The constructor generates the PRNG and initializes some variables
*
* @param digest an array of strings, containing the digest of the used hash
* function, the digest of the PRGN and the names of the
* corresponding providers
* @param w the winternitz parameter
* @param height the heigth of the tree
*/
public GMSSRootSig(Digest digest, int w, int height)
{
messDigestOTS = digest;
gmssRandom = new GMSSRandom(messDigestOTS);
this.mdsize = messDigestOTS.getDigestSize();
this.w = w;
this.height = height;
this.k = (1 << w) - 1;
int mdsizeBit = mdsize << 3;
this.messagesize = (int)Math.ceil((double)(mdsizeBit) / (double)w);
}
/**
* This method initializes the distributed sigature calculation. Variables
* are reseted and necessary steps are calculated
*
* @param seed0 the initial OTSseed
* @param message the massage which will be signed
*/
public void initSign(byte[] seed0, byte[] message)
{
// create hash of message m
this.hash = new byte[mdsize];
messDigestOTS.update(message, 0, message.length);
this.hash = new byte[messDigestOTS.getDigestSize()];
messDigestOTS.doFinal(this.hash, 0);
// variables for calculation of steps
byte[] messPart = new byte[mdsize];
System.arraycopy(hash, 0, messPart, 0, mdsize);
int checkPart = 0;
int sumH = 0;
int checksumsize = getLog((messagesize << w) + 1);
// ------- calculation of necessary steps ------
if (8 % w == 0)
{
int dt = 8 / w;
// message part
for (int a = 0; a < mdsize; a++)
{
// count necessary hashs in 'sumH'
for (int b = 0; b < dt; b++)
{
sumH += messPart[a] & k;
messPart[a] = (byte)(messPart[a] >>> w);
}
}
// checksum part
this.checksum = (messagesize << w) - sumH;
checkPart = checksum;
// count necessary hashs in 'sumH'
for (int b = 0; b < checksumsize; b += w)
{
sumH += checkPart & k;
checkPart >>>= w;
}
} // end if ( 8 % w == 0 )
else if (w < 8)
{
long big8;
int ii = 0;
int dt = mdsize / w;
// first d*w bytes of hash (main message part)
for (int i = 0; i < dt; i++)
{
big8 = 0;
for (int j = 0; j < w; j++)
{
big8 ^= (messPart[ii] & 0xff) << (j << 3);
ii++;
}
// count necessary hashs in 'sumH'
for (int j = 0; j < 8; j++)
{
sumH += (int)(big8 & k);
big8 >>>= w;
}
}
// rest of message part
dt = mdsize % w;
big8 = 0;
for (int j = 0; j < dt; j++)
{
big8 ^= (messPart[ii] & 0xff) << (j << 3);
ii++;
}
dt <<= 3;
// count necessary hashs in 'sumH'
for (int j = 0; j < dt; j += w)
{
sumH += (int)(big8 & k);
big8 >>>= w;
}
// checksum part
this.checksum = (messagesize << w) - sumH;
checkPart = checksum;
// count necessary hashs in 'sumH'
for (int i = 0; i < checksumsize; i += w)
{
sumH += checkPart & k;
checkPart >>>= w;
}
}// end if(w<8)
else if (w < 57)
{
long big8;
int r = 0;
int s, f, rest, ii;
// first a*w bits of hash where a*w <= 8*mdsize < (a+1)*w (main
// message part)
while (r <= ((mdsize << 3) - w))
{
s = r >>> 3;
rest = r % 8;
r += w;
f = (r + 7) >>> 3;
big8 = 0;
ii = 0;
for (int j = s; j < f; j++)
{
big8 ^= (messPart[j] & 0xff) << (ii << 3);
ii++;
}
big8 >>>= rest;
// count necessary hashs in 'sumH'
sumH += (big8 & k);
}
// rest of message part
s = r >>> 3;
if (s < mdsize)
{
rest = r % 8;
big8 = 0;
ii = 0;
for (int j = s; j < mdsize; j++)
{
big8 ^= (messPart[j] & 0xff) << (ii << 3);
ii++;
}
big8 >>>= rest;
// count necessary hashs in 'sumH'
sumH += (big8 & k);
}
// checksum part
this.checksum = (messagesize << w) - sumH;
checkPart = checksum;
// count necessary hashs in 'sumH'
for (int i = 0; i < checksumsize; i += w)
{
sumH += (checkPart & k);
checkPart >>>= w;
}
}// end if(w<57)
// calculate keysize
this.keysize = messagesize
+ (int)Math.ceil((double)checksumsize / (double)w);
// calculate steps: 'keysize' times PRNG, 'sumH' times hashing,
// (1<<height)-1 updateSign() calls
this.steps = (int)Math.ceil((double)(keysize + sumH)
/ (double)((1 << height)));
// ----------------------------
// reset variables
this.sign = new byte[keysize * mdsize];
this.counter = 0;
this.test = 0;
this.ii = 0;
this.test8 = 0;
this.r = 0;
// define the private key messagesize
this.privateKeyOTS = new byte[mdsize];
// copy the seed
this.seed = new byte[mdsize];
System.arraycopy(seed0, 0, this.seed, 0, mdsize);
}
/**
* This Method performs <code>steps</code> steps of distributed signature
* calculaion
*
* @return true if signature is generated completly, else false
*/
public boolean updateSign()
{
// steps times do
for (int s = 0; s < steps; s++)
{ // do 'step' times
if (counter < keysize)
{ // generate the private key or perform
// the next hash
oneStep();
}
if (counter == keysize)
{// finish
return true;
}
}
return false; // leaf not finished yet
}
/**
* @return The private OTS key
*/
public byte[] getSig()
{
return sign;
}
/**
* @return The one-time signature of the message, generated step by step
*/
private void oneStep()
{
// -------- if (8 % w == 0) ----------
if (8 % w == 0)
{
if (test == 0)
{
// get current OTSprivateKey
this.privateKeyOTS = gmssRandom.nextSeed(seed);
// System.arraycopy(privateKeyOTS, 0, hlp, 0, mdsize);
if (ii < mdsize)
{ // for main message part
test = hash[ii] & k;
hash[ii] = (byte)(hash[ii] >>> w);
}
else
{ // for checksum part
test = checksum & k;
checksum >>>= w;
}
}
else if (test > 0)
{ // hash the private Key 'test' times (on
// time each step)
messDigestOTS.update(privateKeyOTS, 0, privateKeyOTS.length);
privateKeyOTS = new byte[messDigestOTS.getDigestSize()];
messDigestOTS.doFinal(privateKeyOTS, 0);
test--;
}
if (test == 0)
{ // if all hashes done copy result to siganture
// array
System.arraycopy(privateKeyOTS, 0, sign, counter * mdsize,
mdsize);
counter++;
if (counter % (8 / w) == 0)
{ // raise array index for main
// massage part
ii++;
}
}
}// ----- end if (8 % w == 0) -----
// ---------- if ( w < 8 ) ----------------
else if (w < 8)
{
if (test == 0)
{
if (counter % 8 == 0 && ii < mdsize)
{ // after every 8th "add
// to signature"-step
big8 = 0;
if (counter < ((mdsize / w) << 3))
{// main massage
// (generate w*8 Bits
// every time) part
for (int j = 0; j < w; j++)
{
big8 ^= (hash[ii] & 0xff) << (j << 3);
ii++;
}
}
else
{ // rest of massage part (once)
for (int j = 0; j < mdsize % w; j++)
{
big8 ^= (hash[ii] & 0xff) << (j << 3);
ii++;
}
}
}
if (counter == messagesize)
{ // checksum part (once)
big8 = checksum;
}
test = (int)(big8 & k);
// generate current OTSprivateKey
this.privateKeyOTS = gmssRandom.nextSeed(seed);
// System.arraycopy(privateKeyOTS, 0, hlp, 0, mdsize);
}
else if (test > 0)
{ // hash the private Key 'test' times (on
// time each step)
messDigestOTS.update(privateKeyOTS, 0, privateKeyOTS.length);
privateKeyOTS = new byte[messDigestOTS.getDigestSize()];
messDigestOTS.doFinal(privateKeyOTS, 0);
test--;
}
if (test == 0)
{ // if all hashes done copy result to siganture
// array
System.arraycopy(privateKeyOTS, 0, sign, counter * mdsize,
mdsize);
big8 >>>= w;
counter++;
}
}// ------- end if(w<8)--------------------------------
// --------- if w < 57 -----------------------------
else if (w < 57)
{
if (test8 == 0)
{
int s, f, rest;
big8 = 0;
ii = 0;
rest = r % 8;
s = r >>> 3;
// --- message part---
if (s < mdsize)
{
if (r <= ((mdsize << 3) - w))
{ // first message part
r += w;
f = (r + 7) >>> 3;
}
else
{ // rest of message part (once)
f = mdsize;
r += w;
}
// generate long 'big8' with minimum w next bits of the
// message array
for (int i = s; i < f; i++)
{
big8 ^= (hash[i] & 0xff) << (ii << 3);
ii++;
}
// delete bits on the right side, which were used already by
// the last loop
big8 >>>= rest;
test8 = (big8 & k);
}
// --- checksum part
else
{
test8 = (checksum & k);
checksum >>>= w;
}
// generate current OTSprivateKey
this.privateKeyOTS = gmssRandom.nextSeed(seed);
// System.arraycopy(privateKeyOTS, 0, hlp, 0, mdsize);
}
else if (test8 > 0)
{ // hash the private Key 'test' times (on
// time each step)
messDigestOTS.update(privateKeyOTS, 0, privateKeyOTS.length);
privateKeyOTS = new byte[messDigestOTS.getDigestSize()];
messDigestOTS.doFinal(privateKeyOTS, 0);
test8--;
}
if (test8 == 0)
{ // if all hashes done copy result to siganture
// array
System.arraycopy(privateKeyOTS, 0, sign, counter * mdsize,
mdsize);
counter++;
}
}
}
/**
* This method returns the least integer that is greater or equal to the
* logarithm to the base 2 of an integer <code>intValue</code>.
*
* @param intValue an integer
* @return The least integer greater or equal to the logarithm to the base 2
* of <code>intValue</code>
*/
public int getLog(int intValue)
{
int log = 1;
int i = 2;
while (i < intValue)
{
i <<= 1;
log++;
}
return log;
}
/**
* This method returns the status byte array
*
* @return statBytes
*/
public byte[][] getStatByte()
{
byte[][] statByte = new byte[5][mdsize];
statByte[0] = privateKeyOTS;
statByte[1] = seed;
statByte[2] = hash;
statByte[3] = sign;
statByte[4] = this.getStatLong();
return statByte;
}
/**
* This method returns the status int array
*
* @return statInt
*/
public int[] getStatInt()
{
int[] statInt = new int[9];
statInt[0] = counter;
statInt[1] = test;
statInt[2] = ii;
statInt[3] = r;
statInt[4] = steps;
statInt[5] = keysize;
statInt[6] = height;
statInt[7] = w;
statInt[8] = checksum;
return statInt;
}
/**
* Converts the long parameters into byte arrays to store it in
* statByte-Array
*/
public byte[] getStatLong()
{
byte[] bytes = new byte[16];
bytes[0] = (byte)((test8) & 0xff);
bytes[1] = (byte)((test8 >> 8) & 0xff);
bytes[2] = (byte)((test8 >> 16) & 0xff);
bytes[3] = (byte)((test8 >> 24) & 0xff);
bytes[4] = (byte)((test8) >> 32 & 0xff);
bytes[5] = (byte)((test8 >> 40) & 0xff);
bytes[6] = (byte)((test8 >> 48) & 0xff);
bytes[7] = (byte)((test8 >> 56) & 0xff);
bytes[8] = (byte)((big8) & 0xff);
bytes[9] = (byte)((big8 >> 8) & 0xff);
bytes[10] = (byte)((big8 >> 16) & 0xff);
bytes[11] = (byte)((big8 >> 24) & 0xff);
bytes[12] = (byte)((big8) >> 32 & 0xff);
bytes[13] = (byte)((big8 >> 40) & 0xff);
bytes[14] = (byte)((big8 >> 48) & 0xff);
bytes[15] = (byte)((big8 >> 56) & 0xff);
return bytes;
}
/**
* returns a string representation of the instance
*
* @return a string representation of the instance
*/
public String toString()
{
String out = "" + this.big8 + " ";
int[] statInt = new int[9];
statInt = this.getStatInt();
byte[][] statByte = new byte[5][mdsize];
statByte = this.getStatByte();
for (int i = 0; i < 9; i++)
{
out = out + statInt[i] + " ";
}
for (int i = 0; i < 5; i++)
{
out = out + new String(Hex.encode(statByte[i])) + " ";
}
return out;
}
}