]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshbn.c
Remove some spurious #includes
[PuTTY.git] / sshbn.c
1 /*
2  * Bignum routines for RSA and DH and stuff.
3  */
4
5 #include <stdio.h>
6 #include <stdlib.h>
7 #include <string.h>
8
9 #include "ssh.h"
10
11 unsigned short bnZero[1] = { 0 };
12 unsigned short bnOne[2] = { 1, 1 };
13
14 Bignum Zero = bnZero, One = bnOne;
15
16 Bignum newbn(int length) {
17     Bignum b = malloc((length+1)*sizeof(unsigned short));
18     if (!b)
19         abort();                       /* FIXME */
20     memset(b, 0, (length+1)*sizeof(*b));
21     b[0] = length;
22     return b;
23 }
24
25 Bignum copybn(Bignum orig) {
26     Bignum b = malloc((orig[0]+1)*sizeof(unsigned short));
27     if (!b)
28         abort();                       /* FIXME */
29     memcpy(b, orig, (orig[0]+1)*sizeof(*b));
30     return b;
31 }
32
33 void freebn(Bignum b) {
34     /*
35      * Burn the evidence, just in case.
36      */
37     memset(b, 0, sizeof(b[0]) * (b[0] + 1));
38     free(b);
39 }
40
41 /*
42  * Compute c = a * b.
43  * Input is in the first len words of a and b.
44  * Result is returned in the first 2*len words of c.
45  */
46 static void bigmul(unsigned short *a, unsigned short *b, unsigned short *c,
47                    int len)
48 {
49     int i, j;
50     unsigned long ai, t;
51
52     for (j = len - 1; j >= 0; j--)
53         c[j+len] = 0;
54
55     for (i = len - 1; i >= 0; i--) {
56         ai = a[i];
57         t = 0;
58         for (j = len - 1; j >= 0; j--) {
59             t += ai * (unsigned long) b[j];
60             t += (unsigned long) c[i+j+1];
61             c[i+j+1] = (unsigned short)t;
62             t = t >> 16;
63         }
64         c[i] = (unsigned short)t;
65     }
66 }
67
68 /*
69  * Compute a = a % m.
70  * Input in first len2 words of a and first len words of m.
71  * Output in first len2 words of a
72  * (of which first len2-len words will be zero).
73  * The MSW of m MUST have its high bit set.
74  */
75 static void bigmod(unsigned short *a, unsigned short *m,
76                    int len, int len2)
77 {
78     unsigned short m0, m1;
79     unsigned int h;
80     int i, k;
81
82     /* Special case for len == 1 */
83     if (len == 1) {
84         a[1] = (((long) a[0] << 16) + a[1]) % m[0];
85         a[0] = 0;
86         return;
87     }
88
89     m0 = m[0];
90     m1 = m[1];
91
92     for (i = 0; i <= len2-len; i++) {
93         unsigned long t;
94         unsigned int q, r, c;
95
96         if (i == 0) {
97             h = 0;
98         } else {
99             h = a[i-1];
100             a[i-1] = 0;
101         }
102
103         /* Find q = h:a[i] / m0 */
104         t = ((unsigned long) h << 16) + a[i];
105         q = t / m0;
106         r = t % m0;
107
108         /* Refine our estimate of q by looking at
109          h:a[i]:a[i+1] / m0:m1 */
110         t = (long) m1 * (long) q;
111         if (t > ((unsigned long) r << 16) + a[i+1]) {
112             q--;
113             t -= m1;
114             r = (r + m0) & 0xffff; /* overflow? */
115             if (r >= (unsigned long)m0 &&
116                 t > ((unsigned long) r << 16) + a[i+1])
117                 q--;
118         }
119
120         /* Substract q * m from a[i...] */
121         c = 0;
122         for (k = len - 1; k >= 0; k--) {
123             t = (long) q * (long) m[k];
124             t += c;
125             c = t >> 16;
126             if ((unsigned short) t > a[i+k]) c++;
127             a[i+k] -= (unsigned short) t;
128         }
129
130         /* Add back m in case of borrow */
131         if (c != h) {
132             t = 0;
133             for (k = len - 1; k >= 0; k--) {
134                 t += m[k];
135                 t += a[i+k];
136                 a[i+k] = (unsigned short)t;
137                 t = t >> 16;
138             }
139         }
140     }
141 }
142
143 /*
144  * Compute (base ^ exp) % mod.
145  * The base MUST be smaller than the modulus.
146  * The most significant word of mod MUST be non-zero.
147  * We assume that the result array is the same size as the mod array.
148  */
149 void modpow(Bignum base, Bignum exp, Bignum mod, Bignum result)
150 {
151     unsigned short *a, *b, *n, *m;
152     int mshift;
153     int mlen, i, j;
154
155     /* Allocate m of size mlen, copy mod to m */
156     /* We use big endian internally */
157     mlen = mod[0];
158     m = malloc(mlen * sizeof(unsigned short));
159     for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j];
160
161     /* Shift m left to make msb bit set */
162     for (mshift = 0; mshift < 15; mshift++)
163         if ((m[0] << mshift) & 0x8000) break;
164     if (mshift) {
165         for (i = 0; i < mlen - 1; i++)
166             m[i] = (m[i] << mshift) | (m[i+1] >> (16-mshift));
167         m[mlen-1] = m[mlen-1] << mshift;
168     }
169
170     /* Allocate n of size mlen, copy base to n */
171     n = malloc(mlen * sizeof(unsigned short));
172     i = mlen - base[0];
173     for (j = 0; j < i; j++) n[j] = 0;
174     for (j = 0; j < base[0]; j++) n[i+j] = base[base[0] - j];
175
176     /* Allocate a and b of size 2*mlen. Set a = 1 */
177     a = malloc(2 * mlen * sizeof(unsigned short));
178     b = malloc(2 * mlen * sizeof(unsigned short));
179     for (i = 0; i < 2*mlen; i++) a[i] = 0;
180     a[2*mlen-1] = 1;
181
182     /* Skip leading zero bits of exp. */
183     i = 0; j = 15;
184     while (i < exp[0] && (exp[exp[0] - i] & (1 << j)) == 0) {
185         j--;
186         if (j < 0) { i++; j = 15; }
187     }
188
189     /* Main computation */
190     while (i < exp[0]) {
191         while (j >= 0) {
192             bigmul(a + mlen, a + mlen, b, mlen);
193             bigmod(b, m, mlen, mlen*2);
194             if ((exp[exp[0] - i] & (1 << j)) != 0) {
195                 bigmul(b + mlen, n, a, mlen);
196                 bigmod(a, m, mlen, mlen*2);
197             } else {
198                 unsigned short *t;
199                 t = a;  a = b;  b = t;
200             }
201             j--;
202         }
203         i++; j = 15;
204     }
205
206     /* Fixup result in case the modulus was shifted */
207     if (mshift) {
208         for (i = mlen - 1; i < 2*mlen - 1; i++)
209             a[i] = (a[i] << mshift) | (a[i+1] >> (16-mshift));
210         a[2*mlen-1] = a[2*mlen-1] << mshift;
211         bigmod(a, m, mlen, mlen*2);
212         for (i = 2*mlen - 1; i >= mlen; i--)
213             a[i] = (a[i] >> mshift) | (a[i-1] << (16-mshift));
214     }
215
216     /* Copy result to buffer */
217     for (i = 0; i < mlen; i++)
218         result[result[0] - i] = a[i+mlen];
219
220     /* Free temporary arrays */
221     for (i = 0; i < 2*mlen; i++) a[i] = 0; free(a);
222     for (i = 0; i < 2*mlen; i++) b[i] = 0; free(b);
223     for (i = 0; i < mlen; i++) m[i] = 0; free(m);
224     for (i = 0; i < mlen; i++) n[i] = 0; free(n);
225 }
226
227 /*
228  * Compute (p * q) % mod.
229  * The most significant word of mod MUST be non-zero.
230  * We assume that the result array is the same size as the mod array.
231  */
232 void modmul(Bignum p, Bignum q, Bignum mod, Bignum result)
233 {
234     unsigned short *a, *n, *m, *o;
235     int mshift;
236     int pqlen, mlen, i, j;
237
238     /* Allocate m of size mlen, copy mod to m */
239     /* We use big endian internally */
240     mlen = mod[0];
241     m = malloc(mlen * sizeof(unsigned short));
242     for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j];
243
244     /* Shift m left to make msb bit set */
245     for (mshift = 0; mshift < 15; mshift++)
246         if ((m[0] << mshift) & 0x8000) break;
247     if (mshift) {
248         for (i = 0; i < mlen - 1; i++)
249             m[i] = (m[i] << mshift) | (m[i+1] >> (16-mshift));
250         m[mlen-1] = m[mlen-1] << mshift;
251     }
252
253     pqlen = (p[0] > q[0] ? p[0] : q[0]);
254
255     /* Allocate n of size pqlen, copy p to n */
256     n = malloc(pqlen * sizeof(unsigned short));
257     i = pqlen - p[0];
258     for (j = 0; j < i; j++) n[j] = 0;
259     for (j = 0; j < p[0]; j++) n[i+j] = p[p[0] - j];
260
261     /* Allocate o of size pqlen, copy q to o */
262     o = malloc(pqlen * sizeof(unsigned short));
263     i = pqlen - q[0];
264     for (j = 0; j < i; j++) o[j] = 0;
265     for (j = 0; j < q[0]; j++) o[i+j] = q[q[0] - j];
266
267     /* Allocate a of size 2*pqlen for result */
268     a = malloc(2 * pqlen * sizeof(unsigned short));
269
270     /* Main computation */
271     bigmul(n, o, a, pqlen);
272     bigmod(a, m, mlen, 2*pqlen);
273
274     /* Fixup result in case the modulus was shifted */
275     if (mshift) {
276         for (i = 2*pqlen - mlen - 1; i < 2*pqlen - 1; i++)
277             a[i] = (a[i] << mshift) | (a[i+1] >> (16-mshift));
278         a[2*pqlen-1] = a[2*pqlen-1] << mshift;
279         bigmod(a, m, mlen, pqlen*2);
280         for (i = 2*pqlen - 1; i >= 2*pqlen - mlen; i--)
281             a[i] = (a[i] >> mshift) | (a[i-1] << (16-mshift));
282     }
283
284     /* Copy result to buffer */
285     for (i = 0; i < mlen; i++)
286         result[result[0] - i] = a[i+2*pqlen-mlen];
287
288     /* Free temporary arrays */
289     for (i = 0; i < 2*pqlen; i++) a[i] = 0; free(a);
290     for (i = 0; i < mlen; i++) m[i] = 0; free(m);
291     for (i = 0; i < pqlen; i++) n[i] = 0; free(n);
292     for (i = 0; i < pqlen; i++) o[i] = 0; free(o);
293 }
294
295 /*
296  * Decrement a number.
297  */
298 void decbn(Bignum bn) {
299     int i = 1;
300     while (i < bn[0] && bn[i] == 0)
301         bn[i++] = 0xFFFF;
302     bn[i]--;
303 }
304
305 /*
306  * Read an ssh1-format bignum from a data buffer. Return the number
307  * of bytes consumed.
308  */
309 int ssh1_read_bignum(unsigned char *data, Bignum *result) {
310     unsigned char *p = data;
311     Bignum bn;
312     int i;
313     int w, b;
314
315     w = 0;
316     for (i=0; i<2; i++)
317         w = (w << 8) + *p++;
318
319     b = (w+7)/8;                       /* bits -> bytes */
320     w = (w+15)/16;                     /* bits -> words */
321
322     if (!result)                       /* just return length */
323         return b + 2;
324
325     bn = newbn(w);
326
327     for (i=1; i<=w; i++)
328         bn[i] = 0;
329     for (i=b; i-- ;) {
330         unsigned char byte = *p++;
331         if (i & 1)
332             bn[1+i/2] |= byte<<8;
333         else
334             bn[1+i/2] |= byte;
335     }
336
337     *result = bn;
338
339     return p - data;
340 }
341
342 /*
343  * Return the bit count of a bignum, for ssh1 encoding.
344  */
345 int ssh1_bignum_bitcount(Bignum bn) {
346     int bitcount = bn[0] * 16 - 1;
347
348     while (bitcount >= 0 && (bn[bitcount/16+1] >> (bitcount % 16)) == 0)
349         bitcount--;
350     return bitcount + 1;
351 }
352
353 /*
354  * Return the byte length of a bignum when ssh1 encoded.
355  */
356 int ssh1_bignum_length(Bignum bn) {
357     return 2 + (ssh1_bignum_bitcount(bn)+7)/8;
358 }
359
360 /*
361  * Return a byte from a bignum; 0 is least significant, etc.
362  */
363 int bignum_byte(Bignum bn, int i) {
364     if (i >= 2*bn[0])
365         return 0;                      /* beyond the end */
366     else if (i & 1)
367         return (bn[i/2+1] >> 8) & 0xFF;
368     else
369         return (bn[i/2+1]     ) & 0xFF;
370 }
371
372 /*
373  * Write a ssh1-format bignum into a buffer. It is assumed the
374  * buffer is big enough. Returns the number of bytes used.
375  */
376 int ssh1_write_bignum(void *data, Bignum bn) {
377     unsigned char *p = data;
378     int len = ssh1_bignum_length(bn);
379     int i;
380     int bitc = ssh1_bignum_bitcount(bn);
381
382     *p++ = (bitc >> 8) & 0xFF;
383     *p++ = (bitc     ) & 0xFF;
384     for (i = len-2; i-- ;)
385         *p++ = bignum_byte(bn, i);
386     return len;
387 }