]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshrsa.c
Initial checkin: beta 0.43
[PuTTY.git] / sshrsa.c
1 /*
2  * RSA implementation just sufficient for ssh client-side
3  * initialisation step
4  */
5
6 /*#include <windows.h>
7 #define RSADEBUG
8 #define DLVL 2
9 #include "stel.h"*/
10
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14
15 #include "ssh.h"
16
17 typedef unsigned short *Bignum;
18
19 static unsigned short Zero[1] = { 0 };
20
21 #if defined TESTMODE || defined RSADEBUG
22 #ifndef DLVL
23 #define DLVL 10000
24 #endif
25 #define debug(x) bndebug(#x,x)
26 static int level = 0;
27 static void bndebug(char *name, Bignum b) {
28     int i;
29     int w = 50-level-strlen(name)-5*b[0];
30     if (level >= DLVL)
31         return;
32     if (w < 0) w = 0;
33     dprintf("%*s%s%*s", level, "", name, w, "");
34     for (i=b[0]; i>0; i--)
35         dprintf(" %04x", b[i]);
36     dprintf("\n");
37 }
38 #define dmsg(x) do {if(level<DLVL){dprintf("%*s",level,"");printf x;}} while(0)
39 #define enter(x) do { dmsg(x); level += 4; } while(0)
40 #define leave(x) do { level -= 4; dmsg(x); } while(0)
41 #else
42 #define debug(x)
43 #define dmsg(x)
44 #define enter(x)
45 #define leave(x)
46 #endif
47
48 static Bignum newbn(int length) {
49     Bignum b = malloc((length+1)*sizeof(unsigned short));
50     if (!b)
51         abort();                       /* FIXME */
52     b[0] = length;
53     return b;
54 }
55
56 static void freebn(Bignum b) {
57     free(b);
58 }
59
60 static int msb(Bignum r) {
61     int i;
62     int j;
63     unsigned short n;
64
65     for (i=r[0]; i>0; i--)
66         if (r[i])
67             break;
68
69     j = (i-1)*16;
70     n = r[i];
71     if (n & 0xFF00) j += 8, n >>= 8;
72     if (n & 0x00F0) j += 4, n >>= 4;
73     if (n & 0x000C) j += 2, n >>= 2;
74     if (n & 0x0002) j += 1, n >>= 1;
75
76     return j;
77 }
78
79 static void add(Bignum r1, Bignum r2, Bignum result) {
80     int i;
81     long stuff = 0;
82
83     enter((">add\n"));
84     debug(r1);
85     debug(r2);
86
87     for (i = 1 ;; i++) {
88         if (i <= r1[0])
89             stuff += r1[i];
90         if (i <= r2[0])
91             stuff += r2[i];
92         if (i <= result[0])
93             result[i] = stuff & 0xFFFFU;
94         if (i > r1[0] && i > r2[0] && i >= result[0])
95             break;
96         stuff >>= 16;
97     }
98
99     debug(result);
100     leave(("<add\n"));
101 }
102
103 static void sub(Bignum r1, Bignum r2, Bignum result) {
104     int i;
105     long stuff = 0;
106
107     enter((">sub\n"));
108     debug(r1);
109     debug(r2);
110
111     for (i = 1 ;; i++) {
112         if (i <= r1[0])
113             stuff += r1[i];
114         if (i <= r2[0])
115             stuff -= r2[i];
116         if (i <= result[0])
117             result[i] = stuff & 0xFFFFU;
118         if (i > r1[0] && i > r2[0] && i >= result[0])
119             break;
120         stuff = stuff<0 ? -1 : 0;
121     }
122
123     debug(result);
124     leave(("<sub\n"));
125 }
126
127 static int ge(Bignum r1, Bignum r2) {
128     int i;
129
130     enter((">ge\n"));
131     debug(r1);
132     debug(r2);
133
134     if (r1[0] < r2[0])
135         i = r2[0];
136     else
137         i = r1[0];
138
139     while (i > 0) {
140         unsigned short n1 = (i > r1[0] ? 0 : r1[i]);
141         unsigned short n2 = (i > r2[0] ? 0 : r2[i]);
142
143         if (n1 > n2) {
144             dmsg(("greater\n"));
145             leave(("<ge\n"));
146             return 1;                  /* r1 > r2 */
147         } else if (n1 < n2) {
148             dmsg(("less\n"));
149             leave(("<ge\n"));
150             return 0;                  /* r1 < r2 */
151         }
152
153         i--;
154     }
155
156     dmsg(("equal\n"));
157     leave(("<ge\n"));
158     return 1;                          /* r1 = r2 */
159 }
160
161 static void modmult(Bignum r1, Bignum r2, Bignum modulus, Bignum result) {
162     Bignum temp = newbn(modulus[0]+1);
163     Bignum tmp2 = newbn(modulus[0]+1);
164     int i;
165     int bit, bits, digit, smallbit;
166
167     enter((">modmult\n"));
168     debug(r1);
169     debug(r2);
170     debug(modulus);
171
172     for (i=1; i<=result[0]; i++)
173         result[i] = 0;                 /* result := 0 */
174     for (i=1; i<=temp[0]; i++)
175         temp[i] = (i > r2[0] ? 0 : r2[i]);   /* temp := r2 */
176
177     bits = 1+msb(r1);
178
179     for (bit = 0; bit < bits; bit++) {
180         digit = 1 + bit / 16;
181         smallbit = bit % 16;
182
183         debug(temp);
184         if (digit <= r1[0] && (r1[digit] & (1<<smallbit))) {
185             dmsg(("bit %d\n", bit));
186             add(temp, result, tmp2);
187             if (ge(tmp2, modulus))
188                 sub(tmp2, modulus, result);
189             else
190                 add(tmp2, Zero, result);
191             debug(result);
192         }
193
194         add(temp, temp, tmp2);
195         if (ge(tmp2, modulus))
196             sub(tmp2, modulus, temp);
197         else
198             add(tmp2, Zero, temp);
199     }
200
201     freebn(temp);
202     freebn(tmp2);
203
204     debug(result);
205     leave(("<modmult\n"));
206 }
207
208 static void modpow(Bignum r1, Bignum r2, Bignum modulus, Bignum result) {
209     Bignum temp = newbn(modulus[0]+1);
210     Bignum tmp2 = newbn(modulus[0]+1);
211     int i;
212     int bit, bits, digit, smallbit;
213
214     enter((">modpow\n"));
215     debug(r1);
216     debug(r2);
217     debug(modulus);
218
219     for (i=1; i<=result[0]; i++)
220         result[i] = (i==1);            /* result := 1 */
221     for (i=1; i<=temp[0]; i++)
222         temp[i] = (i > r1[0] ? 0 : r1[i]);   /* temp := r1 */
223
224     bits = 1+msb(r2);
225
226     for (bit = 0; bit < bits; bit++) {
227         digit = 1 + bit / 16;
228         smallbit = bit % 16;
229
230         debug(temp);
231         if (digit <= r2[0] && (r2[digit] & (1<<smallbit))) {
232             dmsg(("bit %d\n", bit));
233             modmult(temp, result, modulus, tmp2);
234             add(tmp2, Zero, result);
235             debug(result);
236         }
237
238         modmult(temp, temp, modulus, tmp2);
239         add(tmp2, Zero, temp);
240     }
241
242     freebn(temp);
243     freebn(tmp2);
244
245     debug(result);
246     leave(("<modpow\n"));
247 }
248
249 int makekey(unsigned char *data, struct RSAKey *result,
250             unsigned char **keystr) {
251     unsigned char *p = data;
252     Bignum bn[2];
253     int i, j;
254     int w, b;
255
256     result->bits = 0;
257     for (i=0; i<4; i++)
258         result->bits = (result->bits << 8) + *p++;
259
260     for (j=0; j<2; j++) {
261
262         w = 0;
263         for (i=0; i<2; i++)
264             w = (w << 8) + *p++;
265
266         result->bytes = b = (w+7)/8;   /* bits -> bytes */
267         w = (w+15)/16;                 /* bits -> words */
268
269         bn[j] = newbn(w);
270
271         if (keystr) *keystr = p;       /* point at key string, second time */
272
273         for (i=1; i<=w; i++)
274             bn[j][i] = 0;
275         for (i=0; i<b; i++) {
276             unsigned char byte = *p++;
277             if ((b-i) & 1)
278                 bn[j][w-i/2] |= byte;
279             else
280                 bn[j][w-i/2] |= byte<<8;
281         }
282
283         debug(bn[j]);
284
285     }
286
287     result->exponent = bn[0];
288     result->modulus = bn[1];
289
290     return p - data;
291 }
292
293 void rsaencrypt(unsigned char *data, int length, struct RSAKey *key) {
294     Bignum b1, b2;
295     int w, i;
296     unsigned char *p;
297
298     debug(key->exponent);
299
300     memmove(data+key->bytes-length, data, length);
301     data[0] = 0;
302     data[1] = 2;
303
304     for (i = 2; i < key->bytes-length-1; i++) {
305         do {
306             data[i] = random_byte();
307         } while (data[i] == 0);
308     }
309     data[key->bytes-length-1] = 0;
310
311     w = (key->bytes+1)/2;
312
313     b1 = newbn(w);
314     b2 = newbn(w);
315
316     p = data;
317     for (i=1; i<=w; i++)
318         b1[i] = 0;
319     for (i=0; i<key->bytes; i++) {
320         unsigned char byte = *p++;
321         if ((key->bytes-i) & 1)
322             b1[w-i/2] |= byte;
323         else
324             b1[w-i/2] |= byte<<8;
325     }
326
327     debug(b1);
328
329     modpow(b1, key->exponent, key->modulus, b2);
330
331     debug(b2);
332
333     p = data;
334     for (i=0; i<key->bytes; i++) {
335         unsigned char b;
336         if (i & 1)
337             b = b2[w-i/2] & 0xFF;
338         else
339             b = b2[w-i/2] >> 8;
340         *p++ = b;
341     }
342
343     freebn(b1);
344     freebn(b2);
345 }
346
347 int rsastr_len(struct RSAKey *key) {
348     Bignum md, ex;
349
350     md = key->modulus;
351     ex = key->exponent;
352     return 4 * (ex[0]+md[0]) + 10;
353 }
354
355 void rsastr_fmt(char *str, struct RSAKey *key) {
356     Bignum md, ex;
357     int len = 0, i;
358
359     md = key->modulus;
360     ex = key->exponent;
361
362     for (i=1; i<=ex[0]; i++) {
363         sprintf(str+len, "%04x", ex[i]);
364         len += strlen(str+len);
365     }
366     str[len++] = '/';
367     for (i=1; i<=md[0]; i++) {
368         sprintf(str+len, "%04x", md[i]);
369         len += strlen(str+len);
370     }
371     str[len] = '\0';
372 }
373
374 #ifdef TESTMODE
375
376 #ifndef NODDY
377 #define p1 10007
378 #define p2 10069
379 #define p3 10177
380 #else
381 #define p1 3
382 #define p2 7
383 #define p3 13
384 #endif
385
386 unsigned short P1[2] = { 1, p1 };
387 unsigned short P2[2] = { 1, p2 };
388 unsigned short P3[2] = { 1, p3 };
389 unsigned short bigmod[5] = { 4, 0, 0, 0, 32768U };
390 unsigned short mod[5] = { 4, 0, 0, 0, 0 };
391 unsigned short a[5] = { 4, 0, 0, 0, 0 };
392 unsigned short b[5] = { 4, 0, 0, 0, 0 };
393 unsigned short c[5] = { 4, 0, 0, 0, 0 };
394 unsigned short One[2] = { 1, 1 };
395 unsigned short Two[2] = { 1, 2 };
396
397 int main(void) {
398     modmult(P1, P2, bigmod, a);   debug(a);
399     modmult(a, P3, bigmod, mod);  debug(mod);
400
401     sub(P1, One, a);              debug(a);
402     sub(P2, One, b);              debug(b);
403     modmult(a, b, bigmod, c);     debug(c);
404     sub(P3, One, a);              debug(a);
405     modmult(a, c, bigmod, b);     debug(b);
406
407     modpow(Two, b, mod, a);       debug(a);
408
409     return 0;
410 }
411
412 #endif