]> asedeno.scripts.mit.edu Git - PuTTY.git/blobdiff - sshbn.c
Expand comment on BUG_SSH2_OLDGEX to make it clear why it's necessary.
[PuTTY.git] / sshbn.c
diff --git a/sshbn.c b/sshbn.c
index a2067831baed6011928d826b3aa7e9c36638885a..b5c99125b2633645d83c06d28e842b8740fb5099 100644 (file)
--- a/sshbn.c
+++ b/sshbn.c
@@ -6,6 +6,7 @@
 #include <assert.h>
 #include <stdlib.h>
 #include <string.h>
+#include <limits.h>
 
 #include "misc.h"
 
@@ -120,7 +121,11 @@ Bignum Zero = bnZero, One = bnOne;
 
 static Bignum newbn(int length)
 {
-    Bignum b = snewn(length + 1, BignumInt);
+    Bignum b;
+
+    assert(length >= 0 && length < INT_MAX / BIGNUM_INT_BITS);
+
+    b = snewn(length + 1, BignumInt);
     if (!b)
        abort();                       /* FIXME */
     memset(b, 0, (length + 1) * sizeof(*b));
@@ -154,7 +159,11 @@ void freebn(Bignum b)
 
 Bignum bn_power_2(int n)
 {
-    Bignum ret = newbn(n / BIGNUM_INT_BITS + 1);
+    Bignum ret;
+
+    assert(n >= 0);
+
+    ret = newbn(n / BIGNUM_INT_BITS + 1);
     bignum_set_bit(ret, n, 1);
     return ret;
 }
@@ -598,6 +607,7 @@ static void internal_add_shifted(BignumInt *number,
     addend = (BignumDblInt)n << bshift;
 
     while (addend) {
+        assert(word <= number[0]);
        addend += number[word];
        number[word] = (BignumInt) addend & BIGNUM_INT_MASK;
        addend >>= BIGNUM_INT_BITS;
@@ -1082,6 +1092,38 @@ Bignum modmul(Bignum p, Bignum q, Bignum mod)
     return result;
 }
 
+Bignum modsub(const Bignum a, const Bignum b, const Bignum n)
+{
+    Bignum a1, b1, ret;
+
+    if (bignum_cmp(a, n) >= 0) a1 = bigmod(a, n);
+    else a1 = a;
+    if (bignum_cmp(b, n) >= 0) b1 = bigmod(b, n);
+    else b1 = b;
+
+    if (bignum_cmp(a1, b1) >= 0) /* a >= b */
+    {
+        ret = bigsub(a1, b1);
+    }
+    else
+    {
+        /* Handle going round the corner of the modulus without having
+         * negative support in Bignum */
+        Bignum tmp = bigsub(n, b1);
+        if (tmp) {
+            ret = bigadd(tmp, a1);
+            freebn(tmp);
+        } else {
+            ret = NULL;
+        }
+    }
+
+    if (a != a1) freebn(a1);
+    if (b != b1) freebn(b1);
+
+    return ret;
+}
+
 /*
  * Compute p % mod.
  * The most significant word of mod MUST be non-zero.
@@ -1174,6 +1216,8 @@ Bignum bignum_from_bytes(const unsigned char *data, int nbytes)
     Bignum result;
     int w, i;
 
+    assert(nbytes >= 0 && nbytes < INT_MAX/8);
+
     w = (nbytes + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES; /* bytes->words */
 
     result = newbn(w);
@@ -1189,6 +1233,61 @@ Bignum bignum_from_bytes(const unsigned char *data, int nbytes)
     return result;
 }
 
+Bignum bignum_from_bytes_le(const unsigned char *data, int nbytes)
+{
+    Bignum result;
+    int w, i;
+
+    assert(nbytes >= 0 && nbytes < INT_MAX/8);
+
+    w = (nbytes + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES; /* bytes->words */
+
+    result = newbn(w);
+    for (i = 1; i <= w; i++)
+        result[i] = 0;
+    for (i = 0; i < nbytes; ++i) {
+        unsigned char byte = *data++;
+        result[1 + i / BIGNUM_INT_BYTES] |= byte << (8*i % BIGNUM_INT_BITS);
+    }
+
+    while (result[0] > 1 && result[result[0]] == 0)
+        result[0]--;
+    return result;
+}
+
+Bignum bignum_random_in_range(const Bignum lower, const Bignum upper)
+{
+    Bignum ret = NULL;
+    unsigned char *bytes;
+    int upper_len = bignum_bitcount(upper);
+    int upper_bytes = upper_len / 8;
+    int upper_bits = upper_len % 8;
+    if (upper_bits) ++upper_bytes;
+
+    bytes = snewn(upper_bytes, unsigned char);
+    do {
+        int i;
+
+        if (ret) freebn(ret);
+
+        for (i = 0; i < upper_bytes; ++i)
+        {
+            bytes[i] = (unsigned char)random_byte();
+        }
+        /* Mask the top to reduce failure rate to 50/50 */
+        if (upper_bits)
+        {
+            bytes[i - 1] &= 0xFF >> (8 - upper_bits);
+        }
+
+        ret = bignum_from_bytes(bytes, upper_bytes);
+    } while (bignum_cmp(ret, lower) < 0 || bignum_cmp(ret, upper) > 0);
+    smemclr(bytes, upper_bytes);
+    sfree(bytes);
+
+    return ret;
+}
+
 /*
  * Read an SSH-1-format bignum from a data buffer. Return the number
  * of bytes consumed, or -1 if there wasn't enough data.
@@ -1250,7 +1349,7 @@ int ssh2_bignum_length(Bignum bn)
  */
 int bignum_byte(Bignum bn, int i)
 {
-    if (i >= (int)(BIGNUM_INT_BYTES * bn[0]))
+    if (i < 0 || i >= (int)(BIGNUM_INT_BYTES * bn[0]))
        return 0;                      /* beyond the end */
     else
        return (bn[i / BIGNUM_INT_BYTES + 1] >>
@@ -1262,7 +1361,7 @@ int bignum_byte(Bignum bn, int i)
  */
 int bignum_bit(Bignum bn, int i)
 {
-    if (i >= (int)(BIGNUM_INT_BITS * bn[0]))
+    if (i < 0 || i >= (int)(BIGNUM_INT_BITS * bn[0]))
        return 0;                      /* beyond the end */
     else
        return (bn[i / BIGNUM_INT_BITS + 1] >> (i % BIGNUM_INT_BITS)) & 1;
@@ -1273,7 +1372,7 @@ int bignum_bit(Bignum bn, int i)
  */
 void bignum_set_bit(Bignum bn, int bitnum, int value)
 {
-    if (bitnum >= (int)(BIGNUM_INT_BITS * bn[0]))
+    if (bitnum < 0 || bitnum >= (int)(BIGNUM_INT_BITS * bn[0]))
        abort();                       /* beyond the end */
     else {
        int v = bitnum / BIGNUM_INT_BITS + 1;
@@ -1309,7 +1408,18 @@ int ssh1_write_bignum(void *data, Bignum bn)
 int bignum_cmp(Bignum a, Bignum b)
 {
     int amax = a[0], bmax = b[0];
-    int i = (amax > bmax ? amax : bmax);
+    int i;
+
+    /* Annoyingly we have two representations of zero */
+    if (amax == 1 && a[amax] == 0)
+        amax = 0;
+    if (bmax == 1 && b[bmax] == 0)
+        bmax = 0;
+
+    assert(amax == 0 || a[amax] != 0);
+    assert(bmax == 0 || b[bmax] != 0);
+
+    i = (amax > bmax ? amax : bmax);
     while (i) {
        BignumInt aval = (i > amax ? 0 : a[i]);
        BignumInt bval = (i > bmax ? 0 : b[i]);
@@ -1331,6 +1441,8 @@ Bignum bignum_rshift(Bignum a, int shift)
     int i, shiftw, shiftb, shiftbb, bits;
     BignumInt ai, ai1;
 
+    assert(shift >= 0);
+
     bits = bignum_bitcount(a) - shift;
     ret = newbn((bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS);
 
@@ -1350,6 +1462,45 @@ Bignum bignum_rshift(Bignum a, int shift)
     return ret;
 }
 
+/*
+ * Left-shift one bignum to form another.
+ */
+Bignum bignum_lshift(Bignum a, int shift)
+{
+    Bignum ret;
+    int bits, shiftWords, shiftBits;
+
+    assert(shift >= 0);
+
+    bits = bignum_bitcount(a) + shift;
+    ret = newbn((bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS);
+    if (!ret) return NULL;
+
+    shiftWords = shift / BIGNUM_INT_BITS;
+    shiftBits = shift % BIGNUM_INT_BITS;
+
+    if (shiftBits == 0)
+    {
+        memcpy(&ret[1 + shiftWords], &a[1], sizeof(BignumInt) * a[0]);
+    }
+    else
+    {
+        int i;
+        BignumInt carry = 0;
+
+        /* Remember that Bignum[0] is length, so add 1 */
+        for (i = shiftWords + 1; i < ((int)a[0]) + shiftWords + 1; ++i)
+        {
+            BignumInt from = a[i - shiftWords];
+            ret[i] = (from << shiftBits) | carry;
+            carry = from >> (BIGNUM_INT_BITS - shiftBits);
+        }
+        if (carry) ret[i] = carry;
+    }
+
+    return ret;
+}
+
 /*
  * Non-modular multiplication and addition.
  */
@@ -1585,6 +1736,8 @@ Bignum bigdiv(Bignum a, Bignum b)
 {
     Bignum q = newbn(a[0]);
     bigdivmod(a, b, NULL, q);
+    while (q[0] > 1 && q[q[0]] == 0)
+        q[0]--;
     return q;
 }
 
@@ -1595,6 +1748,8 @@ Bignum bigmod(Bignum a, Bignum b)
 {
     Bignum r = newbn(b[0]);
     bigdivmod(a, b, r, NULL);
+    while (r[0] > 1 && r[r[0]] == 0)
+        r[0]--;
     return r;
 }
 
@@ -1642,6 +1797,10 @@ Bignum modinv(Bignum number, Bignum modulus)
              * Found a common factor between the inputs, so we cannot
              * return a modular inverse at all.
              */
+            freebn(b);
+            freebn(a);
+            freebn(xp);
+            freebn(x);
             return NULL;
         }
 
@@ -1650,6 +1809,8 @@ Bignum modinv(Bignum number, Bignum modulus)
        bigdivmod(a, b, t, q);
        while (t[0] > 1 && t[t[0]] == 0)
            t[0]--;
+       while (q[0] > 1 && q[q[0]] == 0)
+           q[0]--;
        freebn(a);
        a = b;
        b = t;