X-Git-Url: https://asedeno.scripts.mit.edu/gitweb/?a=blobdiff_plain;f=sshbn.c;h=6768204bc4cd2ec930799bcc3d250407928d63a2;hb=510f49e405e71ba5c97875e7a019364e1ef5fac9;hp=ba3d5b635456091d89415bb5fb1ec48ceb3218e0;hpb=42801b7e9ee410379563b46fde0656596f807451;p=PuTTY.git diff --git a/sshbn.c b/sshbn.c index ba3d5b63..6768204b 100644 --- a/sshbn.c +++ b/sshbn.c @@ -6,66 +6,12 @@ #include #include #include +#include +#include #include "misc.h" -/* - * Usage notes: - * * Do not call the DIVMOD_WORD macro with expressions such as array - * subscripts, as some implementations object to this (see below). - * * Note that none of the division methods below will cope if the - * quotient won't fit into BIGNUM_INT_BITS. Callers should be careful - * to avoid this case. - * If this condition occurs, in the case of the x86 DIV instruction, - * an overflow exception will occur, which (according to a correspondent) - * will manifest on Windows as something like - * 0xC0000095: Integer overflow - * The C variant won't give the right answer, either. - */ - -#if defined __GNUC__ && defined __i386__ -typedef unsigned long BignumInt; -typedef unsigned long long BignumDblInt; -#define BIGNUM_INT_MASK 0xFFFFFFFFUL -#define BIGNUM_TOP_BIT 0x80000000UL -#define BIGNUM_INT_BITS 32 -#define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2) -#define DIVMOD_WORD(q, r, hi, lo, w) \ - __asm__("div %2" : \ - "=d" (r), "=a" (q) : \ - "r" (w), "d" (hi), "a" (lo)) -#elif defined _MSC_VER && defined _M_IX86 -typedef unsigned __int32 BignumInt; -typedef unsigned __int64 BignumDblInt; -#define BIGNUM_INT_MASK 0xFFFFFFFFUL -#define BIGNUM_TOP_BIT 0x80000000UL -#define BIGNUM_INT_BITS 32 -#define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2) -/* Note: MASM interprets array subscripts in the macro arguments as - * assembler syntax, which gives the wrong answer. Don't supply them. - * */ -#define DIVMOD_WORD(q, r, hi, lo, w) do { \ - __asm mov edx, hi \ - __asm mov eax, lo \ - __asm div w \ - __asm mov r, edx \ - __asm mov q, eax \ -} while(0) -#else -typedef unsigned short BignumInt; -typedef unsigned long BignumDblInt; -#define BIGNUM_INT_MASK 0xFFFFU -#define BIGNUM_TOP_BIT 0x8000U -#define BIGNUM_INT_BITS 16 -#define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2) -#define DIVMOD_WORD(q, r, hi, lo, w) do { \ - BignumDblInt n = (((BignumDblInt)hi) << BIGNUM_INT_BITS) | lo; \ - q = n / w; \ - r = n % w; \ -} while (0) -#endif - -#define BIGNUM_INT_BYTES (BIGNUM_INT_BITS / 8) +#include "sshbn.h" #define BIGNUM_INTERNAL typedef BignumInt *Bignum; @@ -74,6 +20,7 @@ typedef BignumInt *Bignum; BignumInt bnZero[1] = { 0 }; BignumInt bnOne[2] = { 1, 1 }; +BignumInt bnTen[2] = { 1, 10 }; /* * The Bignum format is an array of `BignumInt'. The first @@ -89,13 +36,15 @@ BignumInt bnOne[2] = { 1, 1 }; * nonzero. */ -Bignum Zero = bnZero, One = bnOne; +Bignum Zero = bnZero, One = bnOne, Ten = bnTen; static Bignum newbn(int length) { - Bignum b = snewn(length + 1, BignumInt); - if (!b) - abort(); /* FIXME */ + Bignum b; + + assert(length >= 0 && length < INT_MAX / BIGNUM_INT_BITS); + + b = snewn(length + 1, BignumInt); memset(b, 0, (length + 1) * sizeof(*b)); b[0] = length; return b; @@ -121,58 +70,568 @@ void freebn(Bignum b) /* * Burn the evidence, just in case. */ - memset(b, 0, sizeof(b[0]) * (b[0] + 1)); + smemclr(b, sizeof(b[0]) * (b[0] + 1)); sfree(b); } Bignum bn_power_2(int n) { - Bignum ret = newbn(n / BIGNUM_INT_BITS + 1); + Bignum ret; + + assert(n >= 0); + + ret = newbn(n / BIGNUM_INT_BITS + 1); bignum_set_bit(ret, n, 1); return ret; } +/* + * Internal addition. Sets c = a - b, where 'a', 'b' and 'c' are all + * big-endian arrays of 'len' BignumInts. Returns the carry off the + * top. + */ +static BignumCarry internal_add(const BignumInt *a, const BignumInt *b, + BignumInt *c, int len) +{ + int i; + BignumCarry carry = 0; + + for (i = len-1; i >= 0; i--) + BignumADC(c[i], carry, a[i], b[i], carry); + + return (BignumInt)carry; +} + +/* + * Internal subtraction. Sets c = a - b, where 'a', 'b' and 'c' are + * all big-endian arrays of 'len' BignumInts. Any borrow from the top + * is ignored. + */ +static void internal_sub(const BignumInt *a, const BignumInt *b, + BignumInt *c, int len) +{ + int i; + BignumCarry carry = 1; + + for (i = len-1; i >= 0; i--) + BignumADC(c[i], carry, a[i], ~b[i], carry); +} + /* * Compute c = a * b. * Input is in the first len words of a and b. * Result is returned in the first 2*len words of c. + * + * 'scratch' must point to an array of BignumInt of size at least + * mul_compute_scratch(len). (This covers the needs of internal_mul + * and all its recursive calls to itself.) */ -static void internal_mul(BignumInt *a, BignumInt *b, - BignumInt *c, int len) +#define KARATSUBA_THRESHOLD 50 +static int mul_compute_scratch(int len) { - int i, j; - BignumDblInt t; - - for (j = 0; j < 2 * len; j++) - c[j] = 0; - - for (i = len - 1; i >= 0; i--) { - t = 0; - for (j = len - 1; j >= 0; j--) { - t += MUL_WORD(a[i], (BignumDblInt) b[j]); - t += (BignumDblInt) c[i + j + 1]; - c[i + j + 1] = (BignumInt) t; - t = t >> BIGNUM_INT_BITS; - } - c[i] = (BignumInt) t; + int ret = 0; + while (len > KARATSUBA_THRESHOLD) { + int toplen = len/2, botlen = len - toplen; /* botlen is the bigger */ + int midlen = botlen + 1; + ret += 4*midlen; + len = midlen; } + return ret; +} +static void internal_mul(const BignumInt *a, const BignumInt *b, + BignumInt *c, int len, BignumInt *scratch) +{ + if (len > KARATSUBA_THRESHOLD) { + int i; + + /* + * Karatsuba divide-and-conquer algorithm. Cut each input in + * half, so that it's expressed as two big 'digits' in a giant + * base D: + * + * a = a_1 D + a_0 + * b = b_1 D + b_0 + * + * Then the product is of course + * + * ab = a_1 b_1 D^2 + (a_1 b_0 + a_0 b_1) D + a_0 b_0 + * + * and we compute the three coefficients by recursively + * calling ourself to do half-length multiplications. + * + * The clever bit that makes this worth doing is that we only + * need _one_ half-length multiplication for the central + * coefficient rather than the two that it obviouly looks + * like, because we can use a single multiplication to compute + * + * (a_1 + a_0) (b_1 + b_0) = a_1 b_1 + a_1 b_0 + a_0 b_1 + a_0 b_0 + * + * and then we subtract the other two coefficients (a_1 b_1 + * and a_0 b_0) which we were computing anyway. + * + * Hence we get to multiply two numbers of length N in about + * three times as much work as it takes to multiply numbers of + * length N/2, which is obviously better than the four times + * as much work it would take if we just did a long + * conventional multiply. + */ + + int toplen = len/2, botlen = len - toplen; /* botlen is the bigger */ + int midlen = botlen + 1; + BignumCarry carry; +#ifdef KARA_DEBUG + int i; +#endif + + /* + * The coefficients a_1 b_1 and a_0 b_0 just avoid overlapping + * in the output array, so we can compute them immediately in + * place. + */ + +#ifdef KARA_DEBUG + printf("a1,a0 = 0x"); + for (i = 0; i < len; i++) { + if (i == toplen) printf(", 0x"); + printf("%0*x", BIGNUM_INT_BITS/4, a[i]); + } + printf("\n"); + printf("b1,b0 = 0x"); + for (i = 0; i < len; i++) { + if (i == toplen) printf(", 0x"); + printf("%0*x", BIGNUM_INT_BITS/4, b[i]); + } + printf("\n"); +#endif + + /* a_1 b_1 */ + internal_mul(a, b, c, toplen, scratch); +#ifdef KARA_DEBUG + printf("a1b1 = 0x"); + for (i = 0; i < 2*toplen; i++) { + printf("%0*x", BIGNUM_INT_BITS/4, c[i]); + } + printf("\n"); +#endif + + /* a_0 b_0 */ + internal_mul(a + toplen, b + toplen, c + 2*toplen, botlen, scratch); +#ifdef KARA_DEBUG + printf("a0b0 = 0x"); + for (i = 0; i < 2*botlen; i++) { + printf("%0*x", BIGNUM_INT_BITS/4, c[2*toplen+i]); + } + printf("\n"); +#endif + + /* Zero padding. midlen exceeds toplen by at most 2, so just + * zero the first two words of each input and the rest will be + * copied over. */ + scratch[0] = scratch[1] = scratch[midlen] = scratch[midlen+1] = 0; + + for (i = 0; i < toplen; i++) { + scratch[midlen - toplen + i] = a[i]; /* a_1 */ + scratch[2*midlen - toplen + i] = b[i]; /* b_1 */ + } + + /* compute a_1 + a_0 */ + scratch[0] = internal_add(scratch+1, a+toplen, scratch+1, botlen); +#ifdef KARA_DEBUG + printf("a1plusa0 = 0x"); + for (i = 0; i < midlen; i++) { + printf("%0*x", BIGNUM_INT_BITS/4, scratch[i]); + } + printf("\n"); +#endif + /* compute b_1 + b_0 */ + scratch[midlen] = internal_add(scratch+midlen+1, b+toplen, + scratch+midlen+1, botlen); +#ifdef KARA_DEBUG + printf("b1plusb0 = 0x"); + for (i = 0; i < midlen; i++) { + printf("%0*x", BIGNUM_INT_BITS/4, scratch[midlen+i]); + } + printf("\n"); +#endif + + /* + * Now we can do the third multiplication. + */ + internal_mul(scratch, scratch + midlen, scratch + 2*midlen, midlen, + scratch + 4*midlen); +#ifdef KARA_DEBUG + printf("a1plusa0timesb1plusb0 = 0x"); + for (i = 0; i < 2*midlen; i++) { + printf("%0*x", BIGNUM_INT_BITS/4, scratch[2*midlen+i]); + } + printf("\n"); +#endif + + /* + * Now we can reuse the first half of 'scratch' to compute the + * sum of the outer two coefficients, to subtract from that + * product to obtain the middle one. + */ + scratch[0] = scratch[1] = scratch[2] = scratch[3] = 0; + for (i = 0; i < 2*toplen; i++) + scratch[2*midlen - 2*toplen + i] = c[i]; + scratch[1] = internal_add(scratch+2, c + 2*toplen, + scratch+2, 2*botlen); +#ifdef KARA_DEBUG + printf("a1b1plusa0b0 = 0x"); + for (i = 0; i < 2*midlen; i++) { + printf("%0*x", BIGNUM_INT_BITS/4, scratch[i]); + } + printf("\n"); +#endif + + internal_sub(scratch + 2*midlen, scratch, + scratch + 2*midlen, 2*midlen); +#ifdef KARA_DEBUG + printf("a1b0plusa0b1 = 0x"); + for (i = 0; i < 2*midlen; i++) { + printf("%0*x", BIGNUM_INT_BITS/4, scratch[2*midlen+i]); + } + printf("\n"); +#endif + + /* + * And now all we need to do is to add that middle coefficient + * back into the output. We may have to propagate a carry + * further up the output, but we can be sure it won't + * propagate right the way off the top. + */ + carry = internal_add(c + 2*len - botlen - 2*midlen, + scratch + 2*midlen, + c + 2*len - botlen - 2*midlen, 2*midlen); + i = 2*len - botlen - 2*midlen - 1; + while (carry) { + assert(i >= 0); + BignumADC(c[i], carry, c[i], 0, carry); + i--; + } +#ifdef KARA_DEBUG + printf("ab = 0x"); + for (i = 0; i < 2*len; i++) { + printf("%0*x", BIGNUM_INT_BITS/4, c[i]); + } + printf("\n"); +#endif + + } else { + int i; + BignumInt carry; + const BignumInt *ap, *bp; + BignumInt *cp, *cps; + + /* + * Multiply in the ordinary O(N^2) way. + */ + + for (i = 0; i < 2 * len; i++) + c[i] = 0; + + for (cps = c + 2*len, ap = a + len; ap-- > a; cps--) { + carry = 0; + for (cp = cps, bp = b + len; cp--, bp-- > b ;) + BignumMULADD2(carry, *cp, *ap, *bp, *cp, carry); + *cp = carry; + } + } +} + +/* + * Variant form of internal_mul used for the initial step of + * Montgomery reduction. Only bothers outputting 'len' words + * (everything above that is thrown away). + */ +static void internal_mul_low(const BignumInt *a, const BignumInt *b, + BignumInt *c, int len, BignumInt *scratch) +{ + if (len > KARATSUBA_THRESHOLD) { + int i; + + /* + * Karatsuba-aware version of internal_mul_low. As before, we + * express each input value as a shifted combination of two + * halves: + * + * a = a_1 D + a_0 + * b = b_1 D + b_0 + * + * Then the full product is, as before, + * + * ab = a_1 b_1 D^2 + (a_1 b_0 + a_0 b_1) D + a_0 b_0 + * + * Provided we choose D on the large side (so that a_0 and b_0 + * are _at least_ as long as a_1 and b_1), we don't need the + * topmost term at all, and we only need half of the middle + * term. So there's no point in doing the proper Karatsuba + * optimisation which computes the middle term using the top + * one, because we'd take as long computing the top one as + * just computing the middle one directly. + * + * So instead, we do a much more obvious thing: we call the + * fully optimised internal_mul to compute a_0 b_0, and we + * recursively call ourself to compute the _bottom halves_ of + * a_1 b_0 and a_0 b_1, each of which we add into the result + * in the obvious way. + * + * In other words, there's no actual Karatsuba _optimisation_ + * in this function; the only benefit in doing it this way is + * that we call internal_mul proper for a large part of the + * work, and _that_ can optimise its operation. + */ + + int toplen = len/2, botlen = len - toplen; /* botlen is the bigger */ + + /* + * Scratch space for the various bits and pieces we're going + * to be adding together: we need botlen*2 words for a_0 b_0 + * (though we may end up throwing away its topmost word), and + * toplen words for each of a_1 b_0 and a_0 b_1. That adds up + * to exactly 2*len. + */ + + /* a_0 b_0 */ + internal_mul(a + toplen, b + toplen, scratch + 2*toplen, botlen, + scratch + 2*len); + + /* a_1 b_0 */ + internal_mul_low(a, b + len - toplen, scratch + toplen, toplen, + scratch + 2*len); + + /* a_0 b_1 */ + internal_mul_low(a + len - toplen, b, scratch, toplen, + scratch + 2*len); + + /* Copy the bottom half of the big coefficient into place */ + for (i = 0; i < botlen; i++) + c[toplen + i] = scratch[2*toplen + botlen + i]; + + /* Add the two small coefficients, throwing away the returned carry */ + internal_add(scratch, scratch + toplen, scratch, toplen); + + /* And add that to the large coefficient, leaving the result in c. */ + internal_add(scratch, scratch + 2*toplen + botlen - toplen, + c, toplen); + + } else { + int i; + BignumInt carry; + const BignumInt *ap, *bp; + BignumInt *cp, *cps; + + /* + * Multiply in the ordinary O(N^2) way. + */ + + for (i = 0; i < len; i++) + c[i] = 0; + + for (cps = c + len, ap = a + len; ap-- > a; cps--) { + carry = 0; + for (cp = cps, bp = b + len; bp--, cp-- > c ;) + BignumMULADD2(carry, *cp, *ap, *bp, *cp, carry); + } + } +} + +/* + * Montgomery reduction. Expects x to be a big-endian array of 2*len + * BignumInts whose value satisfies 0 <= x < rn (where r = 2^(len * + * BIGNUM_INT_BITS) is the Montgomery base). Returns in the same array + * a value x' which is congruent to xr^{-1} mod n, and satisfies 0 <= + * x' < n. + * + * 'n' and 'mninv' should be big-endian arrays of 'len' BignumInts + * each, containing respectively n and the multiplicative inverse of + * -n mod r. + * + * 'tmp' is an array of BignumInt used as scratch space, of length at + * least 3*len + mul_compute_scratch(len). + */ +static void monty_reduce(BignumInt *x, const BignumInt *n, + const BignumInt *mninv, BignumInt *tmp, int len) +{ + int i; + BignumInt carry; + + /* + * Multiply x by (-n)^{-1} mod r. This gives us a value m such + * that mn is congruent to -x mod r. Hence, mn+x is an exact + * multiple of r, and is also (obviously) congruent to x mod n. + */ + internal_mul_low(x + len, mninv, tmp, len, tmp + 3*len); + + /* + * Compute t = (mn+x)/r in ordinary, non-modular, integer + * arithmetic. By construction this is exact, and is congruent mod + * n to x * r^{-1}, i.e. the answer we want. + * + * The following multiply leaves that answer in the _most_ + * significant half of the 'x' array, so then we must shift it + * down. + */ + internal_mul(tmp, n, tmp+len, len, tmp + 3*len); + carry = internal_add(x, tmp+len, x, 2*len); + for (i = 0; i < len; i++) + x[len + i] = x[i], x[i] = 0; + + /* + * Reduce t mod n. This doesn't require a full-on division by n, + * but merely a test and single optional subtraction, since we can + * show that 0 <= t < 2n. + * + * Proof: + * + we computed m mod r, so 0 <= m < r. + * + so 0 <= mn < rn, obviously + * + hence we only need 0 <= x < rn to guarantee that 0 <= mn+x < 2rn + * + yielding 0 <= (mn+x)/r < 2n as required. + */ + if (!carry) { + for (i = 0; i < len; i++) + if (x[len + i] != n[i]) + break; + } + if (carry || i >= len || x[len + i] > n[i]) + internal_sub(x+len, n, x+len, len); } static void internal_add_shifted(BignumInt *number, - unsigned n, int shift) + BignumInt n, int shift) { int word = 1 + (shift / BIGNUM_INT_BITS); int bshift = shift % BIGNUM_INT_BITS; - BignumDblInt addend; + BignumInt addendh, addendl; + BignumCarry carry; + + addendl = n << bshift; + addendh = (bshift == 0 ? 0 : n >> (BIGNUM_INT_BITS - bshift)); + + assert(word <= number[0]); + BignumADC(number[word], carry, number[word], addendl, 0); + word++; + if (!addendh && !carry) + return; + assert(word <= number[0]); + BignumADC(number[word], carry, number[word], addendh, carry); + word++; + while (carry) { + assert(word <= number[0]); + BignumADC(number[word], carry, number[word], 0, carry); + word++; + } +} - addend = (BignumDblInt)n << bshift; +static int bn_clz(BignumInt x) +{ + /* + * Count the leading zero bits in x. Equivalently, how far left + * would we need to shift x to make its top bit set? + * + * Precondition: x != 0. + */ - while (addend) { - addend += number[word]; - number[word] = (BignumInt) addend & BIGNUM_INT_MASK; - addend >>= BIGNUM_INT_BITS; - word++; + /* FIXME: would be nice to put in some compiler intrinsics under + * ifdef here */ + int i, ret = 0; + for (i = BIGNUM_INT_BITS / 2; i != 0; i >>= 1) { + if ((x >> (BIGNUM_INT_BITS-i)) == 0) { + x <<= i; + ret += i; + } } + return ret; +} + +static BignumInt reciprocal_word(BignumInt d) +{ + BignumInt dshort, recip, prodh, prodl; + int corrections; + + /* + * Input: a BignumInt value d, with its top bit set. + */ + assert(d >> (BIGNUM_INT_BITS-1) == 1); + + /* + * Output: a value, shifted to fill a BignumInt, which is strictly + * less than 1/(d+1), i.e. is an *under*-estimate (but by as + * little as possible within the constraints) of the reciprocal of + * any number whose first BIGNUM_INT_BITS bits match d. + * + * Ideally we'd like to _totally_ fill BignumInt, i.e. always + * return a value with the top bit set. Unfortunately we can't + * quite guarantee that for all inputs and also return a fixed + * exponent. So instead we take our reciprocal to be + * 2^(BIGNUM_INT_BITS*2-1) / d, so that it has the top bit clear + * only in the exceptional case where d takes exactly the maximum + * value BIGNUM_INT_MASK; in that case, the top bit is clear and + * the next bit down is set. + */ + + /* + * Start by computing a half-length version of the answer, by + * straightforward division within a BignumInt. + */ + dshort = (d >> (BIGNUM_INT_BITS/2)) + 1; + recip = (BIGNUM_TOP_BIT + dshort - 1) / dshort; + recip <<= BIGNUM_INT_BITS - BIGNUM_INT_BITS/2; + + /* + * Newton-Raphson iteration to improve that starting reciprocal + * estimate: take f(x) = d - 1/x, and then the N-R formula gives + * x_new = x - f(x)/f'(x) = x - (d-1/x)/(1/x^2) = x(2-d*x). Or, + * taking our fixed-point representation into account, take f(x) + * to be d - K/x (where K = 2^(BIGNUM_INT_BITS*2-1) as discussed + * above) and then we get (2K - d*x) * x/K. + * + * Newton-Raphson doubles the number of correct bits at every + * iteration, and the initial division above already gave us half + * the output word, so it's only worth doing one iteration. + */ + BignumMULADD(prodh, prodl, recip, d, recip); + prodl = ~prodl; + prodh = ~prodh; + { + BignumCarry c; + BignumADC(prodl, c, prodl, 1, 0); + prodh += c; + } + BignumMUL(prodh, prodl, prodh, recip); + recip = (prodh << 1) | (prodl >> (BIGNUM_INT_BITS-1)); + + /* + * Now make sure we have the best possible reciprocal estimate, + * before we return it. We might have been off by a handful either + * way - not enough to bother with any better-thought-out kind of + * correction loop. + */ + BignumMULADD(prodh, prodl, recip, d, recip); + corrections = 0; + if (prodh >= BIGNUM_TOP_BIT) { + do { + BignumCarry c = 1; + BignumADC(prodl, c, prodl, ~d, c); prodh += BIGNUM_INT_MASK + c; + recip--; + corrections++; + } while (prodh >= ((BignumInt)1 << (BIGNUM_INT_BITS-1))); + } else { + while (1) { + BignumInt newprodh, newprodl; + BignumCarry c = 0; + BignumADC(newprodl, c, prodl, d, c); newprodh = prodh + c; + if (newprodh >= BIGNUM_TOP_BIT) + break; + prodh = newprodh; + prodl = newprodl; + recip++; + corrections++; + } + } + + return recip; } /* @@ -180,112 +639,311 @@ static void internal_add_shifted(BignumInt *number, * Input in first alen words of a and first mlen words of m. * Output in first alen words of a * (of which first alen-mlen words will be zero). - * The MSW of m MUST have its high bit set. * Quotient is accumulated in the `quotient' array, which is a Bignum - * rather than the internal bigendian format. Quotient parts are shifted - * left by `qshift' before adding into quot. + * rather than the internal bigendian format. + * + * 'recip' must be the result of calling reciprocal_word() on the top + * BIGNUM_INT_BITS of the modulus (denoted m0 in comments below), with + * the topmost set bit normalised to the MSB of the input to + * reciprocal_word. 'rshift' is how far left the top nonzero word of + * the modulus had to be shifted to set that top bit. */ static void internal_mod(BignumInt *a, int alen, BignumInt *m, int mlen, - BignumInt *quot, int qshift) + BignumInt *quot, BignumInt recip, int rshift) { - BignumInt m0, m1; - unsigned int h; int i, k; - m0 = m[0]; - if (mlen > 1) - m1 = m[1]; - else - m1 = 0; +#ifdef DIVISION_DEBUG + { + int d; + printf("start division, m=0x"); + for (d = 0; d < mlen; d++) + printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)m[d]); + printf(", recip=%#0*llx, rshift=%d\n", + BIGNUM_INT_BITS/4, (unsigned long long)recip, rshift); + } +#endif - for (i = 0; i <= alen - mlen; i++) { - BignumDblInt t; - unsigned int q, r, c, ai1; + /* + * Repeatedly use that reciprocal estimate to get a decent number + * of quotient bits, and subtract off the resulting multiple of m. + * + * Normally we expect to terminate this loop by means of finding + * out q=0 part way through, but one way in which we might not get + * that far in the first place is if the input a is actually zero, + * in which case we'll discard zero words from the front of a + * until we reach the termination condition in the for statement + * here. + */ + for (i = 0; i <= alen - mlen ;) { + BignumInt product; + BignumInt aword, q; + int shift, full_bitoffset, bitoffset, wordoffset; + +#ifdef DIVISION_DEBUG + { + int d; + printf("main loop, a=0x"); + for (d = 0; d < alen; d++) + printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)a[d]); + printf("\n"); + } +#endif - if (i == 0) { - h = 0; - } else { - h = a[i - 1]; - a[i - 1] = 0; - } + if (a[i] == 0) { +#ifdef DIVISION_DEBUG + printf("zero word at i=%d\n", i); +#endif + i++; + continue; + } + + aword = a[i]; + shift = bn_clz(aword); + aword <<= shift; + if (shift > 0 && i+1 < alen) + aword |= a[i+1] >> (BIGNUM_INT_BITS - shift); + + { + BignumInt unused; + BignumMUL(q, unused, recip, aword); + (void)unused; + } + +#ifdef DIVISION_DEBUG + printf("i=%d, aword=%#0*llx, shift=%d, q=%#0*llx\n", + i, BIGNUM_INT_BITS/4, (unsigned long long)aword, + shift, BIGNUM_INT_BITS/4, (unsigned long long)q); +#endif - if (i == alen - 1) - ai1 = 0; - else - ai1 = a[i + 1]; - - /* Find q = h:a[i] / m0 */ - if (h >= m0) { - /* - * Special case. - * - * To illustrate it, suppose a BignumInt is 8 bits, and - * we are dividing (say) A1:23:45:67 by A1:B2:C3. Then - * our initial division will be 0xA123 / 0xA1, which - * will give a quotient of 0x100 and a divide overflow. - * However, the invariants in this division algorithm - * are not violated, since the full number A1:23:... is - * _less_ than the quotient prefix A1:B2:... and so the - * following correction loop would have sorted it out. - * - * In this situation we set q to be the largest - * quotient we _can_ stomach (0xFF, of course). - */ - q = BIGNUM_INT_MASK; - } else { - /* Macro doesn't want an array subscript expression passed - * into it (see definition), so use a temporary. */ - BignumInt tmplo = a[i]; - DIVMOD_WORD(q, r, h, tmplo, m0); - - /* Refine our estimate of q by looking at - h:a[i]:a[i+1] / m0:m1 */ - t = MUL_WORD(m1, q); - if (t > ((BignumDblInt) r << BIGNUM_INT_BITS) + ai1) { - q--; - t -= m1; - r = (r + m0) & BIGNUM_INT_MASK; /* overflow? */ - if (r >= (BignumDblInt) m0 && - t > ((BignumDblInt) r << BIGNUM_INT_BITS) + ai1) q--; - } - } + /* + * Work out the right bit and word offsets to use when + * subtracting q*m from a. + * + * aword was taken from a[i], which means its LSB was at bit + * position (alen-1-i) * BIGNUM_INT_BITS. But then we shifted + * it left by 'shift', so now the low bit of aword corresponds + * to bit position (alen-1-i) * BIGNUM_INT_BITS - shift, i.e. + * aword is approximately equal to a / 2^(that). + * + * m0 comes from the top word of mod, so its LSB is at bit + * position (mlen-1) * BIGNUM_INT_BITS - rshift, i.e. it can + * be considered to be m / 2^(that power). 'recip' is the + * reciprocal of m0, times 2^(BIGNUM_INT_BITS*2-1), i.e. it's + * about 2^((mlen+1) * BIGNUM_INT_BITS - rshift - 1) / m. + * + * Hence, recip * aword is approximately equal to the product + * of those, which simplifies to + * + * a/m * 2^((mlen+2+i-alen)*BIGNUM_INT_BITS + shift - rshift - 1) + * + * But we've also shifted recip*aword down by BIGNUM_INT_BITS + * to form q, so we have + * + * q ~= a/m * 2^((mlen+1+i-alen)*BIGNUM_INT_BITS + shift - rshift - 1) + * + * and hence, when we now compute q*m, it will be about + * a*2^(all that lot), i.e. the negation of that expression is + * how far left we have to shift the product q*m to make it + * approximately equal to a. + */ + full_bitoffset = -((mlen+1+i-alen)*BIGNUM_INT_BITS + shift-rshift-1); +#ifdef DIVISION_DEBUG + printf("full_bitoffset=%d\n", full_bitoffset); +#endif - /* Subtract q * m from a[i...] */ - c = 0; - for (k = mlen - 1; k >= 0; k--) { - t = MUL_WORD(q, m[k]); - t += c; - c = (unsigned)(t >> BIGNUM_INT_BITS); - if ((BignumInt) t > a[i + k]) - c++; - a[i + k] -= (BignumInt) t; - } + if (full_bitoffset < 0) { + /* + * If we find ourselves needing to shift q*m _right_, that + * means we've reached the bottom of the quotient. Clip q + * so that its right shift becomes zero, and if that means + * q becomes _actually_ zero, this loop is done. + */ + if (full_bitoffset <= -BIGNUM_INT_BITS) + break; + q >>= -full_bitoffset; + full_bitoffset = 0; + if (!q) + break; +#ifdef DIVISION_DEBUG + printf("now full_bitoffset=%d, q=%#0*llx\n", + full_bitoffset, BIGNUM_INT_BITS/4, (unsigned long long)q); +#endif + } - /* Add back m in case of borrow */ - if (c != h) { - t = 0; - for (k = mlen - 1; k >= 0; k--) { - t += m[k]; - t += a[i + k]; - a[i + k] = (BignumInt) t; - t = t >> BIGNUM_INT_BITS; - } - q--; - } - if (quot) - internal_add_shifted(quot, q, qshift + BIGNUM_INT_BITS * (alen - mlen - i)); + wordoffset = full_bitoffset / BIGNUM_INT_BITS; + bitoffset = full_bitoffset % BIGNUM_INT_BITS; +#ifdef DIVISION_DEBUG + printf("wordoffset=%d, bitoffset=%d\n", wordoffset, bitoffset); +#endif + + /* wordoffset as computed above is the offset between the LSWs + * of m and a. But in fact m and a are stored MSW-first, so we + * need to adjust it to be the offset between the actual array + * indices, and flip the sign too. */ + wordoffset = alen - mlen - wordoffset; + + if (bitoffset == 0) { + BignumCarry c = 1; + BignumInt prev_hi_word = 0; + for (k = mlen - 1; wordoffset+k >= i; k--) { + BignumInt mword = k<0 ? 0 : m[k]; + BignumMULADD(prev_hi_word, product, q, mword, prev_hi_word); +#ifdef DIVISION_DEBUG + printf(" aligned sub: product word for m[%d] = %#0*llx\n", + k, BIGNUM_INT_BITS/4, + (unsigned long long)product); +#endif +#ifdef DIVISION_DEBUG + printf(" aligned sub: subtrahend for a[%d] = %#0*llx\n", + wordoffset+k, BIGNUM_INT_BITS/4, + (unsigned long long)product); +#endif + BignumADC(a[wordoffset+k], c, a[wordoffset+k], ~product, c); + } + } else { + BignumInt add_word = 0; + BignumInt c = 1; + BignumInt prev_hi_word = 0; + for (k = mlen - 1; wordoffset+k >= i; k--) { + BignumInt mword = k<0 ? 0 : m[k]; + BignumMULADD(prev_hi_word, product, q, mword, prev_hi_word); +#ifdef DIVISION_DEBUG + printf(" unaligned sub: product word for m[%d] = %#0*llx\n", + k, BIGNUM_INT_BITS/4, + (unsigned long long)product); +#endif + + add_word |= product << bitoffset; + +#ifdef DIVISION_DEBUG + printf(" unaligned sub: subtrahend for a[%d] = %#0*llx\n", + wordoffset+k, + BIGNUM_INT_BITS/4, (unsigned long long)add_word); +#endif + BignumADC(a[wordoffset+k], c, a[wordoffset+k], ~add_word, c); + + add_word = product >> (BIGNUM_INT_BITS - bitoffset); + } + } + + if (quot) { +#ifdef DIVISION_DEBUG + printf("adding quotient word %#0*llx << %d\n", + BIGNUM_INT_BITS/4, (unsigned long long)q, full_bitoffset); +#endif + internal_add_shifted(quot, q, full_bitoffset); +#ifdef DIVISION_DEBUG + { + int d; + printf("now quot=0x"); + for (d = quot[0]; d > 0; d--) + printf("%0*llx", BIGNUM_INT_BITS/4, + (unsigned long long)quot[d]); + printf("\n"); + } +#endif + } + } + +#ifdef DIVISION_DEBUG + { + int d; + printf("end main loop, a=0x"); + for (d = 0; d < alen; d++) + printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)a[d]); + if (quot) { + printf(", quot=0x"); + for (d = quot[0]; d > 0; d--) + printf("%0*llx", BIGNUM_INT_BITS/4, + (unsigned long long)quot[d]); + } + printf("\n"); + } +#endif + + /* + * The above loop should terminate with the remaining value in a + * being strictly less than 2*m (if a >= 2*m then we should always + * have managed to get a nonzero q word), but we can't guarantee + * that it will be strictly less than m: consider a case where the + * remainder is 1, and another where the remainder is m-1. By the + * time a contains a value that's _about m_, you clearly can't + * distinguish those cases by looking at only the top word of a - + * you have to go all the way down to the bottom before you find + * out whether it's just less or just more than m. + * + * Hence, we now do a final fixup in which we subtract one last + * copy of m, or don't, accordingly. We should never have to + * subtract more than one copy of m here. + */ + for (i = 0; i < alen; i++) { + /* Compare a with m, word by word, from the MSW down. As soon + * as we encounter a difference, we know whether we need the + * fixup. */ + int mindex = mlen-alen+i; + BignumInt mword = mindex < 0 ? 0 : m[mindex]; + if (a[i] < mword) { +#ifdef DIVISION_DEBUG + printf("final fixup not needed, a < m\n"); +#endif + return; + } else if (a[i] > mword) { +#ifdef DIVISION_DEBUG + printf("final fixup is needed, a > m\n"); +#endif + break; + } + /* If neither of those cases happened, the words are the same, + * so keep going and look at the next one. */ + } +#ifdef DIVISION_DEBUG + if (i == mlen) /* if we printed neither of the above diagnostics */ + printf("final fixup is needed, a == m\n"); +#endif + + /* + * If we got here without returning, then a >= m, so we must + * subtract m, and increment the quotient. + */ + { + BignumCarry c = 1; + for (i = alen - 1; i >= 0; i--) { + int mindex = mlen-alen+i; + BignumInt mword = mindex < 0 ? 0 : m[mindex]; + BignumADC(a[i], c, a[i], ~mword, c); + } + } + if (quot) + internal_add_shifted(quot, 1, 0); + +#ifdef DIVISION_DEBUG + { + int d; + printf("after final fixup, a=0x"); + for (d = 0; d < alen; d++) + printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)a[d]); + if (quot) { + printf(", quot=0x"); + for (d = quot[0]; d > 0; d--) + printf("%0*llx", BIGNUM_INT_BITS/4, + (unsigned long long)quot[d]); + } + printf("\n"); } +#endif } /* - * Compute (base ^ exp) % mod. + * Compute (base ^ exp) % mod, the pedestrian way. */ -Bignum modpow(Bignum base_in, Bignum exp, Bignum mod) +Bignum modpow_simple(Bignum base_in, Bignum exp, Bignum mod) { - BignumInt *a, *b, *n, *m; - int mshift; - int mlen, i, j; + BignumInt *a, *b, *n, *m, *scratch; + BignumInt recip; + int rshift; + int mlen, scratchlen, i, j; Bignum base, result; /* @@ -307,16 +965,6 @@ Bignum modpow(Bignum base_in, Bignum exp, Bignum mod) for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j]; - /* Shift m left to make msb bit set */ - for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++) - if ((m[0] << mshift) & BIGNUM_TOP_BIT) - break; - if (mshift) { - for (i = 0; i < mlen - 1; i++) - m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift)); - m[mlen - 1] = m[mlen - 1] << mshift; - } - /* Allocate n of size mlen, copy base to n */ n = snewn(mlen, BignumInt); i = mlen - base[0]; @@ -332,10 +980,14 @@ Bignum modpow(Bignum base_in, Bignum exp, Bignum mod) a[i] = 0; a[2 * mlen - 1] = 1; + /* Scratch space for multiplies */ + scratchlen = mul_compute_scratch(mlen); + scratch = snewn(scratchlen, BignumInt); + /* Skip leading zero bits of exp. */ i = 0; j = BIGNUM_INT_BITS-1; - while (i < (int)exp[0] && (exp[exp[0] - i] & (1 << j)) == 0) { + while (i < (int)exp[0] && (exp[exp[0] - i] & ((BignumInt)1 << j)) == 0) { j--; if (j < 0) { i++; @@ -343,14 +995,26 @@ Bignum modpow(Bignum base_in, Bignum exp, Bignum mod) } } + /* Compute reciprocal of the top full word of the modulus */ + { + BignumInt m0 = m[0]; + rshift = bn_clz(m0); + if (rshift) { + m0 <<= rshift; + if (mlen > 1) + m0 |= m[1] >> (BIGNUM_INT_BITS - rshift); + } + recip = reciprocal_word(m0); + } + /* Main computation */ while (i < (int)exp[0]) { while (j >= 0) { - internal_mul(a + mlen, a + mlen, b, mlen); - internal_mod(b, mlen * 2, m, mlen, NULL, 0); - if ((exp[exp[0] - i] & (1 << j)) != 0) { - internal_mul(b + mlen, n, a, mlen); - internal_mod(a, mlen * 2, m, mlen, NULL, 0); + internal_mul(a + mlen, a + mlen, b, mlen, scratch); + internal_mod(b, mlen * 2, m, mlen, NULL, recip, rshift); + if ((exp[exp[0] - i] & ((BignumInt)1 << j)) != 0) { + internal_mul(b + mlen, n, a, mlen, scratch); + internal_mod(a, mlen * 2, m, mlen, NULL, recip, rshift); } else { BignumInt *t; t = a; @@ -363,16 +1027,6 @@ Bignum modpow(Bignum base_in, Bignum exp, Bignum mod) j = BIGNUM_INT_BITS-1; } - /* Fixup result in case the modulus was shifted */ - if (mshift) { - for (i = mlen - 1; i < 2 * mlen - 1; i++) - a[i] = (a[i] << mshift) | (a[i + 1] >> (BIGNUM_INT_BITS - mshift)); - a[2 * mlen - 1] = a[2 * mlen - 1] << mshift; - internal_mod(a, mlen * 2, m, mlen, NULL, 0); - for (i = 2 * mlen - 1; i >= mlen; i--) - a[i] = (a[i] >> mshift) | (a[i - 1] << (BIGNUM_INT_BITS - mshift)); - } - /* Copy result to buffer */ result = newbn(mod[0]); for (i = 0; i < mlen; i++) @@ -381,17 +1035,15 @@ Bignum modpow(Bignum base_in, Bignum exp, Bignum mod) result[0]--; /* Free temporary arrays */ - for (i = 0; i < 2 * mlen; i++) - a[i] = 0; + smemclr(a, 2 * mlen * sizeof(*a)); sfree(a); - for (i = 0; i < 2 * mlen; i++) - b[i] = 0; + smemclr(scratch, scratchlen * sizeof(*scratch)); + sfree(scratch); + smemclr(b, 2 * mlen * sizeof(*b)); sfree(b); - for (i = 0; i < mlen; i++) - m[i] = 0; + smemclr(m, mlen * sizeof(*m)); sfree(m); - for (i = 0; i < mlen; i++) - n[i] = 0; + smemclr(n, mlen * sizeof(*n)); sfree(n); freebn(base); @@ -399,6 +1051,152 @@ Bignum modpow(Bignum base_in, Bignum exp, Bignum mod) return result; } +/* + * Compute (base ^ exp) % mod. Uses the Montgomery multiplication + * technique where possible, falling back to modpow_simple otherwise. + */ +Bignum modpow(Bignum base_in, Bignum exp, Bignum mod) +{ + BignumInt *a, *b, *x, *n, *mninv, *scratch; + int len, scratchlen, i, j; + Bignum base, base2, r, rn, inv, result; + + /* + * The most significant word of mod needs to be non-zero. It + * should already be, but let's make sure. + */ + assert(mod[mod[0]] != 0); + + /* + * mod had better be odd, or we can't do Montgomery multiplication + * using a power of two at all. + */ + if (!(mod[1] & 1)) + return modpow_simple(base_in, exp, mod); + + /* + * Make sure the base is smaller than the modulus, by reducing + * it modulo the modulus if not. + */ + base = bigmod(base_in, mod); + + /* + * Compute the inverse of n mod r, for monty_reduce. (In fact we + * want the inverse of _minus_ n mod r, but we'll sort that out + * below.) + */ + len = mod[0]; + r = bn_power_2(BIGNUM_INT_BITS * len); + inv = modinv(mod, r); + assert(inv); /* cannot fail, since mod is odd and r is a power of 2 */ + + /* + * Multiply the base by r mod n, to get it into Montgomery + * representation. + */ + base2 = modmul(base, r, mod); + freebn(base); + base = base2; + + rn = bigmod(r, mod); /* r mod n, i.e. Montgomerified 1 */ + + freebn(r); /* won't need this any more */ + + /* + * Set up internal arrays of the right lengths, in big-endian + * format, containing the base, the modulus, and the modulus's + * inverse. + */ + n = snewn(len, BignumInt); + for (j = 0; j < len; j++) + n[len - 1 - j] = mod[j + 1]; + + mninv = snewn(len, BignumInt); + for (j = 0; j < len; j++) + mninv[len - 1 - j] = (j < (int)inv[0] ? inv[j + 1] : 0); + freebn(inv); /* we don't need this copy of it any more */ + /* Now negate mninv mod r, so it's the inverse of -n rather than +n. */ + x = snewn(len, BignumInt); + for (j = 0; j < len; j++) + x[j] = 0; + internal_sub(x, mninv, mninv, len); + + /* x = snewn(len, BignumInt); */ /* already done above */ + for (j = 0; j < len; j++) + x[len - 1 - j] = (j < (int)base[0] ? base[j + 1] : 0); + freebn(base); /* we don't need this copy of it any more */ + + a = snewn(2*len, BignumInt); + b = snewn(2*len, BignumInt); + for (j = 0; j < len; j++) + a[2*len - 1 - j] = (j < (int)rn[0] ? rn[j + 1] : 0); + freebn(rn); + + /* Scratch space for multiplies */ + scratchlen = 3*len + mul_compute_scratch(len); + scratch = snewn(scratchlen, BignumInt); + + /* Skip leading zero bits of exp. */ + i = 0; + j = BIGNUM_INT_BITS-1; + while (i < (int)exp[0] && (exp[exp[0] - i] & ((BignumInt)1 << j)) == 0) { + j--; + if (j < 0) { + i++; + j = BIGNUM_INT_BITS-1; + } + } + + /* Main computation */ + while (i < (int)exp[0]) { + while (j >= 0) { + internal_mul(a + len, a + len, b, len, scratch); + monty_reduce(b, n, mninv, scratch, len); + if ((exp[exp[0] - i] & ((BignumInt)1 << j)) != 0) { + internal_mul(b + len, x, a, len, scratch); + monty_reduce(a, n, mninv, scratch, len); + } else { + BignumInt *t; + t = a; + a = b; + b = t; + } + j--; + } + i++; + j = BIGNUM_INT_BITS-1; + } + + /* + * Final monty_reduce to get back from the adjusted Montgomery + * representation. + */ + monty_reduce(a, n, mninv, scratch, len); + + /* Copy result to buffer */ + result = newbn(mod[0]); + for (i = 0; i < len; i++) + result[result[0] - i] = a[i + len]; + while (result[0] > 1 && result[result[0]] == 0) + result[0]--; + + /* Free temporary arrays */ + smemclr(scratch, scratchlen * sizeof(*scratch)); + sfree(scratch); + smemclr(a, 2 * len * sizeof(*a)); + sfree(a); + smemclr(b, 2 * len * sizeof(*b)); + sfree(b); + smemclr(mninv, len * sizeof(*mninv)); + sfree(mninv); + smemclr(n, len * sizeof(*n)); + sfree(n); + smemclr(x, len * sizeof(*x)); + sfree(x); + + return result; +} + /* * Compute (p * q) % mod. * The most significant word of mod MUST be non-zero. @@ -406,11 +1204,18 @@ Bignum modpow(Bignum base_in, Bignum exp, Bignum mod) */ Bignum modmul(Bignum p, Bignum q, Bignum mod) { - BignumInt *a, *n, *m, *o; - int mshift; + BignumInt *a, *n, *m, *o, *scratch; + BignumInt recip; + int rshift, scratchlen; int pqlen, mlen, rlen, i, j; Bignum result; + /* + * The most significant word of mod needs to be non-zero. It + * should already be, but let's make sure. + */ + assert(mod[mod[0]] != 0); + /* Allocate m of size mlen, copy mod to m */ /* We use big endian internally */ mlen = mod[0]; @@ -418,18 +1223,15 @@ Bignum modmul(Bignum p, Bignum q, Bignum mod) for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j]; - /* Shift m left to make msb bit set */ - for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++) - if ((m[0] << mshift) & BIGNUM_TOP_BIT) - break; - if (mshift) { - for (i = 0; i < mlen - 1; i++) - m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift)); - m[mlen - 1] = m[mlen - 1] << mshift; - } - pqlen = (p[0] > q[0] ? p[0] : q[0]); + /* + * Make sure that we're allowing enough space. The shifting below + * will underflow the vectors we allocate if pqlen is too small. + */ + if (2*pqlen <= mlen) + pqlen = mlen/2 + 1; + /* Allocate n of size pqlen, copy p to n */ n = snewn(pqlen, BignumInt); i = pqlen - p[0]; @@ -449,20 +1251,26 @@ Bignum modmul(Bignum p, Bignum q, Bignum mod) /* Allocate a of size 2*pqlen for result */ a = snewn(2 * pqlen, BignumInt); - /* Main computation */ - internal_mul(n, o, a, pqlen); - internal_mod(a, pqlen * 2, m, mlen, NULL, 0); - - /* Fixup result in case the modulus was shifted */ - if (mshift) { - for (i = 2 * pqlen - mlen - 1; i < 2 * pqlen - 1; i++) - a[i] = (a[i] << mshift) | (a[i + 1] >> (BIGNUM_INT_BITS - mshift)); - a[2 * pqlen - 1] = a[2 * pqlen - 1] << mshift; - internal_mod(a, pqlen * 2, m, mlen, NULL, 0); - for (i = 2 * pqlen - 1; i >= 2 * pqlen - mlen; i--) - a[i] = (a[i] >> mshift) | (a[i - 1] << (BIGNUM_INT_BITS - mshift)); + /* Scratch space for multiplies */ + scratchlen = mul_compute_scratch(pqlen); + scratch = snewn(scratchlen, BignumInt); + + /* Compute reciprocal of the top full word of the modulus */ + { + BignumInt m0 = m[0]; + rshift = bn_clz(m0); + if (rshift) { + m0 <<= rshift; + if (mlen > 1) + m0 |= m[1] >> (BIGNUM_INT_BITS - rshift); + } + recip = reciprocal_word(m0); } + /* Main computation */ + internal_mul(n, o, a, pqlen, scratch); + internal_mod(a, pqlen * 2, m, mlen, NULL, recip, rshift); + /* Copy result to buffer */ rlen = (mlen < pqlen * 2 ? mlen : pqlen * 2); result = newbn(rlen); @@ -472,22 +1280,49 @@ Bignum modmul(Bignum p, Bignum q, Bignum mod) result[0]--; /* Free temporary arrays */ - for (i = 0; i < 2 * pqlen; i++) - a[i] = 0; + smemclr(scratch, scratchlen * sizeof(*scratch)); + sfree(scratch); + smemclr(a, 2 * pqlen * sizeof(*a)); sfree(a); - for (i = 0; i < mlen; i++) - m[i] = 0; + smemclr(m, mlen * sizeof(*m)); sfree(m); - for (i = 0; i < pqlen; i++) - n[i] = 0; + smemclr(n, pqlen * sizeof(*n)); sfree(n); - for (i = 0; i < pqlen; i++) - o[i] = 0; + smemclr(o, pqlen * sizeof(*o)); sfree(o); return result; } +Bignum modsub(const Bignum a, const Bignum b, const Bignum n) +{ + Bignum a1, b1, ret; + + if (bignum_cmp(a, n) >= 0) a1 = bigmod(a, n); + else a1 = a; + if (bignum_cmp(b, n) >= 0) b1 = bigmod(b, n); + else b1 = b; + + if (bignum_cmp(a1, b1) >= 0) /* a >= b */ + { + ret = bigsub(a1, b1); + } + else + { + /* Handle going round the corner of the modulus without having + * negative support in Bignum */ + Bignum tmp = bigsub(n, b1); + assert(tmp); + ret = bigadd(tmp, a1); + freebn(tmp); + } + + if (a != a1) freebn(a1); + if (b != b1) freebn(b1); + + return ret; +} + /* * Compute p % mod. * The most significant word of mod MUST be non-zero. @@ -498,9 +1333,16 @@ Bignum modmul(Bignum p, Bignum q, Bignum mod) static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient) { BignumInt *n, *m; - int mshift; + BignumInt recip; + int rshift; int plen, mlen, i, j; + /* + * The most significant word of mod needs to be non-zero. It + * should already be, but let's make sure. + */ + assert(mod[mod[0]] != 0); + /* Allocate m of size mlen, copy mod to m */ /* We use big endian internally */ mlen = mod[0]; @@ -508,16 +1350,6 @@ static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient) for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j]; - /* Shift m left to make msb bit set */ - for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++) - if ((m[0] << mshift) & BIGNUM_TOP_BIT) - break; - if (mshift) { - for (i = 0; i < mlen - 1; i++) - m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift)); - m[mlen - 1] = m[mlen - 1] << mshift; - } - plen = p[0]; /* Ensure plen > mlen */ if (plen <= mlen) @@ -530,19 +1362,21 @@ static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient) for (j = 1; j <= (int)p[0]; j++) n[plen - j] = p[j]; - /* Main computation */ - internal_mod(n, plen, m, mlen, quotient, mshift); - - /* Fixup result in case the modulus was shifted */ - if (mshift) { - for (i = plen - mlen - 1; i < plen - 1; i++) - n[i] = (n[i] << mshift) | (n[i + 1] >> (BIGNUM_INT_BITS - mshift)); - n[plen - 1] = n[plen - 1] << mshift; - internal_mod(n, plen, m, mlen, quotient, 0); - for (i = plen - 1; i >= plen - mlen; i--) - n[i] = (n[i] >> mshift) | (n[i - 1] << (BIGNUM_INT_BITS - mshift)); + /* Compute reciprocal of the top full word of the modulus */ + { + BignumInt m0 = m[0]; + rshift = bn_clz(m0); + if (rshift) { + m0 <<= rshift; + if (mlen > 1) + m0 |= m[1] >> (BIGNUM_INT_BITS - rshift); + } + recip = reciprocal_word(m0); } + /* Main computation */ + internal_mod(n, plen, m, mlen, quotient, recip, rshift); + /* Copy result to buffer */ if (result) { for (i = 1; i <= (int)result[0]; i++) { @@ -552,11 +1386,9 @@ static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient) } /* Free temporary arrays */ - for (i = 0; i < mlen; i++) - m[i] = 0; + smemclr(m, mlen * sizeof(*m)); sfree(m); - for (i = 0; i < plen; i++) - n[i] = 0; + smemclr(n, plen * sizeof(*n)); sfree(n); } @@ -576,6 +1408,8 @@ Bignum bignum_from_bytes(const unsigned char *data, int nbytes) Bignum result; int w, i; + assert(nbytes >= 0 && nbytes < INT_MAX/8); + w = (nbytes + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES; /* bytes->words */ result = newbn(w); @@ -583,14 +1417,94 @@ Bignum bignum_from_bytes(const unsigned char *data, int nbytes) result[i] = 0; for (i = nbytes; i--;) { unsigned char byte = *data++; - result[1 + i / BIGNUM_INT_BYTES] |= byte << (8*i % BIGNUM_INT_BITS); + result[1 + i / BIGNUM_INT_BYTES] |= + (BignumInt)byte << (8*i % BIGNUM_INT_BITS); + } + + bn_restore_invariant(result); + return result; +} + +Bignum bignum_from_bytes_le(const unsigned char *data, int nbytes) +{ + Bignum result; + int w, i; + + assert(nbytes >= 0 && nbytes < INT_MAX/8); + + w = (nbytes + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES; /* bytes->words */ + + result = newbn(w); + for (i = 1; i <= w; i++) + result[i] = 0; + for (i = 0; i < nbytes; ++i) { + unsigned char byte = *data++; + result[1 + i / BIGNUM_INT_BYTES] |= + (BignumInt)byte << (8*i % BIGNUM_INT_BITS); + } + + bn_restore_invariant(result); + return result; +} + +Bignum bignum_from_decimal(const char *decimal) +{ + Bignum result = copybn(Zero); + + while (*decimal) { + Bignum tmp, tmp2; + + if (!isdigit((unsigned char)*decimal)) { + freebn(result); + return 0; + } + + tmp = bigmul(result, Ten); + tmp2 = bignum_from_long(*decimal - '0'); + freebn(result); + result = bigadd(tmp, tmp2); + freebn(tmp); + freebn(tmp2); + + decimal++; } - while (result[0] > 1 && result[result[0]] == 0) - result[0]--; return result; } +Bignum bignum_random_in_range(const Bignum lower, const Bignum upper) +{ + Bignum ret = NULL; + unsigned char *bytes; + int upper_len = bignum_bitcount(upper); + int upper_bytes = upper_len / 8; + int upper_bits = upper_len % 8; + if (upper_bits) ++upper_bytes; + + bytes = snewn(upper_bytes, unsigned char); + do { + int i; + + if (ret) freebn(ret); + + for (i = 0; i < upper_bytes; ++i) + { + bytes[i] = (unsigned char)random_byte(); + } + /* Mask the top to reduce failure rate to 50/50 */ + if (upper_bits) + { + bytes[i - 1] &= 0xFF >> (8 - upper_bits); + } + + ret = bignum_from_bytes(bytes, upper_bytes); + } while (bignum_cmp(ret, lower) < 0 || bignum_cmp(ret, upper) > 0); + smemclr(bytes, upper_bytes); + sfree(bytes); + + return ret; +} + /* * Read an SSH-1-format bignum from a data buffer. Return the number * of bytes consumed, or -1 if there wasn't enough data. @@ -652,7 +1566,7 @@ int ssh2_bignum_length(Bignum bn) */ int bignum_byte(Bignum bn, int i) { - if (i >= (int)(BIGNUM_INT_BYTES * bn[0])) + if (i < 0 || i >= (int)(BIGNUM_INT_BYTES * bn[0])) return 0; /* beyond the end */ else return (bn[i / BIGNUM_INT_BYTES + 1] >> @@ -664,7 +1578,7 @@ int bignum_byte(Bignum bn, int i) */ int bignum_bit(Bignum bn, int i) { - if (i >= (int)(BIGNUM_INT_BITS * bn[0])) + if (i < 0 || i >= (int)(BIGNUM_INT_BITS * bn[0])) return 0; /* beyond the end */ else return (bn[i / BIGNUM_INT_BITS + 1] >> (i % BIGNUM_INT_BITS)) & 1; @@ -675,11 +1589,11 @@ int bignum_bit(Bignum bn, int i) */ void bignum_set_bit(Bignum bn, int bitnum, int value) { - if (bitnum >= (int)(BIGNUM_INT_BITS * bn[0])) - abort(); /* beyond the end */ - else { + if (bitnum < 0 || bitnum >= (int)(BIGNUM_INT_BITS * bn[0])) { + if (value) abort(); /* beyond the end */ + } else { int v = bitnum / BIGNUM_INT_BITS + 1; - int mask = 1 << (bitnum % BIGNUM_INT_BITS); + BignumInt mask = (BignumInt)1 << (bitnum % BIGNUM_INT_BITS); if (value) bn[v] |= mask; else @@ -711,7 +1625,18 @@ int ssh1_write_bignum(void *data, Bignum bn) int bignum_cmp(Bignum a, Bignum b) { int amax = a[0], bmax = b[0]; - int i = (amax > bmax ? amax : bmax); + int i; + + /* Annoyingly we have two representations of zero */ + if (amax == 1 && a[amax] == 0) + amax = 0; + if (bmax == 1 && b[bmax] == 0) + bmax = 0; + + assert(amax == 0 || a[amax] != 0); + assert(bmax == 0 || b[bmax] != 0); + + i = (amax > bmax ? amax : bmax); while (i) { BignumInt aval = (i > amax ? 0 : a[i]); BignumInt bval = (i > bmax ? 0 : b[i]); @@ -733,6 +1658,8 @@ Bignum bignum_rshift(Bignum a, int shift) int i, shiftw, shiftb, shiftbb, bits; BignumInt ai, ai1; + assert(shift >= 0); + bits = bignum_bitcount(a) - shift; ret = newbn((bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS); @@ -752,6 +1679,44 @@ Bignum bignum_rshift(Bignum a, int shift) return ret; } +/* + * Left-shift one bignum to form another. + */ +Bignum bignum_lshift(Bignum a, int shift) +{ + Bignum ret; + int bits, shiftWords, shiftBits; + + assert(shift >= 0); + + bits = bignum_bitcount(a) + shift; + ret = newbn((bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS); + + shiftWords = shift / BIGNUM_INT_BITS; + shiftBits = shift % BIGNUM_INT_BITS; + + if (shiftBits == 0) + { + memcpy(&ret[1 + shiftWords], &a[1], sizeof(BignumInt) * a[0]); + } + else + { + int i; + BignumInt carry = 0; + + /* Remember that Bignum[0] is length, so add 1 */ + for (i = shiftWords + 1; i < ((int)a[0]) + shiftWords + 1; ++i) + { + BignumInt from = a[i - shiftWords]; + ret[i] = (from << shiftBits) | carry; + carry = from >> (BIGNUM_INT_BITS - shiftBits); + } + if (carry) ret[i] = carry; + } + + return ret; +} + /* * Non-modular multiplication and addition. */ @@ -760,18 +1725,21 @@ Bignum bigmuladd(Bignum a, Bignum b, Bignum addend) int alen = a[0], blen = b[0]; int mlen = (alen > blen ? alen : blen); int rlen, i, maxspot; + int wslen; BignumInt *workspace; Bignum ret; - /* mlen space for a, mlen space for b, 2*mlen for result */ - workspace = snewn(mlen * 4, BignumInt); + /* mlen space for a, mlen space for b, 2*mlen for result, + * plus scratch space for multiplication */ + wslen = mlen * 4 + mul_compute_scratch(mlen); + workspace = snewn(wslen, BignumInt); for (i = 0; i < mlen; i++) { workspace[0 * mlen + i] = (mlen - i <= (int)a[0] ? a[mlen - i] : 0); workspace[1 * mlen + i] = (mlen - i <= (int)b[0] ? b[mlen - i] : 0); } internal_mul(workspace + 0 * mlen, workspace + 1 * mlen, - workspace + 2 * mlen, mlen); + workspace + 2 * mlen, mlen, workspace + 4 * mlen); /* now just copy the result back */ rlen = alen + blen + 1; @@ -788,18 +1756,18 @@ Bignum bigmuladd(Bignum a, Bignum b, Bignum addend) /* now add in the addend, if any */ if (addend) { - BignumDblInt carry = 0; + BignumCarry carry = 0; for (i = 1; i <= rlen; i++) { - carry += (i <= (int)ret[0] ? ret[i] : 0); - carry += (i <= (int)addend[0] ? addend[i] : 0); - ret[i] = (BignumInt) carry & BIGNUM_INT_MASK; - carry >>= BIGNUM_INT_BITS; + BignumInt retword = (i <= (int)ret[0] ? ret[i] : 0); + BignumInt addword = (i <= (int)addend[0] ? addend[i] : 0); + BignumADC(ret[i], carry, retword, addword, carry); if (ret[i] != 0 && i > maxspot) maxspot = i; } } ret[0] = maxspot; + smemclr(workspace, wslen * sizeof(*workspace)); sfree(workspace); return ret; } @@ -812,6 +1780,67 @@ Bignum bigmul(Bignum a, Bignum b) return bigmuladd(a, b, NULL); } +/* + * Simple addition. + */ +Bignum bigadd(Bignum a, Bignum b) +{ + int alen = a[0], blen = b[0]; + int rlen = (alen > blen ? alen : blen) + 1; + int i, maxspot; + Bignum ret; + BignumCarry carry; + + ret = newbn(rlen); + + carry = 0; + maxspot = 0; + for (i = 1; i <= rlen; i++) { + BignumInt aword = (i <= (int)a[0] ? a[i] : 0); + BignumInt bword = (i <= (int)b[0] ? b[i] : 0); + BignumADC(ret[i], carry, aword, bword, carry); + if (ret[i] != 0 && i > maxspot) + maxspot = i; + } + ret[0] = maxspot; + + return ret; +} + +/* + * Subtraction. Returns a-b, or NULL if the result would come out + * negative (recall that this entire bignum module only handles + * positive numbers). + */ +Bignum bigsub(Bignum a, Bignum b) +{ + int alen = a[0], blen = b[0]; + int rlen = (alen > blen ? alen : blen); + int i, maxspot; + Bignum ret; + BignumCarry carry; + + ret = newbn(rlen); + + carry = 1; + maxspot = 0; + for (i = 1; i <= rlen; i++) { + BignumInt aword = (i <= (int)a[0] ? a[i] : 0); + BignumInt bword = (i <= (int)b[0] ? b[i] : 0); + BignumADC(ret[i], carry, aword, ~bword, carry); + if (ret[i] != 0 && i > maxspot) + maxspot = i; + } + ret[0] = maxspot; + + if (!carry) { + freebn(ret); + return NULL; + } + + return ret; +} + /* * Create a bignum which is the bitmask covering another one. That * is, the smallest integer which is >= N and is also one less than @@ -838,40 +1867,52 @@ Bignum bignum_bitmask(Bignum n) } /* - * Convert a (max 32-bit) long into a bignum. + * Convert an unsigned long into a bignum. */ -Bignum bignum_from_long(unsigned long nn) +Bignum bignum_from_long(unsigned long n) { + const int maxwords = + (sizeof(unsigned long) + sizeof(BignumInt) - 1) / sizeof(BignumInt); Bignum ret; - BignumDblInt n = nn; + int i; + + ret = newbn(maxwords); + ret[0] = 0; + for (i = 0; i < maxwords; i++) { + ret[i+1] = n >> (i * BIGNUM_INT_BITS); + if (ret[i+1] != 0) + ret[0] = i+1; + } - ret = newbn(3); - ret[1] = (BignumInt)(n & BIGNUM_INT_MASK); - ret[2] = (BignumInt)((n >> BIGNUM_INT_BITS) & BIGNUM_INT_MASK); - ret[3] = 0; - ret[0] = (ret[2] ? 2 : 1); return ret; } /* * Add a long to a bignum. */ -Bignum bignum_add_long(Bignum number, unsigned long addendx) +Bignum bignum_add_long(Bignum number, unsigned long n) { - Bignum ret = newbn(number[0] + 1); - int i, maxspot = 0; - BignumDblInt carry = 0, addend = addendx; - - for (i = 1; i <= (int)ret[0]; i++) { - carry += addend & BIGNUM_INT_MASK; - carry += (i <= (int)number[0] ? number[i] : 0); - addend >>= BIGNUM_INT_BITS; - ret[i] = (BignumInt) carry & BIGNUM_INT_MASK; - carry >>= BIGNUM_INT_BITS; - if (ret[i] != 0) - maxspot = i; + const int maxwords = + (sizeof(unsigned long) + sizeof(BignumInt) - 1) / sizeof(BignumInt); + Bignum ret; + int words, i; + BignumCarry carry; + + words = number[0]; + if (words < maxwords) + words = maxwords; + words++; + ret = newbn(words); + + carry = 0; + ret[0] = 0; + for (i = 0; i < words; i++) { + BignumInt nword = (i < maxwords ? n >> (i * BIGNUM_INT_BITS) : 0); + BignumInt numword = (i < number[0] ? number[i+1] : 0); + BignumADC(ret[i+1], carry, numword, nword, carry); + if (ret[i+1] != 0) + ret[0] = i+1; } - ret[0] = maxspot; return ret; } @@ -880,13 +1921,17 @@ Bignum bignum_add_long(Bignum number, unsigned long addendx) */ unsigned short bignum_mod_short(Bignum number, unsigned short modulus) { - BignumDblInt mod, r; + unsigned long mod = modulus, r = 0; + /* Precompute (BIGNUM_INT_MASK+1) % mod */ + unsigned long base_r = (BIGNUM_INT_MASK - modulus + 1) % mod; int i; - r = 0; - mod = modulus; - for (i = number[0]; i > 0; i--) - r = (r * (BIGNUM_TOP_BIT % mod) * 2 + number[i] % mod) % mod; + for (i = number[0]; i > 0; i--) { + /* + * Conceptually, ((r << BIGNUM_INT_BITS) + number[i]) % mod + */ + r = ((r * base_r) + (number[i] % mod)) % mod; + } return (unsigned short) r; } @@ -920,6 +1965,8 @@ Bignum bigdiv(Bignum a, Bignum b) { Bignum q = newbn(a[0]); bigdivmod(a, b, NULL, q); + while (q[0] > 1 && q[q[0]] == 0) + q[0]--; return q; } @@ -930,6 +1977,8 @@ Bignum bigmod(Bignum a, Bignum b) { Bignum r = newbn(b[0]); bigdivmod(a, b, r, NULL); + while (r[0] > 1 && r[r[0]] == 0) + r[0]--; return r; } @@ -966,12 +2015,31 @@ Bignum modinv(Bignum number, Bignum modulus) Bignum x = copybn(One); int sign = +1; + assert(number[number[0]] != 0); + assert(modulus[modulus[0]] != 0); + while (bignum_cmp(b, One) != 0) { - Bignum t = newbn(b[0]); - Bignum q = newbn(a[0]); + Bignum t, q; + + if (bignum_cmp(b, Zero) == 0) { + /* + * Found a common factor between the inputs, so we cannot + * return a modular inverse at all. + */ + freebn(b); + freebn(a); + freebn(xp); + freebn(x); + return NULL; + } + + t = newbn(b[0]); + q = newbn(a[0]); bigdivmod(a, b, t, q); while (t[0] > 1 && t[t[0]] == 0) t[0]--; + while (q[0] > 1 && q[q[0]] == 0) + q[0]--; freebn(a); a = b; b = t; @@ -1021,7 +2089,7 @@ char *bignum_decimal(Bignum x) { int ndigits, ndigit; int i, iszero; - BignumDblInt carry; + BignumInt carry; char *ret; BignumInt *workspace; @@ -1068,11 +2136,33 @@ char *bignum_decimal(Bignum x) iszero = 1; carry = 0; for (i = 0; i < (int)x[0]; i++) { - carry = (carry << BIGNUM_INT_BITS) + workspace[i]; - workspace[i] = (BignumInt) (carry / 10); + /* + * Conceptually, we want to compute + * + * (carry << BIGNUM_INT_BITS) + workspace[i] + * ----------------------------------------- + * 10 + * + * but we don't have an integer type longer than BignumInt + * to work with. So we have to do it in pieces. + */ + + BignumInt q, r; + q = workspace[i] / 10; + r = workspace[i] % 10; + + /* I want (BIGNUM_INT_MASK+1)/10 but can't say so directly! */ + q += carry * ((BIGNUM_INT_MASK-9) / 10 + 1); + r += carry * ((BIGNUM_INT_MASK-9) % 10); + + q += r / 10; + r %= 10; + + workspace[i] = q; + carry = r; + if (workspace[i]) iszero = 0; - carry %= 10; } ret[--ndigit] = (char) (carry + '0'); } while (!iszero); @@ -1087,6 +2177,7 @@ char *bignum_decimal(Bignum x) /* * Done. */ + smemclr(workspace, x[0] * sizeof(*workspace)); sfree(workspace); return ret; }