]> asedeno.scripts.mit.edu Git - PuTTY.git/blobdiff - sshbn.c
Giant const-correctness patch of doom!
[PuTTY.git] / sshbn.c
diff --git a/sshbn.c b/sshbn.c
index a2067831baed6011928d826b3aa7e9c36638885a..a35bf31444a23e788e1fb09fa16e7d4d7b2861c9 100644 (file)
--- a/sshbn.c
+++ b/sshbn.c
@@ -6,6 +6,8 @@
 #include <assert.h>
 #include <stdlib.h>
 #include <string.h>
+#include <limits.h>
+#include <ctype.h>
 
 #include "misc.h"
 
@@ -101,6 +103,7 @@ typedef BignumInt *Bignum;
 
 BignumInt bnZero[1] = { 0 };
 BignumInt bnOne[2] = { 1, 1 };
+BignumInt bnTen[2] = { 1, 10 };
 
 /*
  * The Bignum format is an array of `BignumInt'. The first
@@ -116,11 +119,15 @@ BignumInt bnOne[2] = { 1, 1 };
  * nonzero.
  */
 
-Bignum Zero = bnZero, One = bnOne;
+Bignum Zero = bnZero, One = bnOne, Ten = bnTen;
 
 static Bignum newbn(int length)
 {
-    Bignum b = snewn(length + 1, BignumInt);
+    Bignum b;
+
+    assert(length >= 0 && length < INT_MAX / BIGNUM_INT_BITS);
+
+    b = snewn(length + 1, BignumInt);
     if (!b)
        abort();                       /* FIXME */
     memset(b, 0, (length + 1) * sizeof(*b));
@@ -154,7 +161,11 @@ void freebn(Bignum b)
 
 Bignum bn_power_2(int n)
 {
-    Bignum ret = newbn(n / BIGNUM_INT_BITS + 1);
+    Bignum ret;
+
+    assert(n >= 0);
+
+    ret = newbn(n / BIGNUM_INT_BITS + 1);
     bignum_set_bit(ret, n, 1);
     return ret;
 }
@@ -598,6 +609,7 @@ static void internal_add_shifted(BignumInt *number,
     addend = (BignumDblInt)n << bshift;
 
     while (addend) {
+        assert(word <= number[0]);
        addend += number[word];
        number[word] = (BignumInt) addend & BIGNUM_INT_MASK;
        addend >>= BIGNUM_INT_BITS;
@@ -1082,6 +1094,38 @@ Bignum modmul(Bignum p, Bignum q, Bignum mod)
     return result;
 }
 
+Bignum modsub(const Bignum a, const Bignum b, const Bignum n)
+{
+    Bignum a1, b1, ret;
+
+    if (bignum_cmp(a, n) >= 0) a1 = bigmod(a, n);
+    else a1 = a;
+    if (bignum_cmp(b, n) >= 0) b1 = bigmod(b, n);
+    else b1 = b;
+
+    if (bignum_cmp(a1, b1) >= 0) /* a >= b */
+    {
+        ret = bigsub(a1, b1);
+    }
+    else
+    {
+        /* Handle going round the corner of the modulus without having
+         * negative support in Bignum */
+        Bignum tmp = bigsub(n, b1);
+        if (tmp) {
+            ret = bigadd(tmp, a1);
+            freebn(tmp);
+        } else {
+            ret = NULL;
+        }
+    }
+
+    if (a != a1) freebn(a1);
+    if (b != b1) freebn(b1);
+
+    return ret;
+}
+
 /*
  * Compute p % mod.
  * The most significant word of mod MUST be non-zero.
@@ -1174,6 +1218,8 @@ Bignum bignum_from_bytes(const unsigned char *data, int nbytes)
     Bignum result;
     int w, i;
 
+    assert(nbytes >= 0 && nbytes < INT_MAX/8);
+
     w = (nbytes + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES; /* bytes->words */
 
     result = newbn(w);
@@ -1189,6 +1235,85 @@ Bignum bignum_from_bytes(const unsigned char *data, int nbytes)
     return result;
 }
 
+Bignum bignum_from_bytes_le(const unsigned char *data, int nbytes)
+{
+    Bignum result;
+    int w, i;
+
+    assert(nbytes >= 0 && nbytes < INT_MAX/8);
+
+    w = (nbytes + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES; /* bytes->words */
+
+    result = newbn(w);
+    for (i = 1; i <= w; i++)
+        result[i] = 0;
+    for (i = 0; i < nbytes; ++i) {
+        unsigned char byte = *data++;
+        result[1 + i / BIGNUM_INT_BYTES] |= byte << (8*i % BIGNUM_INT_BITS);
+    }
+
+    while (result[0] > 1 && result[result[0]] == 0)
+        result[0]--;
+    return result;
+}
+
+Bignum bignum_from_decimal(const char *decimal)
+{
+    Bignum result = copybn(Zero);
+
+    while (*decimal) {
+        Bignum tmp, tmp2;
+
+        if (!isdigit((unsigned char)*decimal)) {
+            freebn(result);
+            return 0;
+        }
+
+        tmp = bigmul(result, Ten);
+        tmp2 = bignum_from_long(*decimal - '0');
+        result = bigadd(tmp, tmp2);
+        freebn(tmp);
+        freebn(tmp2);
+
+        decimal++;
+    }
+
+    return result;
+}
+
+Bignum bignum_random_in_range(const Bignum lower, const Bignum upper)
+{
+    Bignum ret = NULL;
+    unsigned char *bytes;
+    int upper_len = bignum_bitcount(upper);
+    int upper_bytes = upper_len / 8;
+    int upper_bits = upper_len % 8;
+    if (upper_bits) ++upper_bytes;
+
+    bytes = snewn(upper_bytes, unsigned char);
+    do {
+        int i;
+
+        if (ret) freebn(ret);
+
+        for (i = 0; i < upper_bytes; ++i)
+        {
+            bytes[i] = (unsigned char)random_byte();
+        }
+        /* Mask the top to reduce failure rate to 50/50 */
+        if (upper_bits)
+        {
+            bytes[i - 1] &= 0xFF >> (8 - upper_bits);
+        }
+
+        ret = bignum_from_bytes(bytes, upper_bytes);
+    } while (bignum_cmp(ret, lower) < 0 || bignum_cmp(ret, upper) > 0);
+    smemclr(bytes, upper_bytes);
+    sfree(bytes);
+
+    return ret;
+}
+
 /*
  * Read an SSH-1-format bignum from a data buffer. Return the number
  * of bytes consumed, or -1 if there wasn't enough data.
@@ -1250,7 +1375,7 @@ int ssh2_bignum_length(Bignum bn)
  */
 int bignum_byte(Bignum bn, int i)
 {
-    if (i >= (int)(BIGNUM_INT_BYTES * bn[0]))
+    if (i < 0 || i >= (int)(BIGNUM_INT_BYTES * bn[0]))
        return 0;                      /* beyond the end */
     else
        return (bn[i / BIGNUM_INT_BYTES + 1] >>
@@ -1262,7 +1387,7 @@ int bignum_byte(Bignum bn, int i)
  */
 int bignum_bit(Bignum bn, int i)
 {
-    if (i >= (int)(BIGNUM_INT_BITS * bn[0]))
+    if (i < 0 || i >= (int)(BIGNUM_INT_BITS * bn[0]))
        return 0;                      /* beyond the end */
     else
        return (bn[i / BIGNUM_INT_BITS + 1] >> (i % BIGNUM_INT_BITS)) & 1;
@@ -1273,7 +1398,7 @@ int bignum_bit(Bignum bn, int i)
  */
 void bignum_set_bit(Bignum bn, int bitnum, int value)
 {
-    if (bitnum >= (int)(BIGNUM_INT_BITS * bn[0]))
+    if (bitnum < 0 || bitnum >= (int)(BIGNUM_INT_BITS * bn[0]))
        abort();                       /* beyond the end */
     else {
        int v = bitnum / BIGNUM_INT_BITS + 1;
@@ -1309,7 +1434,18 @@ int ssh1_write_bignum(void *data, Bignum bn)
 int bignum_cmp(Bignum a, Bignum b)
 {
     int amax = a[0], bmax = b[0];
-    int i = (amax > bmax ? amax : bmax);
+    int i;
+
+    /* Annoyingly we have two representations of zero */
+    if (amax == 1 && a[amax] == 0)
+        amax = 0;
+    if (bmax == 1 && b[bmax] == 0)
+        bmax = 0;
+
+    assert(amax == 0 || a[amax] != 0);
+    assert(bmax == 0 || b[bmax] != 0);
+
+    i = (amax > bmax ? amax : bmax);
     while (i) {
        BignumInt aval = (i > amax ? 0 : a[i]);
        BignumInt bval = (i > bmax ? 0 : b[i]);
@@ -1331,6 +1467,8 @@ Bignum bignum_rshift(Bignum a, int shift)
     int i, shiftw, shiftb, shiftbb, bits;
     BignumInt ai, ai1;
 
+    assert(shift >= 0);
+
     bits = bignum_bitcount(a) - shift;
     ret = newbn((bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS);
 
@@ -1350,6 +1488,45 @@ Bignum bignum_rshift(Bignum a, int shift)
     return ret;
 }
 
+/*
+ * Left-shift one bignum to form another.
+ */
+Bignum bignum_lshift(Bignum a, int shift)
+{
+    Bignum ret;
+    int bits, shiftWords, shiftBits;
+
+    assert(shift >= 0);
+
+    bits = bignum_bitcount(a) + shift;
+    ret = newbn((bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS);
+    if (!ret) return NULL;
+
+    shiftWords = shift / BIGNUM_INT_BITS;
+    shiftBits = shift % BIGNUM_INT_BITS;
+
+    if (shiftBits == 0)
+    {
+        memcpy(&ret[1 + shiftWords], &a[1], sizeof(BignumInt) * a[0]);
+    }
+    else
+    {
+        int i;
+        BignumInt carry = 0;
+
+        /* Remember that Bignum[0] is length, so add 1 */
+        for (i = shiftWords + 1; i < ((int)a[0]) + shiftWords + 1; ++i)
+        {
+            BignumInt from = a[i - shiftWords];
+            ret[i] = (from << shiftBits) | carry;
+            carry = from >> (BIGNUM_INT_BITS - shiftBits);
+        }
+        if (carry) ret[i] = carry;
+    }
+
+    return ret;
+}
+
 /*
  * Non-modular multiplication and addition.
  */
@@ -1585,6 +1762,8 @@ Bignum bigdiv(Bignum a, Bignum b)
 {
     Bignum q = newbn(a[0]);
     bigdivmod(a, b, NULL, q);
+    while (q[0] > 1 && q[q[0]] == 0)
+        q[0]--;
     return q;
 }
 
@@ -1595,6 +1774,8 @@ Bignum bigmod(Bignum a, Bignum b)
 {
     Bignum r = newbn(b[0]);
     bigdivmod(a, b, r, NULL);
+    while (r[0] > 1 && r[r[0]] == 0)
+        r[0]--;
     return r;
 }
 
@@ -1642,6 +1823,10 @@ Bignum modinv(Bignum number, Bignum modulus)
              * Found a common factor between the inputs, so we cannot
              * return a modular inverse at all.
              */
+            freebn(b);
+            freebn(a);
+            freebn(xp);
+            freebn(x);
             return NULL;
         }
 
@@ -1650,6 +1835,8 @@ Bignum modinv(Bignum number, Bignum modulus)
        bigdivmod(a, b, t, q);
        while (t[0] > 1 && t[t[0]] == 0)
            t[0]--;
+       while (q[0] > 1 && q[q[0]] == 0)
+           q[0]--;
        freebn(a);
        a = b;
        b = t;
@@ -1783,7 +1970,7 @@ char *bignum_decimal(Bignum x)
  * testdata/bignum.py .
  */
 
-void modalfatalbox(char *p, ...)
+void modalfatalbox(const char *p, ...)
 {
     va_list ap;
     fprintf(stderr, "FATAL ERROR: ");