]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshbn.c
Merged SSH1 robustness changes from 0.55 release branch on to trunk.
[PuTTY.git] / sshbn.c
1 /*
2  * Bignum routines for RSA and DH and stuff.
3  */
4
5 #include <stdio.h>
6 #include <assert.h>
7 #include <stdlib.h>
8 #include <string.h>
9
10 #include "misc.h"
11
12 #if defined __GNUC__ && defined __i386__
13 typedef unsigned long BignumInt;
14 typedef unsigned long long BignumDblInt;
15 #define BIGNUM_INT_MASK  0xFFFFFFFFUL
16 #define BIGNUM_TOP_BIT   0x80000000UL
17 #define BIGNUM_INT_BITS  32
18 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2)
19 #define DIVMOD_WORD(q, r, hi, lo, w) \
20     __asm__("div %2" : \
21             "=d" (r), "=a" (q) : \
22             "r" (w), "d" (hi), "a" (lo))
23 #else
24 typedef unsigned short BignumInt;
25 typedef unsigned long BignumDblInt;
26 #define BIGNUM_INT_MASK  0xFFFFU
27 #define BIGNUM_TOP_BIT   0x8000U
28 #define BIGNUM_INT_BITS  16
29 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2)
30 #define DIVMOD_WORD(q, r, hi, lo, w) do { \
31     BignumDblInt n = (((BignumDblInt)hi) << BIGNUM_INT_BITS) | lo; \
32     q = n / w; \
33     r = n % w; \
34 } while (0)
35 #endif
36
37 #define BIGNUM_INT_BYTES (BIGNUM_INT_BITS / 8)
38
39 #define BIGNUM_INTERNAL
40 typedef BignumInt *Bignum;
41
42 #include "ssh.h"
43
44 BignumInt bnZero[1] = { 0 };
45 BignumInt bnOne[2] = { 1, 1 };
46
47 /*
48  * The Bignum format is an array of `BignumInt'. The first
49  * element of the array counts the remaining elements. The
50  * remaining elements express the actual number, base 2^BIGNUM_INT_BITS, _least_
51  * significant digit first. (So it's trivial to extract the bit
52  * with value 2^n for any n.)
53  *
54  * All Bignums in this module are positive. Negative numbers must
55  * be dealt with outside it.
56  *
57  * INVARIANT: the most significant word of any Bignum must be
58  * nonzero.
59  */
60
61 Bignum Zero = bnZero, One = bnOne;
62
63 static Bignum newbn(int length)
64 {
65     Bignum b = snewn(length + 1, BignumInt);
66     if (!b)
67         abort();                       /* FIXME */
68     memset(b, 0, (length + 1) * sizeof(*b));
69     b[0] = length;
70     return b;
71 }
72
73 void bn_restore_invariant(Bignum b)
74 {
75     while (b[0] > 1 && b[b[0]] == 0)
76         b[0]--;
77 }
78
79 Bignum copybn(Bignum orig)
80 {
81     Bignum b = snewn(orig[0] + 1, BignumInt);
82     if (!b)
83         abort();                       /* FIXME */
84     memcpy(b, orig, (orig[0] + 1) * sizeof(*b));
85     return b;
86 }
87
88 void freebn(Bignum b)
89 {
90     /*
91      * Burn the evidence, just in case.
92      */
93     memset(b, 0, sizeof(b[0]) * (b[0] + 1));
94     sfree(b);
95 }
96
97 Bignum bn_power_2(int n)
98 {
99     Bignum ret = newbn(n / BIGNUM_INT_BITS + 1);
100     bignum_set_bit(ret, n, 1);
101     return ret;
102 }
103
104 /*
105  * Compute c = a * b.
106  * Input is in the first len words of a and b.
107  * Result is returned in the first 2*len words of c.
108  */
109 static void internal_mul(BignumInt *a, BignumInt *b,
110                          BignumInt *c, int len)
111 {
112     int i, j;
113     BignumDblInt t;
114
115     for (j = 0; j < 2 * len; j++)
116         c[j] = 0;
117
118     for (i = len - 1; i >= 0; i--) {
119         t = 0;
120         for (j = len - 1; j >= 0; j--) {
121             t += MUL_WORD(a[i], (BignumDblInt) b[j]);
122             t += (BignumDblInt) c[i + j + 1];
123             c[i + j + 1] = (BignumInt) t;
124             t = t >> BIGNUM_INT_BITS;
125         }
126         c[i] = (BignumInt) t;
127     }
128 }
129
130 static void internal_add_shifted(BignumInt *number,
131                                  unsigned n, int shift)
132 {
133     int word = 1 + (shift / BIGNUM_INT_BITS);
134     int bshift = shift % BIGNUM_INT_BITS;
135     BignumDblInt addend;
136
137     addend = (BignumDblInt)n << bshift;
138
139     while (addend) {
140         addend += number[word];
141         number[word] = (BignumInt) addend & BIGNUM_INT_MASK;
142         addend >>= BIGNUM_INT_BITS;
143         word++;
144     }
145 }
146
147 /*
148  * Compute a = a % m.
149  * Input in first alen words of a and first mlen words of m.
150  * Output in first alen words of a
151  * (of which first alen-mlen words will be zero).
152  * The MSW of m MUST have its high bit set.
153  * Quotient is accumulated in the `quotient' array, which is a Bignum
154  * rather than the internal bigendian format. Quotient parts are shifted
155  * left by `qshift' before adding into quot.
156  */
157 static void internal_mod(BignumInt *a, int alen,
158                          BignumInt *m, int mlen,
159                          BignumInt *quot, int qshift)
160 {
161     BignumInt m0, m1;
162     unsigned int h;
163     int i, k;
164
165     m0 = m[0];
166     if (mlen > 1)
167         m1 = m[1];
168     else
169         m1 = 0;
170
171     for (i = 0; i <= alen - mlen; i++) {
172         BignumDblInt t;
173         unsigned int q, r, c, ai1;
174
175         if (i == 0) {
176             h = 0;
177         } else {
178             h = a[i - 1];
179             a[i - 1] = 0;
180         }
181
182         if (i == alen - 1)
183             ai1 = 0;
184         else
185             ai1 = a[i + 1];
186
187         /* Find q = h:a[i] / m0 */
188         DIVMOD_WORD(q, r, h, a[i], m0);
189
190         /* Refine our estimate of q by looking at
191            h:a[i]:a[i+1] / m0:m1 */
192         t = MUL_WORD(m1, q);
193         if (t > ((BignumDblInt) r << BIGNUM_INT_BITS) + ai1) {
194             q--;
195             t -= m1;
196             r = (r + m0) & BIGNUM_INT_MASK;     /* overflow? */
197             if (r >= (BignumDblInt) m0 &&
198                 t > ((BignumDblInt) r << BIGNUM_INT_BITS) + ai1) q--;
199         }
200
201         /* Subtract q * m from a[i...] */
202         c = 0;
203         for (k = mlen - 1; k >= 0; k--) {
204             t = MUL_WORD(q, m[k]);
205             t += c;
206             c = t >> BIGNUM_INT_BITS;
207             if ((BignumInt) t > a[i + k])
208                 c++;
209             a[i + k] -= (BignumInt) t;
210         }
211
212         /* Add back m in case of borrow */
213         if (c != h) {
214             t = 0;
215             for (k = mlen - 1; k >= 0; k--) {
216                 t += m[k];
217                 t += a[i + k];
218                 a[i + k] = (BignumInt) t;
219                 t = t >> BIGNUM_INT_BITS;
220             }
221             q--;
222         }
223         if (quot)
224             internal_add_shifted(quot, q, qshift + BIGNUM_INT_BITS * (alen - mlen - i));
225     }
226 }
227
228 /*
229  * Compute (base ^ exp) % mod.
230  */
231 Bignum modpow(Bignum base_in, Bignum exp, Bignum mod)
232 {
233     BignumInt *a, *b, *n, *m;
234     int mshift;
235     int mlen, i, j;
236     Bignum base, result;
237
238     /*
239      * The most significant word of mod needs to be non-zero. It
240      * should already be, but let's make sure.
241      */
242     assert(mod[mod[0]] != 0);
243
244     /*
245      * Make sure the base is smaller than the modulus, by reducing
246      * it modulo the modulus if not.
247      */
248     base = bigmod(base_in, mod);
249
250     /* Allocate m of size mlen, copy mod to m */
251     /* We use big endian internally */
252     mlen = mod[0];
253     m = snewn(mlen, BignumInt);
254     for (j = 0; j < mlen; j++)
255         m[j] = mod[mod[0] - j];
256
257     /* Shift m left to make msb bit set */
258     for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++)
259         if ((m[0] << mshift) & BIGNUM_TOP_BIT)
260             break;
261     if (mshift) {
262         for (i = 0; i < mlen - 1; i++)
263             m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift));
264         m[mlen - 1] = m[mlen - 1] << mshift;
265     }
266
267     /* Allocate n of size mlen, copy base to n */
268     n = snewn(mlen, BignumInt);
269     i = mlen - base[0];
270     for (j = 0; j < i; j++)
271         n[j] = 0;
272     for (j = 0; j < base[0]; j++)
273         n[i + j] = base[base[0] - j];
274
275     /* Allocate a and b of size 2*mlen. Set a = 1 */
276     a = snewn(2 * mlen, BignumInt);
277     b = snewn(2 * mlen, BignumInt);
278     for (i = 0; i < 2 * mlen; i++)
279         a[i] = 0;
280     a[2 * mlen - 1] = 1;
281
282     /* Skip leading zero bits of exp. */
283     i = 0;
284     j = BIGNUM_INT_BITS-1;
285     while (i < exp[0] && (exp[exp[0] - i] & (1 << j)) == 0) {
286         j--;
287         if (j < 0) {
288             i++;
289             j = BIGNUM_INT_BITS-1;
290         }
291     }
292
293     /* Main computation */
294     while (i < exp[0]) {
295         while (j >= 0) {
296             internal_mul(a + mlen, a + mlen, b, mlen);
297             internal_mod(b, mlen * 2, m, mlen, NULL, 0);
298             if ((exp[exp[0] - i] & (1 << j)) != 0) {
299                 internal_mul(b + mlen, n, a, mlen);
300                 internal_mod(a, mlen * 2, m, mlen, NULL, 0);
301             } else {
302                 BignumInt *t;
303                 t = a;
304                 a = b;
305                 b = t;
306             }
307             j--;
308         }
309         i++;
310         j = BIGNUM_INT_BITS-1;
311     }
312
313     /* Fixup result in case the modulus was shifted */
314     if (mshift) {
315         for (i = mlen - 1; i < 2 * mlen - 1; i++)
316             a[i] = (a[i] << mshift) | (a[i + 1] >> (BIGNUM_INT_BITS - mshift));
317         a[2 * mlen - 1] = a[2 * mlen - 1] << mshift;
318         internal_mod(a, mlen * 2, m, mlen, NULL, 0);
319         for (i = 2 * mlen - 1; i >= mlen; i--)
320             a[i] = (a[i] >> mshift) | (a[i - 1] << (BIGNUM_INT_BITS - mshift));
321     }
322
323     /* Copy result to buffer */
324     result = newbn(mod[0]);
325     for (i = 0; i < mlen; i++)
326         result[result[0] - i] = a[i + mlen];
327     while (result[0] > 1 && result[result[0]] == 0)
328         result[0]--;
329
330     /* Free temporary arrays */
331     for (i = 0; i < 2 * mlen; i++)
332         a[i] = 0;
333     sfree(a);
334     for (i = 0; i < 2 * mlen; i++)
335         b[i] = 0;
336     sfree(b);
337     for (i = 0; i < mlen; i++)
338         m[i] = 0;
339     sfree(m);
340     for (i = 0; i < mlen; i++)
341         n[i] = 0;
342     sfree(n);
343
344     freebn(base);
345
346     return result;
347 }
348
349 /*
350  * Compute (p * q) % mod.
351  * The most significant word of mod MUST be non-zero.
352  * We assume that the result array is the same size as the mod array.
353  */
354 Bignum modmul(Bignum p, Bignum q, Bignum mod)
355 {
356     BignumInt *a, *n, *m, *o;
357     int mshift;
358     int pqlen, mlen, rlen, i, j;
359     Bignum result;
360
361     /* Allocate m of size mlen, copy mod to m */
362     /* We use big endian internally */
363     mlen = mod[0];
364     m = snewn(mlen, BignumInt);
365     for (j = 0; j < mlen; j++)
366         m[j] = mod[mod[0] - j];
367
368     /* Shift m left to make msb bit set */
369     for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++)
370         if ((m[0] << mshift) & BIGNUM_TOP_BIT)
371             break;
372     if (mshift) {
373         for (i = 0; i < mlen - 1; i++)
374             m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift));
375         m[mlen - 1] = m[mlen - 1] << mshift;
376     }
377
378     pqlen = (p[0] > q[0] ? p[0] : q[0]);
379
380     /* Allocate n of size pqlen, copy p to n */
381     n = snewn(pqlen, BignumInt);
382     i = pqlen - p[0];
383     for (j = 0; j < i; j++)
384         n[j] = 0;
385     for (j = 0; j < p[0]; j++)
386         n[i + j] = p[p[0] - j];
387
388     /* Allocate o of size pqlen, copy q to o */
389     o = snewn(pqlen, BignumInt);
390     i = pqlen - q[0];
391     for (j = 0; j < i; j++)
392         o[j] = 0;
393     for (j = 0; j < q[0]; j++)
394         o[i + j] = q[q[0] - j];
395
396     /* Allocate a of size 2*pqlen for result */
397     a = snewn(2 * pqlen, BignumInt);
398
399     /* Main computation */
400     internal_mul(n, o, a, pqlen);
401     internal_mod(a, pqlen * 2, m, mlen, NULL, 0);
402
403     /* Fixup result in case the modulus was shifted */
404     if (mshift) {
405         for (i = 2 * pqlen - mlen - 1; i < 2 * pqlen - 1; i++)
406             a[i] = (a[i] << mshift) | (a[i + 1] >> (BIGNUM_INT_BITS - mshift));
407         a[2 * pqlen - 1] = a[2 * pqlen - 1] << mshift;
408         internal_mod(a, pqlen * 2, m, mlen, NULL, 0);
409         for (i = 2 * pqlen - 1; i >= 2 * pqlen - mlen; i--)
410             a[i] = (a[i] >> mshift) | (a[i - 1] << (BIGNUM_INT_BITS - mshift));
411     }
412
413     /* Copy result to buffer */
414     rlen = (mlen < pqlen * 2 ? mlen : pqlen * 2);
415     result = newbn(rlen);
416     for (i = 0; i < rlen; i++)
417         result[result[0] - i] = a[i + 2 * pqlen - rlen];
418     while (result[0] > 1 && result[result[0]] == 0)
419         result[0]--;
420
421     /* Free temporary arrays */
422     for (i = 0; i < 2 * pqlen; i++)
423         a[i] = 0;
424     sfree(a);
425     for (i = 0; i < mlen; i++)
426         m[i] = 0;
427     sfree(m);
428     for (i = 0; i < pqlen; i++)
429         n[i] = 0;
430     sfree(n);
431     for (i = 0; i < pqlen; i++)
432         o[i] = 0;
433     sfree(o);
434
435     return result;
436 }
437
438 /*
439  * Compute p % mod.
440  * The most significant word of mod MUST be non-zero.
441  * We assume that the result array is the same size as the mod array.
442  * We optionally write out a quotient if `quotient' is non-NULL.
443  * We can avoid writing out the result if `result' is NULL.
444  */
445 static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient)
446 {
447     BignumInt *n, *m;
448     int mshift;
449     int plen, mlen, i, j;
450
451     /* Allocate m of size mlen, copy mod to m */
452     /* We use big endian internally */
453     mlen = mod[0];
454     m = snewn(mlen, BignumInt);
455     for (j = 0; j < mlen; j++)
456         m[j] = mod[mod[0] - j];
457
458     /* Shift m left to make msb bit set */
459     for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++)
460         if ((m[0] << mshift) & BIGNUM_TOP_BIT)
461             break;
462     if (mshift) {
463         for (i = 0; i < mlen - 1; i++)
464             m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift));
465         m[mlen - 1] = m[mlen - 1] << mshift;
466     }
467
468     plen = p[0];
469     /* Ensure plen > mlen */
470     if (plen <= mlen)
471         plen = mlen + 1;
472
473     /* Allocate n of size plen, copy p to n */
474     n = snewn(plen, BignumInt);
475     for (j = 0; j < plen; j++)
476         n[j] = 0;
477     for (j = 1; j <= p[0]; j++)
478         n[plen - j] = p[j];
479
480     /* Main computation */
481     internal_mod(n, plen, m, mlen, quotient, mshift);
482
483     /* Fixup result in case the modulus was shifted */
484     if (mshift) {
485         for (i = plen - mlen - 1; i < plen - 1; i++)
486             n[i] = (n[i] << mshift) | (n[i + 1] >> (BIGNUM_INT_BITS - mshift));
487         n[plen - 1] = n[plen - 1] << mshift;
488         internal_mod(n, plen, m, mlen, quotient, 0);
489         for (i = plen - 1; i >= plen - mlen; i--)
490             n[i] = (n[i] >> mshift) | (n[i - 1] << (BIGNUM_INT_BITS - mshift));
491     }
492
493     /* Copy result to buffer */
494     if (result) {
495         for (i = 1; i <= result[0]; i++) {
496             int j = plen - i;
497             result[i] = j >= 0 ? n[j] : 0;
498         }
499     }
500
501     /* Free temporary arrays */
502     for (i = 0; i < mlen; i++)
503         m[i] = 0;
504     sfree(m);
505     for (i = 0; i < plen; i++)
506         n[i] = 0;
507     sfree(n);
508 }
509
510 /*
511  * Decrement a number.
512  */
513 void decbn(Bignum bn)
514 {
515     int i = 1;
516     while (i < bn[0] && bn[i] == 0)
517         bn[i++] = BIGNUM_INT_MASK;
518     bn[i]--;
519 }
520
521 Bignum bignum_from_bytes(const unsigned char *data, int nbytes)
522 {
523     Bignum result;
524     int w, i;
525
526     w = (nbytes + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES; /* bytes->words */
527
528     result = newbn(w);
529     for (i = 1; i <= w; i++)
530         result[i] = 0;
531     for (i = nbytes; i--;) {
532         unsigned char byte = *data++;
533         result[1 + i / BIGNUM_INT_BYTES] |= byte << (8*i % BIGNUM_INT_BITS);
534     }
535
536     while (result[0] > 1 && result[result[0]] == 0)
537         result[0]--;
538     return result;
539 }
540
541 /*
542  * Read an ssh1-format bignum from a data buffer. Return the number
543  * of bytes consumed, or -1 if there wasn't enough data.
544  */
545 int ssh1_read_bignum(const unsigned char *data, int len, Bignum * result)
546 {
547     const unsigned char *p = data;
548     int i;
549     int w, b;
550
551     if (len < 2)
552         return -1;
553
554     w = 0;
555     for (i = 0; i < 2; i++)
556         w = (w << 8) + *p++;
557     b = (w + 7) / 8;                   /* bits -> bytes */
558
559     if (len < b+2)
560         return -1;
561
562     if (!result)                       /* just return length */
563         return b + 2;
564
565     *result = bignum_from_bytes(p, b);
566
567     return p + b - data;
568 }
569
570 /*
571  * Return the bit count of a bignum, for ssh1 encoding.
572  */
573 int bignum_bitcount(Bignum bn)
574 {
575     int bitcount = bn[0] * BIGNUM_INT_BITS - 1;
576     while (bitcount >= 0
577            && (bn[bitcount / BIGNUM_INT_BITS + 1] >> (bitcount % BIGNUM_INT_BITS)) == 0) bitcount--;
578     return bitcount + 1;
579 }
580
581 /*
582  * Return the byte length of a bignum when ssh1 encoded.
583  */
584 int ssh1_bignum_length(Bignum bn)
585 {
586     return 2 + (bignum_bitcount(bn) + 7) / 8;
587 }
588
589 /*
590  * Return the byte length of a bignum when ssh2 encoded.
591  */
592 int ssh2_bignum_length(Bignum bn)
593 {
594     return 4 + (bignum_bitcount(bn) + 8) / 8;
595 }
596
597 /*
598  * Return a byte from a bignum; 0 is least significant, etc.
599  */
600 int bignum_byte(Bignum bn, int i)
601 {
602     if (i >= BIGNUM_INT_BYTES * bn[0])
603         return 0;                      /* beyond the end */
604     else
605         return (bn[i / BIGNUM_INT_BYTES + 1] >>
606                 ((i % BIGNUM_INT_BYTES)*8)) & 0xFF;
607 }
608
609 /*
610  * Return a bit from a bignum; 0 is least significant, etc.
611  */
612 int bignum_bit(Bignum bn, int i)
613 {
614     if (i >= BIGNUM_INT_BITS * bn[0])
615         return 0;                      /* beyond the end */
616     else
617         return (bn[i / BIGNUM_INT_BITS + 1] >> (i % BIGNUM_INT_BITS)) & 1;
618 }
619
620 /*
621  * Set a bit in a bignum; 0 is least significant, etc.
622  */
623 void bignum_set_bit(Bignum bn, int bitnum, int value)
624 {
625     if (bitnum >= BIGNUM_INT_BITS * bn[0])
626         abort();                       /* beyond the end */
627     else {
628         int v = bitnum / BIGNUM_INT_BITS + 1;
629         int mask = 1 << (bitnum % BIGNUM_INT_BITS);
630         if (value)
631             bn[v] |= mask;
632         else
633             bn[v] &= ~mask;
634     }
635 }
636
637 /*
638  * Write a ssh1-format bignum into a buffer. It is assumed the
639  * buffer is big enough. Returns the number of bytes used.
640  */
641 int ssh1_write_bignum(void *data, Bignum bn)
642 {
643     unsigned char *p = data;
644     int len = ssh1_bignum_length(bn);
645     int i;
646     int bitc = bignum_bitcount(bn);
647
648     *p++ = (bitc >> 8) & 0xFF;
649     *p++ = (bitc) & 0xFF;
650     for (i = len - 2; i--;)
651         *p++ = bignum_byte(bn, i);
652     return len;
653 }
654
655 /*
656  * Compare two bignums. Returns like strcmp.
657  */
658 int bignum_cmp(Bignum a, Bignum b)
659 {
660     int amax = a[0], bmax = b[0];
661     int i = (amax > bmax ? amax : bmax);
662     while (i) {
663         BignumInt aval = (i > amax ? 0 : a[i]);
664         BignumInt bval = (i > bmax ? 0 : b[i]);
665         if (aval < bval)
666             return -1;
667         if (aval > bval)
668             return +1;
669         i--;
670     }
671     return 0;
672 }
673
674 /*
675  * Right-shift one bignum to form another.
676  */
677 Bignum bignum_rshift(Bignum a, int shift)
678 {
679     Bignum ret;
680     int i, shiftw, shiftb, shiftbb, bits;
681     BignumInt ai, ai1;
682
683     bits = bignum_bitcount(a) - shift;
684     ret = newbn((bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS);
685
686     if (ret) {
687         shiftw = shift / BIGNUM_INT_BITS;
688         shiftb = shift % BIGNUM_INT_BITS;
689         shiftbb = BIGNUM_INT_BITS - shiftb;
690
691         ai1 = a[shiftw + 1];
692         for (i = 1; i <= ret[0]; i++) {
693             ai = ai1;
694             ai1 = (i + shiftw + 1 <= a[0] ? a[i + shiftw + 1] : 0);
695             ret[i] = ((ai >> shiftb) | (ai1 << shiftbb)) & BIGNUM_INT_MASK;
696         }
697     }
698
699     return ret;
700 }
701
702 /*
703  * Non-modular multiplication and addition.
704  */
705 Bignum bigmuladd(Bignum a, Bignum b, Bignum addend)
706 {
707     int alen = a[0], blen = b[0];
708     int mlen = (alen > blen ? alen : blen);
709     int rlen, i, maxspot;
710     BignumInt *workspace;
711     Bignum ret;
712
713     /* mlen space for a, mlen space for b, 2*mlen for result */
714     workspace = snewn(mlen * 4, BignumInt);
715     for (i = 0; i < mlen; i++) {
716         workspace[0 * mlen + i] = (mlen - i <= a[0] ? a[mlen - i] : 0);
717         workspace[1 * mlen + i] = (mlen - i <= b[0] ? b[mlen - i] : 0);
718     }
719
720     internal_mul(workspace + 0 * mlen, workspace + 1 * mlen,
721                  workspace + 2 * mlen, mlen);
722
723     /* now just copy the result back */
724     rlen = alen + blen + 1;
725     if (addend && rlen <= addend[0])
726         rlen = addend[0] + 1;
727     ret = newbn(rlen);
728     maxspot = 0;
729     for (i = 1; i <= ret[0]; i++) {
730         ret[i] = (i <= 2 * mlen ? workspace[4 * mlen - i] : 0);
731         if (ret[i] != 0)
732             maxspot = i;
733     }
734     ret[0] = maxspot;
735
736     /* now add in the addend, if any */
737     if (addend) {
738         BignumDblInt carry = 0;
739         for (i = 1; i <= rlen; i++) {
740             carry += (i <= ret[0] ? ret[i] : 0);
741             carry += (i <= addend[0] ? addend[i] : 0);
742             ret[i] = (BignumInt) carry & BIGNUM_INT_MASK;
743             carry >>= BIGNUM_INT_BITS;
744             if (ret[i] != 0 && i > maxspot)
745                 maxspot = i;
746         }
747     }
748     ret[0] = maxspot;
749
750     sfree(workspace);
751     return ret;
752 }
753
754 /*
755  * Non-modular multiplication.
756  */
757 Bignum bigmul(Bignum a, Bignum b)
758 {
759     return bigmuladd(a, b, NULL);
760 }
761
762 /*
763  * Create a bignum which is the bitmask covering another one. That
764  * is, the smallest integer which is >= N and is also one less than
765  * a power of two.
766  */
767 Bignum bignum_bitmask(Bignum n)
768 {
769     Bignum ret = copybn(n);
770     int i;
771     BignumInt j;
772
773     i = ret[0];
774     while (n[i] == 0 && i > 0)
775         i--;
776     if (i <= 0)
777         return ret;                    /* input was zero */
778     j = 1;
779     while (j < n[i])
780         j = 2 * j + 1;
781     ret[i] = j;
782     while (--i > 0)
783         ret[i] = BIGNUM_INT_MASK;
784     return ret;
785 }
786
787 /*
788  * Convert a (max 32-bit) long into a bignum.
789  */
790 Bignum bignum_from_long(unsigned long nn)
791 {
792     Bignum ret;
793     BignumDblInt n = nn;
794
795     ret = newbn(3);
796     ret[1] = (BignumInt)(n & BIGNUM_INT_MASK);
797     ret[2] = (BignumInt)((n >> BIGNUM_INT_BITS) & BIGNUM_INT_MASK);
798     ret[3] = 0;
799     ret[0] = (ret[2]  ? 2 : 1);
800     return ret;
801 }
802
803 /*
804  * Add a long to a bignum.
805  */
806 Bignum bignum_add_long(Bignum number, unsigned long addendx)
807 {
808     Bignum ret = newbn(number[0] + 1);
809     int i, maxspot = 0;
810     BignumDblInt carry = 0, addend = addendx;
811
812     for (i = 1; i <= ret[0]; i++) {
813         carry += addend & BIGNUM_INT_MASK;
814         carry += (i <= number[0] ? number[i] : 0);
815         addend >>= BIGNUM_INT_BITS;
816         ret[i] = (BignumInt) carry & BIGNUM_INT_MASK;
817         carry >>= BIGNUM_INT_BITS;
818         if (ret[i] != 0)
819             maxspot = i;
820     }
821     ret[0] = maxspot;
822     return ret;
823 }
824
825 /*
826  * Compute the residue of a bignum, modulo a (max 16-bit) short.
827  */
828 unsigned short bignum_mod_short(Bignum number, unsigned short modulus)
829 {
830     BignumDblInt mod, r;
831     int i;
832
833     r = 0;
834     mod = modulus;
835     for (i = number[0]; i > 0; i--)
836         r = (r * (BIGNUM_TOP_BIT % mod) * 2 + number[i] % mod) % mod;
837     return (unsigned short) r;
838 }
839
840 #ifdef DEBUG
841 void diagbn(char *prefix, Bignum md)
842 {
843     int i, nibbles, morenibbles;
844     static const char hex[] = "0123456789ABCDEF";
845
846     debug(("%s0x", prefix ? prefix : ""));
847
848     nibbles = (3 + bignum_bitcount(md)) / 4;
849     if (nibbles < 1)
850         nibbles = 1;
851     morenibbles = 4 * md[0] - nibbles;
852     for (i = 0; i < morenibbles; i++)
853         debug(("-"));
854     for (i = nibbles; i--;)
855         debug(("%c",
856                hex[(bignum_byte(md, i / 2) >> (4 * (i % 2))) & 0xF]));
857
858     if (prefix)
859         debug(("\n"));
860 }
861 #endif
862
863 /*
864  * Simple division.
865  */
866 Bignum bigdiv(Bignum a, Bignum b)
867 {
868     Bignum q = newbn(a[0]);
869     bigdivmod(a, b, NULL, q);
870     return q;
871 }
872
873 /*
874  * Simple remainder.
875  */
876 Bignum bigmod(Bignum a, Bignum b)
877 {
878     Bignum r = newbn(b[0]);
879     bigdivmod(a, b, r, NULL);
880     return r;
881 }
882
883 /*
884  * Greatest common divisor.
885  */
886 Bignum biggcd(Bignum av, Bignum bv)
887 {
888     Bignum a = copybn(av);
889     Bignum b = copybn(bv);
890
891     while (bignum_cmp(b, Zero) != 0) {
892         Bignum t = newbn(b[0]);
893         bigdivmod(a, b, t, NULL);
894         while (t[0] > 1 && t[t[0]] == 0)
895             t[0]--;
896         freebn(a);
897         a = b;
898         b = t;
899     }
900
901     freebn(b);
902     return a;
903 }
904
905 /*
906  * Modular inverse, using Euclid's extended algorithm.
907  */
908 Bignum modinv(Bignum number, Bignum modulus)
909 {
910     Bignum a = copybn(modulus);
911     Bignum b = copybn(number);
912     Bignum xp = copybn(Zero);
913     Bignum x = copybn(One);
914     int sign = +1;
915
916     while (bignum_cmp(b, One) != 0) {
917         Bignum t = newbn(b[0]);
918         Bignum q = newbn(a[0]);
919         bigdivmod(a, b, t, q);
920         while (t[0] > 1 && t[t[0]] == 0)
921             t[0]--;
922         freebn(a);
923         a = b;
924         b = t;
925         t = xp;
926         xp = x;
927         x = bigmuladd(q, xp, t);
928         sign = -sign;
929         freebn(t);
930         freebn(q);
931     }
932
933     freebn(b);
934     freebn(a);
935     freebn(xp);
936
937     /* now we know that sign * x == 1, and that x < modulus */
938     if (sign < 0) {
939         /* set a new x to be modulus - x */
940         Bignum newx = newbn(modulus[0]);
941         BignumInt carry = 0;
942         int maxspot = 1;
943         int i;
944
945         for (i = 1; i <= newx[0]; i++) {
946             BignumInt aword = (i <= modulus[0] ? modulus[i] : 0);
947             BignumInt bword = (i <= x[0] ? x[i] : 0);
948             newx[i] = aword - bword - carry;
949             bword = ~bword;
950             carry = carry ? (newx[i] >= bword) : (newx[i] > bword);
951             if (newx[i] != 0)
952                 maxspot = i;
953         }
954         newx[0] = maxspot;
955         freebn(x);
956         x = newx;
957     }
958
959     /* and return. */
960     return x;
961 }
962
963 /*
964  * Render a bignum into decimal. Return a malloced string holding
965  * the decimal representation.
966  */
967 char *bignum_decimal(Bignum x)
968 {
969     int ndigits, ndigit;
970     int i, iszero;
971     BignumDblInt carry;
972     char *ret;
973     BignumInt *workspace;
974
975     /*
976      * First, estimate the number of digits. Since log(10)/log(2)
977      * is just greater than 93/28 (the joys of continued fraction
978      * approximations...) we know that for every 93 bits, we need
979      * at most 28 digits. This will tell us how much to malloc.
980      *
981      * Formally: if x has i bits, that means x is strictly less
982      * than 2^i. Since 2 is less than 10^(28/93), this is less than
983      * 10^(28i/93). We need an integer power of ten, so we must
984      * round up (rounding down might make it less than x again).
985      * Therefore if we multiply the bit count by 28/93, rounding
986      * up, we will have enough digits.
987      */
988     i = bignum_bitcount(x);
989     ndigits = (28 * i + 92) / 93;      /* multiply by 28/93 and round up */
990     ndigits++;                         /* allow for trailing \0 */
991     ret = snewn(ndigits, char);
992
993     /*
994      * Now allocate some workspace to hold the binary form as we
995      * repeatedly divide it by ten. Initialise this to the
996      * big-endian form of the number.
997      */
998     workspace = snewn(x[0], BignumInt);
999     for (i = 0; i < x[0]; i++)
1000         workspace[i] = x[x[0] - i];
1001
1002     /*
1003      * Next, write the decimal number starting with the last digit.
1004      * We use ordinary short division, dividing 10 into the
1005      * workspace.
1006      */
1007     ndigit = ndigits - 1;
1008     ret[ndigit] = '\0';
1009     do {
1010         iszero = 1;
1011         carry = 0;
1012         for (i = 0; i < x[0]; i++) {
1013             carry = (carry << BIGNUM_INT_BITS) + workspace[i];
1014             workspace[i] = (BignumInt) (carry / 10);
1015             if (workspace[i])
1016                 iszero = 0;
1017             carry %= 10;
1018         }
1019         ret[--ndigit] = (char) (carry + '0');
1020     } while (!iszero);
1021
1022     /*
1023      * There's a chance we've fallen short of the start of the
1024      * string. Correct if so.
1025      */
1026     if (ndigit > 0)
1027         memmove(ret, ret + ndigit, ndigits - ndigit);
1028
1029     /*
1030      * Done.
1031      */
1032     sfree(workspace);
1033     return ret;
1034 }