]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshbn.c
RSA key authentication in ssh1 works; SSH2 is nearly there
[PuTTY.git] / sshbn.c
1 /*
2  * Bignum routines for RSA and DH and stuff.
3  */
4
5 #include <stdio.h>
6 #include <stdlib.h>
7 #include <string.h>
8
9 #include <stdio.h> /* FIXME */
10 #include <stdarg.h> /* FIXME */
11 #include <windows.h> /* FIXME */
12 #include "putty.h" /* FIXME */
13
14 #include "ssh.h"
15
16 unsigned short bnZero[1] = { 0 };
17 unsigned short bnOne[2] = { 1, 1 };
18
19 Bignum Zero = bnZero, One = bnOne;
20
21 Bignum newbn(int length) {
22     Bignum b = malloc((length+1)*sizeof(unsigned short));
23     if (!b)
24         abort();                       /* FIXME */
25     memset(b, 0, (length+1)*sizeof(*b));
26     b[0] = length;
27     return b;
28 }
29
30 Bignum copybn(Bignum orig) {
31     Bignum b = malloc((orig[0]+1)*sizeof(unsigned short));
32     if (!b)
33         abort();                       /* FIXME */
34     memcpy(b, orig, (orig[0]+1)*sizeof(*b));
35     return b;
36 }
37
38 void freebn(Bignum b) {
39     /*
40      * Burn the evidence, just in case.
41      */
42     memset(b, 0, sizeof(b[0]) * (b[0] + 1));
43     free(b);
44 }
45
46 /*
47  * Compute c = a * b.
48  * Input is in the first len words of a and b.
49  * Result is returned in the first 2*len words of c.
50  */
51 static void bigmul(unsigned short *a, unsigned short *b, unsigned short *c,
52                    int len)
53 {
54     int i, j;
55     unsigned long ai, t;
56
57     for (j = len - 1; j >= 0; j--)
58         c[j+len] = 0;
59
60     for (i = len - 1; i >= 0; i--) {
61         ai = a[i];
62         t = 0;
63         for (j = len - 1; j >= 0; j--) {
64             t += ai * (unsigned long) b[j];
65             t += (unsigned long) c[i+j+1];
66             c[i+j+1] = (unsigned short)t;
67             t = t >> 16;
68         }
69         c[i] = (unsigned short)t;
70     }
71 }
72
73 /*
74  * Compute a = a % m.
75  * Input in first len2 words of a and first len words of m.
76  * Output in first len2 words of a
77  * (of which first len2-len words will be zero).
78  * The MSW of m MUST have its high bit set.
79  */
80 static void bigmod(unsigned short *a, unsigned short *m,
81                    int len, int len2)
82 {
83     unsigned short m0, m1;
84     unsigned int h;
85     int i, k;
86
87     /* Special case for len == 1 */
88     if (len == 1) {
89         a[1] = (((long) a[0] << 16) + a[1]) % m[0];
90         a[0] = 0;
91         return;
92     }
93
94     m0 = m[0];
95     m1 = m[1];
96
97     for (i = 0; i <= len2-len; i++) {
98         unsigned long t;
99         unsigned int q, r, c;
100
101         if (i == 0) {
102             h = 0;
103         } else {
104             h = a[i-1];
105             a[i-1] = 0;
106         }
107
108         /* Find q = h:a[i] / m0 */
109         t = ((unsigned long) h << 16) + a[i];
110         q = t / m0;
111         r = t % m0;
112
113         /* Refine our estimate of q by looking at
114          h:a[i]:a[i+1] / m0:m1 */
115         t = (long) m1 * (long) q;
116         if (t > ((unsigned long) r << 16) + a[i+1]) {
117             q--;
118             t -= m1;
119             r = (r + m0) & 0xffff; /* overflow? */
120             if (r >= (unsigned long)m0 &&
121                 t > ((unsigned long) r << 16) + a[i+1])
122                 q--;
123         }
124
125         /* Substract q * m from a[i...] */
126         c = 0;
127         for (k = len - 1; k >= 0; k--) {
128             t = (long) q * (long) m[k];
129             t += c;
130             c = t >> 16;
131             if ((unsigned short) t > a[i+k]) c++;
132             a[i+k] -= (unsigned short) t;
133         }
134
135         /* Add back m in case of borrow */
136         if (c != h) {
137             t = 0;
138             for (k = len - 1; k >= 0; k--) {
139                 t += m[k];
140                 t += a[i+k];
141                 a[i+k] = (unsigned short)t;
142                 t = t >> 16;
143             }
144         }
145     }
146 }
147
148 /*
149  * Compute (base ^ exp) % mod.
150  * The base MUST be smaller than the modulus.
151  * The most significant word of mod MUST be non-zero.
152  * We assume that the result array is the same size as the mod array.
153  */
154 void modpow(Bignum base, Bignum exp, Bignum mod, Bignum result)
155 {
156     unsigned short *a, *b, *n, *m;
157     int mshift;
158     int mlen, i, j;
159
160     /* Allocate m of size mlen, copy mod to m */
161     /* We use big endian internally */
162     mlen = mod[0];
163     m = malloc(mlen * sizeof(unsigned short));
164     for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j];
165
166     /* Shift m left to make msb bit set */
167     for (mshift = 0; mshift < 15; mshift++)
168         if ((m[0] << mshift) & 0x8000) break;
169     if (mshift) {
170         for (i = 0; i < mlen - 1; i++)
171             m[i] = (m[i] << mshift) | (m[i+1] >> (16-mshift));
172         m[mlen-1] = m[mlen-1] << mshift;
173     }
174
175     /* Allocate n of size mlen, copy base to n */
176     n = malloc(mlen * sizeof(unsigned short));
177     i = mlen - base[0];
178     for (j = 0; j < i; j++) n[j] = 0;
179     for (j = 0; j < base[0]; j++) n[i+j] = base[base[0] - j];
180
181     /* Allocate a and b of size 2*mlen. Set a = 1 */
182     a = malloc(2 * mlen * sizeof(unsigned short));
183     b = malloc(2 * mlen * sizeof(unsigned short));
184     for (i = 0; i < 2*mlen; i++) a[i] = 0;
185     a[2*mlen-1] = 1;
186
187     /* Skip leading zero bits of exp. */
188     i = 0; j = 15;
189     while (i < exp[0] && (exp[exp[0] - i] & (1 << j)) == 0) {
190         j--;
191         if (j < 0) { i++; j = 15; }
192     }
193
194     /* Main computation */
195     while (i < exp[0]) {
196         while (j >= 0) {
197             bigmul(a + mlen, a + mlen, b, mlen);
198             bigmod(b, m, mlen, mlen*2);
199             if ((exp[exp[0] - i] & (1 << j)) != 0) {
200                 bigmul(b + mlen, n, a, mlen);
201                 bigmod(a, m, mlen, mlen*2);
202             } else {
203                 unsigned short *t;
204                 t = a;  a = b;  b = t;
205             }
206             j--;
207         }
208         i++; j = 15;
209     }
210
211     /* Fixup result in case the modulus was shifted */
212     if (mshift) {
213         for (i = mlen - 1; i < 2*mlen - 1; i++)
214             a[i] = (a[i] << mshift) | (a[i+1] >> (16-mshift));
215         a[2*mlen-1] = a[2*mlen-1] << mshift;
216         bigmod(a, m, mlen, mlen*2);
217         for (i = 2*mlen - 1; i >= mlen; i--)
218             a[i] = (a[i] >> mshift) | (a[i-1] << (16-mshift));
219     }
220
221     /* Copy result to buffer */
222     for (i = 0; i < mlen; i++)
223         result[result[0] - i] = a[i+mlen];
224
225     /* Free temporary arrays */
226     for (i = 0; i < 2*mlen; i++) a[i] = 0; free(a);
227     for (i = 0; i < 2*mlen; i++) b[i] = 0; free(b);
228     for (i = 0; i < mlen; i++) m[i] = 0; free(m);
229     for (i = 0; i < mlen; i++) n[i] = 0; free(n);
230 }
231
232 /*
233  * Compute (p * q) % mod.
234  * The most significant word of mod MUST be non-zero.
235  * We assume that the result array is the same size as the mod array.
236  */
237 void modmul(Bignum p, Bignum q, Bignum mod, Bignum result)
238 {
239     unsigned short *a, *n, *m, *o;
240     int mshift;
241     int pqlen, mlen, i, j;
242
243     /* Allocate m of size mlen, copy mod to m */
244     /* We use big endian internally */
245     mlen = mod[0];
246     m = malloc(mlen * sizeof(unsigned short));
247     for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j];
248
249     /* Shift m left to make msb bit set */
250     for (mshift = 0; mshift < 15; mshift++)
251         if ((m[0] << mshift) & 0x8000) break;
252     if (mshift) {
253         for (i = 0; i < mlen - 1; i++)
254             m[i] = (m[i] << mshift) | (m[i+1] >> (16-mshift));
255         m[mlen-1] = m[mlen-1] << mshift;
256     }
257
258     pqlen = (p[0] > q[0] ? p[0] : q[0]);
259
260     /* Allocate n of size pqlen, copy p to n */
261     n = malloc(pqlen * sizeof(unsigned short));
262     i = pqlen - p[0];
263     for (j = 0; j < i; j++) n[j] = 0;
264     for (j = 0; j < p[0]; j++) n[i+j] = p[p[0] - j];
265
266     /* Allocate o of size pqlen, copy q to o */
267     o = malloc(pqlen * sizeof(unsigned short));
268     i = pqlen - q[0];
269     for (j = 0; j < i; j++) o[j] = 0;
270     for (j = 0; j < q[0]; j++) o[i+j] = q[q[0] - j];
271
272     /* Allocate a of size 2*pqlen for result */
273     a = malloc(2 * pqlen * sizeof(unsigned short));
274
275     /* Main computation */
276     bigmul(n, o, a, pqlen);
277     bigmod(a, m, mlen, 2*pqlen);
278
279     /* Fixup result in case the modulus was shifted */
280     if (mshift) {
281         for (i = 2*pqlen - mlen - 1; i < 2*pqlen - 1; i++)
282             a[i] = (a[i] << mshift) | (a[i+1] >> (16-mshift));
283         a[2*pqlen-1] = a[2*pqlen-1] << mshift;
284         bigmod(a, m, mlen, pqlen*2);
285         for (i = 2*pqlen - 1; i >= 2*pqlen - mlen; i--)
286             a[i] = (a[i] >> mshift) | (a[i-1] << (16-mshift));
287     }
288
289     /* Copy result to buffer */
290     for (i = 0; i < mlen; i++)
291         result[result[0] - i] = a[i+2*pqlen-mlen];
292
293     /* Free temporary arrays */
294     for (i = 0; i < 2*pqlen; i++) a[i] = 0; free(a);
295     for (i = 0; i < mlen; i++) m[i] = 0; free(m);
296     for (i = 0; i < pqlen; i++) n[i] = 0; free(n);
297     for (i = 0; i < pqlen; i++) o[i] = 0; free(o);
298 }
299
300 /*
301  * Decrement a number.
302  */
303 void decbn(Bignum bn) {
304     int i = 1;
305     while (i < bn[0] && bn[i] == 0)
306         bn[i++] = 0xFFFF;
307     bn[i]--;
308 }
309
310 /*
311  * Read an ssh1-format bignum from a data buffer. Return the number
312  * of bytes consumed.
313  */
314 int ssh1_read_bignum(unsigned char *data, Bignum *result) {
315     unsigned char *p = data;
316     Bignum bn;
317     int i;
318     int w, b;
319
320     w = 0;
321     for (i=0; i<2; i++)
322         w = (w << 8) + *p++;
323
324     b = (w+7)/8;                       /* bits -> bytes */
325     w = (w+15)/16;                     /* bits -> words */
326
327     bn = newbn(w);
328
329     for (i=1; i<=w; i++)
330         bn[i] = 0;
331     for (i=b; i-- ;) {
332         unsigned char byte = *p++;
333         if (i & 1)
334             bn[1+i/2] |= byte<<8;
335         else
336             bn[1+i/2] |= byte;
337     }
338
339     *result = bn;
340
341     return p - data;
342 }