]> asedeno.scripts.mit.edu Git - PuTTY.git/commitdiff
Rewrite the core divide function to not use DIVMOD_WORD.
authorSimon Tatham <anakin@pobox.com>
Sun, 13 Dec 2015 14:46:43 +0000 (14:46 +0000)
committerSimon Tatham <anakin@pobox.com>
Sun, 13 Dec 2015 14:48:39 +0000 (14:48 +0000)
DIVMOD_WORD is a portability hazard, because implementing it requires
either a way to get direct access to the x86 DIV instruction or
equivalent (be it inline assembler or a compiler intrinsic), or else
an integer type we can use as BignumDblInt. But I'm starting to think
about porting to 64-bit Visual Studio with a 64-bit BignumInt, and in
that situation neither of those options will be available.

I could write a piece of _out_-of-line x86-64 assembler in a separate
source file and put a function call in DIVMOD_WORD, but instead I've
decided to solve the problem in a more futureproof way: remove
DIVMOD_WORD totally and write a division function that doesn't need it
at all, solving not only today's porting headache but all future ones
in this area.

The new implementation works by precomputing (a good enough
approximation to) the leading word of the reciprocal of the modulus,
and then getting each word of quotient by multiplying by that
reciprocal, where we previously used DIVMOD_WORD to divide by the
leading word of the actual modulus. The reciprocal itself is computed
outside internal_mod() and passed in as a parameter, allowing me to
save time by only computing it once when I'm about to do a modpow.

To some extent this complicates the implementation: the advantage of
DIVMOD_WORD was that it yielded a full word q of quotient every time
it was used, so the subtraction of q*m from the input could be done in
a nicely word-aligned way. But the reciprocal multiply approach yields
_almost_ a full word of quotient, because you have to make the
reciprocal a bit short to avoid overflow at multiplication time. For a
start, this means we have to do fractionally more iterations of the
main loop; but more painfully, we can no longer depend on the
subtraction of q*m at every step being word-aligned, and instead we
have to be prepared to do it at any bit shift.

But the flip side is that once we've implemented that, the rest of the
algorithm becomes a lot less full of horrible special cases: in
particular, we can now completely throw away the horribleness at all
the call sites where we shift the modulus up by a fractional word to
set its top bit, and then have to do a little dance to get the last
few bits of quotient involving a second call to internal_mod.

So there are points both for and against the new implementation in
simplicity terms; but I think on balance it's more comprehensible than
the old one, and a quick timing test suggests it also ends up a touch
faster overall - the new testbn gets through the output of
testdata/bignum.py in 4.034s where the old one took 4.392s.

sshbn.c
sshbn.h

diff --git a/sshbn.c b/sshbn.c
index c98590938b5d98917348232b7bdc468877f686c6..455aa57abf98ef1ef45fd0e99c5fd870937e8bfb 100644 (file)
--- a/sshbn.c
+++ b/sshbn.c
@@ -532,107 +532,419 @@ static void internal_add_shifted(BignumInt *number,
     }
 }
 
+static int bn_clz(BignumInt x)
+{
+    /*
+     * Count the leading zero bits in x. Equivalently, how far left
+     * would we need to shift x to make its top bit set?
+     *
+     * Precondition: x != 0.
+     */
+
+    /* FIXME: would be nice to put in some compiler intrinsics under
+     * ifdef here */
+    int i, ret = 0;
+    for (i = BIGNUM_INT_BITS / 2; i != 0; i >>= 1) {
+        if ((x >> (BIGNUM_INT_BITS-i)) == 0) {
+            x <<= i;
+            ret += i;
+        }
+    }
+    return ret;
+}
+
+static BignumInt reciprocal_word(BignumInt d)
+{
+    BignumInt dshort, recip;
+    BignumDblInt product;
+    int corrections;
+
+    /*
+     * Input: a BignumInt value d, with its top bit set.
+     */
+    assert(d >> (BIGNUM_INT_BITS-1) == 1);
+
+    /*
+     * Output: a value, shifted to fill a BignumInt, which is strictly
+     * less than 1/(d+1), i.e. is an *under*-estimate (but by as
+     * little as possible within the constraints) of the reciprocal of
+     * any number whose first BIGNUM_INT_BITS bits match d.
+     *
+     * Ideally we'd like to _totally_ fill BignumInt, i.e. always
+     * return a value with the top bit set. Unfortunately we can't
+     * quite guarantee that for all inputs and also return a fixed
+     * exponent. So instead we take our reciprocal to be
+     * 2^(BIGNUM_INT_BITS*2-1) / d, so that it has the top bit clear
+     * only in the exceptional case where d takes exactly the maximum
+     * value BIGNUM_INT_MASK; in that case, the top bit is clear and
+     * the next bit down is set.
+     */
+
+    /*
+     * Start by computing a half-length version of the answer, by
+     * straightforward division within a BignumInt.
+     */
+    dshort = (d >> (BIGNUM_INT_BITS/2)) + 1;
+    recip = (BIGNUM_TOP_BIT + dshort - 1) / dshort;
+    recip <<= BIGNUM_INT_BITS - BIGNUM_INT_BITS/2;
+
+    /*
+     * Newton-Raphson iteration to improve that starting reciprocal
+     * estimate: take f(x) = d - 1/x, and then the N-R formula gives
+     * x_new = x - f(x)/f'(x) = x - (d-1/x)/(1/x^2) = x(2-d*x). Or,
+     * taking our fixed-point representation into account, take f(x)
+     * to be d - K/x (where K = 2^(BIGNUM_INT_BITS*2-1) as discussed
+     * above) and then we get (2K - d*x) * x/K.
+     *
+     * Newton-Raphson doubles the number of correct bits at every
+     * iteration, and the initial division above already gave us half
+     * the output word, so it's only worth doing one iteration.
+     */
+    product = MUL_WORD(recip, d);
+    product += recip;
+    product = -product;                /* the 2K shifts just off the top */
+    product &= (((BignumDblInt)BIGNUM_INT_MASK << BIGNUM_INT_BITS) +
+                BIGNUM_INT_MASK);
+    product >>= BIGNUM_INT_BITS;
+    product = MUL_WORD(product, recip);
+    product >>= (BIGNUM_INT_BITS-1);
+    recip = (BignumInt)product;
+
+    /*
+     * Now make sure we have the best possible reciprocal estimate,
+     * before we return it. We might have been off by a handful either
+     * way - not enough to bother with any better-thought-out kind of
+     * correction loop.
+     */
+    product = MUL_WORD(recip, d);
+    product += recip;
+    corrections = 0;
+    if (product >= ((BignumDblInt)1 << (2*BIGNUM_INT_BITS-1))) {
+        do {
+            product -= d;
+            recip--;
+            corrections++;
+        } while (product >= ((BignumDblInt)1 << (2*BIGNUM_INT_BITS-1)));
+    } else {
+        while (product < ((BignumDblInt)1 << (2*BIGNUM_INT_BITS-1)) - d) {
+            product += d;
+            recip++;
+            corrections++;
+        }
+    }
+
+    return recip;
+}
+
 /*
  * Compute a = a % m.
  * Input in first alen words of a and first mlen words of m.
  * Output in first alen words of a
  * (of which first alen-mlen words will be zero).
- * The MSW of m MUST have its high bit set.
  * Quotient is accumulated in the `quotient' array, which is a Bignum
- * rather than the internal bigendian format. Quotient parts are shifted
- * left by `qshift' before adding into quot.
+ * rather than the internal bigendian format.
+ *
+ * 'recip' must be the result of calling reciprocal_word() on the top
+ * BIGNUM_INT_BITS of the modulus (denoted m0 in comments below), with
+ * the topmost set bit normalised to the MSB of the input to
+ * reciprocal_word. 'rshift' is how far left the top nonzero word of
+ * the modulus had to be shifted to set that top bit.
  */
 static void internal_mod(BignumInt *a, int alen,
                         BignumInt *m, int mlen,
-                        BignumInt *quot, int qshift)
+                        BignumInt *quot, BignumInt recip, int rshift)
 {
-    BignumInt m0, m1, h;
     int i, k;
 
-    m0 = m[0];
-    assert(m0 >> (BIGNUM_INT_BITS-1) == 1);
-    if (mlen > 1)
-       m1 = m[1];
-    else
-       m1 = 0;
+#ifdef DIVISION_DEBUG
+    {
+        int d;
+        printf("start division, m=0x");
+        for (d = 0; d < mlen; d++)
+            printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)m[d]);
+        printf(", recip=%#0*llx, rshift=%d\n",
+               BIGNUM_INT_BITS/4, (unsigned long long)recip, rshift);
+    }
+#endif
 
-    for (i = 0; i <= alen - mlen; i++) {
-       BignumDblInt t;
-        BignumInt q, r, c, ai1;
+    /*
+     * Repeatedly use that reciprocal estimate to get a decent number
+     * of quotient bits, and subtract off the resulting multiple of m.
+     *
+     * Normally we expect to terminate this loop by means of finding
+     * out q=0 part way through, but one way in which we might not get
+     * that far in the first place is if the input a is actually zero,
+     * in which case we'll discard zero words from the front of a
+     * until we reach the termination condition in the for statement
+     * here.
+     */
+    for (i = 0; i <= alen - mlen ;) {
+       BignumDblInt product, subtmp, t;
+        BignumInt aword, q;
+        int shift, full_bitoffset, bitoffset, wordoffset;
 
-       if (i == 0) {
-           h = 0;
-       } else {
-           h = a[i - 1];
-           a[i - 1] = 0;
-       }
+#ifdef DIVISION_DEBUG
+        {
+            int d;
+            printf("main loop, a=0x");
+            for (d = 0; d < alen; d++)
+                printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)a[d]);
+            printf("\n");
+        }
+#endif
 
-       if (i == alen - 1)
-           ai1 = 0;
-       else
-           ai1 = a[i + 1];
-
-       /* Find q = h:a[i] / m0 */
-       if (h >= m0) {
-           /*
-            * Special case.
-            * 
-            * To illustrate it, suppose a BignumInt is 8 bits, and
-            * we are dividing (say) A1:23:45:67 by A1:B2:C3. Then
-            * our initial division will be 0xA123 / 0xA1, which
-            * will give a quotient of 0x100 and a divide overflow.
-            * However, the invariants in this division algorithm
-            * are not violated, since the full number A1:23:... is
-            * _less_ than the quotient prefix A1:B2:... and so the
-            * following correction loop would have sorted it out.
-            * 
-            * In this situation we set q to be the largest
-            * quotient we _can_ stomach (0xFF, of course).
-            */
-           q = BIGNUM_INT_MASK;
-       } else {
-           /* Macro doesn't want an array subscript expression passed
-            * into it (see definition), so use a temporary. */
-           BignumInt tmplo = a[i];
-           DIVMOD_WORD(q, r, h, tmplo, m0);
-
-           /* Refine our estimate of q by looking at
-            h:a[i]:a[i+1] / m0:m1 */
-           t = MUL_WORD(m1, q);
-           if (t > ((BignumDblInt) r << BIGNUM_INT_BITS) + ai1) {
-               q--;
-               t -= m1;
-               r = (r + m0) & BIGNUM_INT_MASK;     /* overflow? */
-               if (r >= (BignumDblInt) m0 &&
-                   t > ((BignumDblInt) r << BIGNUM_INT_BITS) + ai1) q--;
-           }
-       }
+        if (a[i] == 0) {
+#ifdef DIVISION_DEBUG
+            printf("zero word at i=%d\n", i);
+#endif
+            i++;
+            continue;
+        }
 
-       /* Subtract q * m from a[i...] */
-       c = 0;
-       for (k = mlen - 1; k >= 0; k--) {
-           t = MUL_WORD(q, m[k]);
-           t += c;
-           c = (BignumInt)(t >> BIGNUM_INT_BITS);
-           if ((BignumInt) t > a[i + k])
-               c++;
-           a[i + k] -= (BignumInt) t;
-       }
+        aword = a[i];
+        shift = bn_clz(aword);
+        aword <<= shift;
+        if (shift > 0 && i+1 < alen)
+            aword |= a[i+1] >> (BIGNUM_INT_BITS - shift);
 
-       /* Add back m in case of borrow */
-       if (c != h) {
-           t = 0;
-           for (k = mlen - 1; k >= 0; k--) {
-               t += m[k];
-               t += a[i + k];
-               a[i + k] = (BignumInt) t;
-               t = t >> BIGNUM_INT_BITS;
-           }
-           q--;
-       }
-       if (quot)
-           internal_add_shifted(quot, q, qshift + BIGNUM_INT_BITS * (alen - mlen - i));
+        t = MUL_WORD(recip, aword);
+        q = (BignumInt)(t >> BIGNUM_INT_BITS);
+
+#ifdef DIVISION_DEBUG
+        printf("i=%d, aword=%#0*llx, shift=%d, q=%#0*llx\n",
+               i, BIGNUM_INT_BITS/4, (unsigned long long)aword,
+               shift, BIGNUM_INT_BITS/4, (unsigned long long)q);
+#endif
+
+        /*
+         * Work out the right bit and word offsets to use when
+         * subtracting q*m from a.
+         *
+         * aword was taken from a[i], which means its LSB was at bit
+         * position (alen-1-i) * BIGNUM_INT_BITS. But then we shifted
+         * it left by 'shift', so now the low bit of aword corresponds
+         * to bit position (alen-1-i) * BIGNUM_INT_BITS - shift, i.e.
+         * aword is approximately equal to a / 2^(that).
+         *
+         * m0 comes from the top word of mod, so its LSB is at bit
+         * position (mlen-1) * BIGNUM_INT_BITS - rshift, i.e. it can
+         * be considered to be m / 2^(that power). 'recip' is the
+         * reciprocal of m0, times 2^(BIGNUM_INT_BITS*2-1), i.e. it's
+         * about 2^((mlen+1) * BIGNUM_INT_BITS - rshift - 1) / m.
+         *
+         * Hence, recip * aword is approximately equal to the product
+         * of those, which simplifies to
+         *
+         * a/m * 2^((mlen+2+i-alen)*BIGNUM_INT_BITS + shift - rshift - 1)
+         *
+         * But we've also shifted recip*aword down by BIGNUM_INT_BITS
+         * to form q, so we have
+         *
+         * q ~= a/m * 2^((mlen+1+i-alen)*BIGNUM_INT_BITS + shift - rshift - 1)
+         *
+         * and hence, when we now compute q*m, it will be about
+         * a*2^(all that lot), i.e. the negation of that expression is
+         * how far left we have to shift the product q*m to make it
+         * approximately equal to a.
+         */
+        full_bitoffset = -((mlen+1+i-alen)*BIGNUM_INT_BITS + shift-rshift-1);
+#ifdef DIVISION_DEBUG
+        printf("full_bitoffset=%d\n", full_bitoffset);
+#endif
+
+        if (full_bitoffset < 0) {
+            /*
+             * If we find ourselves needing to shift q*m _right_, that
+             * means we've reached the bottom of the quotient. Clip q
+             * so that its right shift becomes zero, and if that means
+             * q becomes _actually_ zero, this loop is done.
+             */
+            if (full_bitoffset <= -BIGNUM_INT_BITS)
+                break;
+            q >>= -full_bitoffset;
+            full_bitoffset = 0;
+            if (!q)
+                break;
+#ifdef DIVISION_DEBUG
+            printf("now full_bitoffset=%d, q=%#0*llx\n",
+                   full_bitoffset, BIGNUM_INT_BITS/4, (unsigned long long)q);
+#endif
+        }
+
+        wordoffset = full_bitoffset / BIGNUM_INT_BITS;
+        bitoffset = full_bitoffset % BIGNUM_INT_BITS;
+#ifdef DIVISION_DEBUG
+        printf("wordoffset=%d, bitoffset=%d\n", wordoffset, bitoffset);
+#endif
+
+        /* wordoffset as computed above is the offset between the LSWs
+         * of m and a. But in fact m and a are stored MSW-first, so we
+         * need to adjust it to be the offset between the actual array
+         * indices, and flip the sign too. */
+        wordoffset = alen - mlen - wordoffset;
+
+        if (bitoffset == 0) {
+            BignumInt c = 1;
+            BignumInt prev_hi_word = 0;
+            for (k = mlen - 1; wordoffset+k >= i; k--) {
+                BignumInt mword = k<0 ? 0 : m[k];
+                product = MUL_WORD(q, mword);
+                product += prev_hi_word;
+                prev_hi_word = product >> BIGNUM_INT_BITS;
+#ifdef DIVISION_DEBUG
+                printf("  aligned sub: product word for m[%d] = %#0*llx\n",
+                       k, BIGNUM_INT_BITS/4,
+                       (unsigned long long)(BignumInt)product);
+#endif
+#ifdef DIVISION_DEBUG
+                printf("  aligned sub: subtrahend for a[%d] = %#0*llx\n",
+                       wordoffset+k, BIGNUM_INT_BITS/4,
+                       (unsigned long long)(BignumInt)product);
+#endif
+                subtmp = (BignumDblInt)a[wordoffset+k] +
+                    ((BignumInt)product ^ BIGNUM_INT_MASK) + c;
+                a[wordoffset+k] = (BignumInt)subtmp;
+                c = subtmp >> BIGNUM_INT_BITS;
+            }
+        } else {
+            BignumInt add_word = 0;
+            BignumInt c = 1;
+            BignumInt prev_hi_word = 0;
+            for (k = mlen - 1; wordoffset+k >= i; k--) {
+                BignumInt mword = k<0 ? 0 : m[k];
+                product = MUL_WORD(q, mword);
+                product += prev_hi_word;
+                prev_hi_word = product >> BIGNUM_INT_BITS;
+#ifdef DIVISION_DEBUG
+                printf("  unaligned sub: product word for m[%d] = %#0*llx\n",
+                       k, BIGNUM_INT_BITS/4,
+                       (unsigned long long)(BignumInt)product);
+#endif
+
+                add_word |= (BignumInt)product << bitoffset;
+
+#ifdef DIVISION_DEBUG
+                printf("  unaligned sub: subtrahend for a[%d] = %#0*llx\n",
+                       wordoffset+k,
+                       BIGNUM_INT_BITS/4, (unsigned long long)add_word);
+#endif
+                subtmp = (BignumDblInt)a[wordoffset+k] +
+                    (add_word ^ BIGNUM_INT_MASK) + c;
+                a[wordoffset+k] = (BignumInt)subtmp;
+                c = subtmp >> BIGNUM_INT_BITS;
+
+                add_word = (BignumInt)product >> (BIGNUM_INT_BITS - bitoffset);
+            }
+        }
+
+       if (quot) {
+#ifdef DIVISION_DEBUG
+            printf("adding quotient word %#0*llx << %d\n",
+                   BIGNUM_INT_BITS/4, (unsigned long long)q, full_bitoffset);
+#endif
+           internal_add_shifted(quot, q, full_bitoffset);
+#ifdef DIVISION_DEBUG
+            {
+                int d;
+                printf("now quot=0x");
+                for (d = quot[0]; d > 0; d--)
+                    printf("%0*llx", BIGNUM_INT_BITS/4,
+                           (unsigned long long)quot[d]);
+                printf("\n");
+            }
+#endif
+        }
+    }
+
+#ifdef DIVISION_DEBUG
+    {
+        int d;
+        printf("end main loop, a=0x");
+        for (d = 0; d < alen; d++)
+            printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)a[d]);
+        if (quot) {
+            printf(", quot=0x");
+            for (d = quot[0]; d > 0; d--)
+                printf("%0*llx", BIGNUM_INT_BITS/4,
+                       (unsigned long long)quot[d]);
+        }
+        printf("\n");
+    }
+#endif
+
+    /*
+     * The above loop should terminate with the remaining value in a
+     * being strictly less than 2*m (if a >= 2*m then we should always
+     * have managed to get a nonzero q word), but we can't guarantee
+     * that it will be strictly less than m: consider a case where the
+     * remainder is 1, and another where the remainder is m-1. By the
+     * time a contains a value that's _about m_, you clearly can't
+     * distinguish those cases by looking at only the top word of a -
+     * you have to go all the way down to the bottom before you find
+     * out whether it's just less or just more than m.
+     *
+     * Hence, we now do a final fixup in which we subtract one last
+     * copy of m, or don't, accordingly. We should never have to
+     * subtract more than one copy of m here.
+     */
+    for (i = 0; i < alen; i++) {
+        /* Compare a with m, word by word, from the MSW down. As soon
+         * as we encounter a difference, we know whether we need the
+         * fixup. */
+        int mindex = mlen-alen+i;
+        BignumInt mword = mindex < 0 ? 0 : m[mindex];
+        if (a[i] < mword) {
+#ifdef DIVISION_DEBUG
+            printf("final fixup not needed, a < m\n");
+#endif
+            return;
+        } else if (a[i] > mword) {
+#ifdef DIVISION_DEBUG
+            printf("final fixup is needed, a > m\n");
+#endif
+            break;
+        }
+        /* If neither of those cases happened, the words are the same,
+         * so keep going and look at the next one. */
     }
+#ifdef DIVISION_DEBUG
+    if (i == mlen) /* if we printed neither of the above diagnostics */
+        printf("final fixup is needed, a == m\n");
+#endif
+
+    /*
+     * If we got here without returning, then a >= m, so we must
+     * subtract m, and increment the quotient.
+     */
+    {
+        BignumInt c = 1;
+        for (i = alen - 1; i >= 0; i--) {
+            int mindex = mlen-alen+i;
+            BignumInt mword = mindex < 0 ? 0 : m[mindex];
+            BignumDblInt subtmp = (BignumDblInt)a[i] +
+                ((BignumInt)mword ^ BIGNUM_INT_MASK) + c;
+            a[i] = (BignumInt)subtmp;
+            c = subtmp >> BIGNUM_INT_BITS;
+        }
+    }
+    if (quot)
+        internal_add_shifted(quot, 1, 0);
+
+#ifdef DIVISION_DEBUG
+    {
+        int d;
+        printf("after final fixup, a=0x");
+        for (d = 0; d < alen; d++)
+            printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)a[d]);
+        if (quot) {
+            printf(", quot=0x");
+            for (d = quot[0]; d > 0; d--)
+                printf("%0*llx", BIGNUM_INT_BITS/4,
+                       (unsigned long long)quot[d]);
+        }
+        printf("\n");
+    }
+#endif
 }
 
 /*
@@ -641,7 +953,8 @@ static void internal_mod(BignumInt *a, int alen,
 Bignum modpow_simple(Bignum base_in, Bignum exp, Bignum mod)
 {
     BignumInt *a, *b, *n, *m, *scratch;
-    int mshift;
+    BignumInt recip;
+    int rshift;
     int mlen, scratchlen, i, j;
     Bignum base, result;
 
@@ -664,16 +977,6 @@ Bignum modpow_simple(Bignum base_in, Bignum exp, Bignum mod)
     for (j = 0; j < mlen; j++)
        m[j] = mod[mod[0] - j];
 
-    /* Shift m left to make msb bit set */
-    for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++)
-       if ((m[0] << mshift) & BIGNUM_TOP_BIT)
-           break;
-    if (mshift) {
-       for (i = 0; i < mlen - 1; i++)
-           m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift));
-       m[mlen - 1] = m[mlen - 1] << mshift;
-    }
-
     /* Allocate n of size mlen, copy base to n */
     n = snewn(mlen, BignumInt);
     i = mlen - base[0];
@@ -704,14 +1007,26 @@ Bignum modpow_simple(Bignum base_in, Bignum exp, Bignum mod)
        }
     }
 
+    /* Compute reciprocal of the top full word of the modulus */
+    {
+        BignumInt m0 = m[0];
+        rshift = bn_clz(m0);
+        if (rshift) {
+            m0 <<= rshift;
+            if (mlen > 1)
+                m0 |= m[1] >> (BIGNUM_INT_BITS - rshift);
+        }
+        recip = reciprocal_word(m0);
+    }
+
     /* Main computation */
     while (i < (int)exp[0]) {
        while (j >= 0) {
            internal_mul(a + mlen, a + mlen, b, mlen, scratch);
-           internal_mod(b, mlen * 2, m, mlen, NULL, 0);
+           internal_mod(b, mlen * 2, m, mlen, NULL, recip, rshift);
            if ((exp[exp[0] - i] & ((BignumInt)1 << j)) != 0) {
                internal_mul(b + mlen, n, a, mlen, scratch);
-               internal_mod(a, mlen * 2, m, mlen, NULL, 0);
+               internal_mod(a, mlen * 2, m, mlen, NULL, recip, rshift);
            } else {
                BignumInt *t;
                t = a;
@@ -724,16 +1039,6 @@ Bignum modpow_simple(Bignum base_in, Bignum exp, Bignum mod)
        j = BIGNUM_INT_BITS-1;
     }
 
-    /* Fixup result in case the modulus was shifted */
-    if (mshift) {
-       for (i = mlen - 1; i < 2 * mlen - 1; i++)
-           a[i] = (a[i] << mshift) | (a[i + 1] >> (BIGNUM_INT_BITS - mshift));
-       a[2 * mlen - 1] = a[2 * mlen - 1] << mshift;
-       internal_mod(a, mlen * 2, m, mlen, NULL, 0);
-       for (i = 2 * mlen - 1; i >= mlen; i--)
-           a[i] = (a[i] >> mshift) | (a[i - 1] << (BIGNUM_INT_BITS - mshift));
-    }
-
     /* Copy result to buffer */
     result = newbn(mod[0]);
     for (i = 0; i < mlen; i++)
@@ -912,7 +1217,8 @@ Bignum modpow(Bignum base_in, Bignum exp, Bignum mod)
 Bignum modmul(Bignum p, Bignum q, Bignum mod)
 {
     BignumInt *a, *n, *m, *o, *scratch;
-    int mshift, scratchlen;
+    BignumInt recip;
+    int rshift, scratchlen;
     int pqlen, mlen, rlen, i, j;
     Bignum result;
 
@@ -929,16 +1235,6 @@ Bignum modmul(Bignum p, Bignum q, Bignum mod)
     for (j = 0; j < mlen; j++)
        m[j] = mod[mod[0] - j];
 
-    /* Shift m left to make msb bit set */
-    for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++)
-       if ((m[0] << mshift) & BIGNUM_TOP_BIT)
-           break;
-    if (mshift) {
-       for (i = 0; i < mlen - 1; i++)
-           m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift));
-       m[mlen - 1] = m[mlen - 1] << mshift;
-    }
-
     pqlen = (p[0] > q[0] ? p[0] : q[0]);
 
     /*
@@ -971,19 +1267,21 @@ Bignum modmul(Bignum p, Bignum q, Bignum mod)
     scratchlen = mul_compute_scratch(pqlen);
     scratch = snewn(scratchlen, BignumInt);
 
+    /* Compute reciprocal of the top full word of the modulus */
+    {
+        BignumInt m0 = m[0];
+        rshift = bn_clz(m0);
+        if (rshift) {
+            m0 <<= rshift;
+            if (mlen > 1)
+                m0 |= m[1] >> (BIGNUM_INT_BITS - rshift);
+        }
+        recip = reciprocal_word(m0);
+    }
+
     /* Main computation */
     internal_mul(n, o, a, pqlen, scratch);
-    internal_mod(a, pqlen * 2, m, mlen, NULL, 0);
-
-    /* Fixup result in case the modulus was shifted */
-    if (mshift) {
-       for (i = 2 * pqlen - mlen - 1; i < 2 * pqlen - 1; i++)
-           a[i] = (a[i] << mshift) | (a[i + 1] >> (BIGNUM_INT_BITS - mshift));
-       a[2 * pqlen - 1] = a[2 * pqlen - 1] << mshift;
-       internal_mod(a, pqlen * 2, m, mlen, NULL, 0);
-       for (i = 2 * pqlen - 1; i >= 2 * pqlen - mlen; i--)
-           a[i] = (a[i] >> mshift) | (a[i - 1] << (BIGNUM_INT_BITS - mshift));
-    }
+    internal_mod(a, pqlen * 2, m, mlen, NULL, recip, rshift);
 
     /* Copy result to buffer */
     rlen = (mlen < pqlen * 2 ? mlen : pqlen * 2);
@@ -1047,7 +1345,8 @@ Bignum modsub(const Bignum a, const Bignum b, const Bignum n)
 static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient)
 {
     BignumInt *n, *m;
-    int mshift;
+    BignumInt recip;
+    int rshift;
     int plen, mlen, i, j;
 
     /*
@@ -1063,16 +1362,6 @@ static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient)
     for (j = 0; j < mlen; j++)
        m[j] = mod[mod[0] - j];
 
-    /* Shift m left to make msb bit set */
-    for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++)
-       if ((m[0] << mshift) & BIGNUM_TOP_BIT)
-           break;
-    if (mshift) {
-       for (i = 0; i < mlen - 1; i++)
-           m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift));
-       m[mlen - 1] = m[mlen - 1] << mshift;
-    }
-
     plen = p[0];
     /* Ensure plen > mlen */
     if (plen <= mlen)
@@ -1085,19 +1374,21 @@ static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient)
     for (j = 1; j <= (int)p[0]; j++)
        n[plen - j] = p[j];
 
-    /* Main computation */
-    internal_mod(n, plen, m, mlen, quotient, mshift);
-
-    /* Fixup result in case the modulus was shifted */
-    if (mshift) {
-       for (i = plen - mlen - 1; i < plen - 1; i++)
-           n[i] = (n[i] << mshift) | (n[i + 1] >> (BIGNUM_INT_BITS - mshift));
-       n[plen - 1] = n[plen - 1] << mshift;
-       internal_mod(n, plen, m, mlen, quotient, 0);
-       for (i = plen - 1; i >= plen - mlen; i--)
-           n[i] = (n[i] >> mshift) | (n[i - 1] << (BIGNUM_INT_BITS - mshift));
+    /* Compute reciprocal of the top full word of the modulus */
+    {
+        BignumInt m0 = m[0];
+        rshift = bn_clz(m0);
+        if (rshift) {
+            m0 <<= rshift;
+            if (mlen > 1)
+                m0 |= m[1] >> (BIGNUM_INT_BITS - rshift);
+        }
+        recip = reciprocal_word(m0);
     }
 
+    /* Main computation */
+    internal_mod(n, plen, m, mlen, quotient, recip, rshift);
+
     /* Copy result to buffer */
     if (result) {
        for (i = 1; i <= (int)result[0]; i++) {
diff --git a/sshbn.h b/sshbn.h
index 9366f614ae4cde0cb198626180758db235e4cee1..fc0e6b5a807b60fe27ea76f1f9459b80882bb812 100644 (file)
--- a/sshbn.h
+++ b/sshbn.h
@@ -1,23 +1,8 @@
 /*
  * sshbn.h: the assorted conditional definitions of BignumInt and
- * multiply/divide macros used throughout the bignum code to treat
- * numbers as arrays of the most conveniently sized word for the
- * target machine. Exported so that other code (e.g. poly1305) can use
- * it too.
- */
-
-/*
- * Usage notes:
- *  * Do not call the DIVMOD_WORD macro with expressions such as array
- *    subscripts, as some implementations object to this (see below).
- *  * Note that none of the division methods below will cope if the
- *    quotient won't fit into BIGNUM_INT_BITS. Callers should be careful
- *    to avoid this case.
- *    If this condition occurs, in the case of the x86 DIV instruction,
- *    an overflow exception will occur, which (according to a correspondent)
- *    will manifest on Windows as something like
- *      0xC0000095: Integer overflow
- *    The C variant won't give the right answer, either.
+ * multiply macros used throughout the bignum code to treat numbers as
+ * arrays of the most conveniently sized word for the target machine.
+ * Exported so that other code (e.g. poly1305) can use it too.
  */
 
 #if defined __SIZEOF_INT128__
@@ -32,11 +17,6 @@ typedef __uint128_t BignumDblInt;
 #define BIGNUM_TOP_BIT   0x8000000000000000ULL
 #define BIGNUM_INT_BITS  64
 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2)
-#define DIVMOD_WORD(q, r, hi, lo, w) do { \
-    BignumDblInt n = (((BignumDblInt)hi) << BIGNUM_INT_BITS) | lo; \
-    q = n / w; \
-    r = n % w; \
-} while (0)
 #elif defined __GNUC__ && defined __i386__
 typedef unsigned long BignumInt;
 typedef unsigned long long BignumDblInt;
@@ -44,10 +24,6 @@ typedef unsigned long long BignumDblInt;
 #define BIGNUM_TOP_BIT   0x80000000UL
 #define BIGNUM_INT_BITS  32
 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2)
-#define DIVMOD_WORD(q, r, hi, lo, w) \
-    __asm__("div %2" : \
-           "=d" (r), "=a" (q) : \
-           "r" (w), "d" (hi), "a" (lo))
 #elif defined _MSC_VER && defined _M_IX86
 typedef unsigned __int32 BignumInt;
 typedef unsigned __int64 BignumDblInt;
@@ -55,16 +31,6 @@ typedef unsigned __int64 BignumDblInt;
 #define BIGNUM_TOP_BIT   0x80000000UL
 #define BIGNUM_INT_BITS  32
 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2)
-/* Note: MASM interprets array subscripts in the macro arguments as
- * assembler syntax, which gives the wrong answer. Don't supply them.
- * <http://msdn2.microsoft.com/en-us/library/bf1dw62z.aspx> */
-#define DIVMOD_WORD(q, r, hi, lo, w) do { \
-    __asm mov edx, hi \
-    __asm mov eax, lo \
-    __asm div w \
-    __asm mov r, edx \
-    __asm mov q, eax \
-} while(0)
 #elif defined _LP64
 /* 64-bit architectures can do 32x32->64 chunks at a time */
 typedef unsigned int BignumInt;
@@ -73,11 +39,6 @@ typedef unsigned long BignumDblInt;
 #define BIGNUM_TOP_BIT   0x80000000U
 #define BIGNUM_INT_BITS  32
 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2)
-#define DIVMOD_WORD(q, r, hi, lo, w) do { \
-    BignumDblInt n = (((BignumDblInt)hi) << BIGNUM_INT_BITS) | lo; \
-    q = n / w; \
-    r = n % w; \
-} while (0)
 #elif defined _LLP64
 /* 64-bit architectures in which unsigned long is 32 bits, not 64 */
 typedef unsigned long BignumInt;
@@ -86,11 +47,6 @@ typedef unsigned long long BignumDblInt;
 #define BIGNUM_TOP_BIT   0x80000000UL
 #define BIGNUM_INT_BITS  32
 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2)
-#define DIVMOD_WORD(q, r, hi, lo, w) do { \
-    BignumDblInt n = (((BignumDblInt)hi) << BIGNUM_INT_BITS) | lo; \
-    q = n / w; \
-    r = n % w; \
-} while (0)
 #else
 /* Fallback for all other cases */
 typedef unsigned short BignumInt;
@@ -99,11 +55,6 @@ typedef unsigned long BignumDblInt;
 #define BIGNUM_TOP_BIT   0x8000U
 #define BIGNUM_INT_BITS  16
 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2)
-#define DIVMOD_WORD(q, r, hi, lo, w) do { \
-    BignumDblInt n = (((BignumDblInt)hi) << BIGNUM_INT_BITS) | lo; \
-    q = n / w; \
-    r = n % w; \
-} while (0)
 #endif
 
 #define BIGNUM_INT_BYTES (BIGNUM_INT_BITS / 8)