]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
Fixes to memory management in the elliptic curve code.
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  * NOTE: Only curves on prime field are handled by the maths functions
5  */
6
7 /*
8  * References:
9  *
10  * Elliptic curves in SSH are specified in RFC 5656:
11  *   http://tools.ietf.org/html/rfc5656
12  *
13  * That specification delegates details of public key formatting and a
14  * lot of underlying mechanism to SEC 1:
15  *   http://www.secg.org/sec1-v2.pdf
16  */
17
18 #include <stdlib.h>
19 #include <assert.h>
20
21 #include "ssh.h"
22
23 /* ----------------------------------------------------------------------
24  * Elliptic curve definitions
25  */
26
27 static int initialise_curve(struct ec_curve *curve, int bits, unsigned char *p,
28                             unsigned char *a, unsigned char *b,
29                             unsigned char *n, unsigned char *Gx,
30                             unsigned char *Gy)
31 {
32     int length = bits / 8;
33     if (bits % 8) ++length;
34
35     curve->fieldBits = bits;
36     curve->p = bignum_from_bytes(p, length);
37     if (!curve->p) goto error;
38
39     /* Curve co-efficients */
40     curve->a = bignum_from_bytes(a, length);
41     if (!curve->a) goto error;
42     curve->b = bignum_from_bytes(b, length);
43     if (!curve->b) goto error;
44
45     /* Group order and generator */
46     curve->n = bignum_from_bytes(n, length);
47     if (!curve->n) goto error;
48     curve->G.x = bignum_from_bytes(Gx, length);
49     if (!curve->G.x) goto error;
50     curve->G.y = bignum_from_bytes(Gy, length);
51     if (!curve->G.y) goto error;
52     curve->G.curve = curve;
53     curve->G.infinity = 0;
54
55     return 1;
56   error:
57     if (curve->p) freebn(curve->p);
58     if (curve->a) freebn(curve->a);
59     if (curve->b) freebn(curve->b);
60     if (curve->n) freebn(curve->n);
61     if (curve->G.x) freebn(curve->G.x);
62     return 0;
63 }
64
65 unsigned char nistp256_oid[] = {0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07};
66 int nistp256_oid_len = 8;
67 unsigned char nistp384_oid[] = {0x2b, 0x81, 0x04, 0x00, 0x22};
68 int nistp384_oid_len = 5;
69 unsigned char nistp521_oid[] = {0x2b, 0x81, 0x04, 0x00, 0x23};
70 int nistp521_oid_len = 5;
71
72 struct ec_curve *ec_p256(void)
73 {
74     static struct ec_curve curve = { 0 };
75     static unsigned char initialised = 0;
76
77     if (!initialised)
78     {
79         unsigned char p[] = {
80             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
81             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
82             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
83             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
84         };
85         unsigned char a[] = {
86             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
87             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
88             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
89             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
90         };
91         unsigned char b[] = {
92             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
93             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
94             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
95             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
96         };
97         unsigned char n[] = {
98             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
99             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
100             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
101             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
102         };
103         unsigned char Gx[] = {
104             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
105             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
106             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
107             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
108         };
109         unsigned char Gy[] = {
110             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
111             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
112             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
113             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
114         };
115
116         if (!initialise_curve(&curve, 256, p, a, b, n, Gx, Gy)) {
117             return NULL;
118         }
119
120         /* Now initialised, no need to do it again */
121         initialised = 1;
122     }
123
124     return &curve;
125 }
126
127 struct ec_curve *ec_p384(void)
128 {
129     static struct ec_curve curve = { 0 };
130     static unsigned char initialised = 0;
131
132     if (!initialised)
133     {
134         unsigned char p[] = {
135             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
136             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
137             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
138             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
139             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
140             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
141         };
142         unsigned char a[] = {
143             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
144             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
145             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
146             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
147             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
148             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
149         };
150         unsigned char b[] = {
151             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
152             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
153             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
154             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
155             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
156             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
157         };
158         unsigned char n[] = {
159             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
160             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
161             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
162             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
163             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
164             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
165         };
166         unsigned char Gx[] = {
167             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
168             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
169             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
170             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
171             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
172             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
173         };
174         unsigned char Gy[] = {
175             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
176             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
177             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
178             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
179             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
180             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
181         };
182
183         if (!initialise_curve(&curve, 384, p, a, b, n, Gx, Gy)) {
184             return NULL;
185         }
186
187         /* Now initialised, no need to do it again */
188         initialised = 1;
189     }
190
191     return &curve;
192 }
193
194 struct ec_curve *ec_p521(void)
195 {
196     static struct ec_curve curve = { 0 };
197     static unsigned char initialised = 0;
198
199     if (!initialised)
200     {
201         unsigned char p[] = {
202             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
203             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
204             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
205             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
206             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
207             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
208             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
209             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
210             0xff, 0xff
211         };
212         unsigned char a[] = {
213             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
214             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
215             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
216             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
217             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
218             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
219             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
220             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
221             0xff, 0xfc
222         };
223         unsigned char b[] = {
224             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
225             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
226             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
227             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
228             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
229             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
230             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
231             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
232             0x3f, 0x00
233         };
234         unsigned char n[] = {
235             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
236             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
237             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
238             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
239             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
240             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
241             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
242             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
243             0x64, 0x09
244         };
245         unsigned char Gx[] = {
246             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
247             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
248             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
249             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
250             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
251             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
252             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
253             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
254             0xbd, 0x66
255         };
256         unsigned char Gy[] = {
257             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
258             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
259             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
260             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
261             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
262             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
263             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
264             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
265             0x66, 0x50
266         };
267
268         if (!initialise_curve(&curve, 521, p, a, b, n, Gx, Gy)) {
269             return NULL;
270         }
271
272         /* Now initialised, no need to do it again */
273         initialised = 1;
274     }
275
276     return &curve;
277 }
278
279 static struct ec_curve *ec_name_to_curve(char *name, int len) {
280     if (len == 8 && !memcmp(name, "nistp", 5)) {
281         name += 5;
282         if (!memcmp(name, "256", 3)) {
283             return ec_p256();
284         } else if (!memcmp(name, "384", 3)) {
285             return ec_p384();
286         } else if (!memcmp(name, "521", 3)) {
287             return ec_p521();
288         }
289     }
290
291     return NULL;
292 }
293
294 static int ec_curve_to_name(const struct ec_curve *curve, unsigned char *name, int len) {
295     /* Return length of string */
296     if (name == NULL) return 8;
297
298     /* Not enough space for the name */
299     if (len < 8) return 0;
300
301     /* Put the name in the buffer */
302     switch (curve->fieldBits) {
303       case 256:
304         memcpy(name+5, "256", 3);
305         break;
306       case 384:
307         memcpy(name+5, "384", 3);
308         break;
309       case 521:
310         memcpy(name+5, "521", 3);
311         break;
312       default:
313         return 0;
314     }
315
316     memcpy(name, "nistp", 5);
317     return 8;
318 }
319
320 /* Return 1 if a is -3 % p, otherwise return 0
321  * This is used because there are some maths optimisations */
322 static int ec_aminus3(const struct ec_curve *curve)
323 {
324     int ret;
325     Bignum _p;
326
327     _p = bignum_add_long(curve->a, 3);
328     if (!_p) return 0;
329
330     ret = !bignum_cmp(curve->p, _p);
331     freebn(_p);
332     return ret;
333 }
334
335 /* ----------------------------------------------------------------------
336  * Elliptic curve field maths
337  */
338
339 static Bignum ecf_add(const Bignum a, const Bignum b,
340                       const struct ec_curve *curve)
341 {
342     Bignum a1, b1, ab, ret;
343
344     a1 = bigmod(a, curve->p);
345     if (!a1) return NULL;
346     b1 = bigmod(b, curve->p);
347     if (!b1)
348     {
349         freebn(a1);
350         return NULL;
351     }
352
353     ab = bigadd(a1, b1);
354     freebn(a1);
355     freebn(b1);
356     if (!ab) return NULL;
357
358     ret = bigmod(ab, curve->p);
359     freebn(ab);
360
361     return ret;
362 }
363
364 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
365 {
366     return modmul(a, a, curve->p);
367 }
368
369 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
370 {
371     Bignum ret, tmp;
372
373     /* Double */
374     tmp = bignum_lshift(a, 1);
375     if (!tmp) return NULL;
376
377     /* Add itself (i.e. treble) */
378     ret = bigadd(tmp, a);
379     freebn(tmp);
380
381     /* Normalise */
382     while (ret != NULL && bignum_cmp(ret, curve->p) >= 0)
383     {
384         tmp = bigsub(ret, curve->p);
385         freebn(ret);
386         ret = tmp;
387     }
388
389     return ret;
390 }
391
392 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
393 {
394     Bignum ret = bignum_lshift(a, 1);
395     if (!ret) return NULL;
396     if (bignum_cmp(ret, curve->p) >= 0)
397     {
398         Bignum tmp = bigsub(ret, curve->p);
399         freebn(ret);
400         return tmp;
401     }
402     else
403     {
404         return ret;
405     }
406 }
407
408 /* ----------------------------------------------------------------------
409  * Memory functions
410  */
411
412 void ec_point_free(struct ec_point *point)
413 {
414     if (point == NULL) return;
415     point->curve = 0;
416     if (point->x) freebn(point->x);
417     if (point->y) freebn(point->y);
418     if (point->z) freebn(point->z);
419     point->infinity = 0;
420     sfree(point);
421 }
422
423 static struct ec_point *ec_point_new(const struct ec_curve *curve,
424                                      const Bignum x, const Bignum y, const Bignum z,
425                                      unsigned char infinity)
426 {
427     struct ec_point *point = snewn(1, struct ec_point);
428     point->curve = curve;
429     point->x = x;
430     point->y = y;
431     point->z = z;
432     point->infinity = infinity ? 1 : 0;
433     return point;
434 }
435
436 static struct ec_point *ec_point_copy(const struct ec_point *a)
437 {
438     if (a == NULL) return NULL;
439     return ec_point_new(a->curve,
440                         a->x ? copybn(a->x) : NULL,
441                         a->y ? copybn(a->y) : NULL,
442                         a->z ? copybn(a->z) : NULL,
443                         a->infinity);
444 }
445
446 static int ec_point_verify(const struct ec_point *a)
447 {
448     if (a->infinity)
449     {
450         return 1;
451     }
452     else
453     {
454         /* Verify y^2 = x^3 + ax + b */
455         int ret = 0;
456
457         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
458
459         Bignum Three = bignum_from_long(3);
460         if (!Three) return 0;
461
462         lhs = modmul(a->y, a->y, a->curve->p);
463         if (!lhs) goto error;
464
465         /* This uses montgomery multiplication to optimise */
466         x3 = modpow(a->x, Three, a->curve->p);
467         freebn(Three);
468         if (!x3) goto error;
469         ax = modmul(a->curve->a, a->x, a->curve->p);
470         if (!ax) goto error;
471         x3ax = bigadd(x3, ax);
472         if (!x3ax) goto error;
473         freebn(x3); x3 = NULL;
474         freebn(ax); ax = NULL;
475         x3axm = bigmod(x3ax, a->curve->p);
476         if (!x3axm) goto error;
477         freebn(x3ax); x3ax = NULL;
478         x3axb = bigadd(x3axm, a->curve->b);
479         if (!x3axb) goto error;
480         freebn(x3axm); x3axm = NULL;
481         rhs = bigmod(x3axb, a->curve->p);
482         if (!rhs) goto error;
483         freebn(x3axb);
484
485         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
486         freebn(lhs);
487         freebn(rhs);
488
489         return ret;
490
491       error:
492         if (x3) freebn(x3);
493         if (ax) freebn(ax);
494         if (x3ax) freebn(x3ax);
495         if (x3axm) freebn(x3axm);
496         if (x3axb) freebn(x3axb);
497         if (lhs) freebn(lhs);
498         return 0;
499     }
500 }
501
502 /* ----------------------------------------------------------------------
503  * Elliptic curve point maths
504  */
505
506 /* Returns 1 on success and 0 on memory error */
507 static int ecp_normalise(struct ec_point *a)
508 {
509     Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
510
511     /* In Jacobian Coordinates the triple (X, Y, Z) represents
512        the affine point (X / Z^2, Y / Z^3) */
513     if (!a) {
514         /* No point */
515         return 0;
516     }
517     if (a->infinity) {
518         /* Point is at infinity - i.e. normalised */
519         return 1;
520     } else if (!a->x || !a->y) {
521         /* No point defined */
522         return 0;
523     } else if (!a->z) {
524         /* Already normalised */
525         return 1;
526     }
527
528     Z2 = ecf_square(a->z, a->curve);
529     if (!Z2) {
530         return 0;
531     }
532     Z2inv = modinv(Z2, a->curve->p);
533     if (!Z2inv) {
534         freebn(Z2);
535         return 0;
536     }
537     tx = modmul(a->x, Z2inv, a->curve->p);
538     freebn(Z2inv);
539     if (!tx) {
540         freebn(Z2);
541         return 0;
542     }
543
544     Z3 = modmul(Z2, a->z, a->curve->p);
545     freebn(Z2);
546     if (!Z3) {
547         freebn(tx);
548         return 0;
549     }
550     Z3inv = modinv(Z3, a->curve->p);
551     freebn(Z3);
552     if (!Z3inv) {
553         freebn(tx);
554         return 0;
555     }
556     ty = modmul(a->y, Z3inv, a->curve->p);
557     freebn(Z3inv);
558     if (!ty) {
559         freebn(tx);
560         return 0;
561     }
562
563     freebn(a->x);
564     a->x = tx;
565     freebn(a->y);
566     a->y = ty;
567     freebn(a->z);
568     a->z = NULL;
569     return 1;
570 }
571
572 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
573 {
574     Bignum S, M, outx, outy, outz;
575
576     if (a->infinity || bignum_cmp(a->y, Zero) == 0)
577     {
578         /* Identity */
579         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
580     }
581
582     /* S = 4*X*Y^2 */
583     {
584         Bignum Y2, XY2, _2XY2;
585
586         Y2 = ecf_square(a->y, a->curve);
587         if (!Y2) {
588             return NULL;
589         }
590         XY2 = modmul(a->x, Y2, a->curve->p);
591         freebn(Y2);
592         if (!XY2) {
593             return NULL;
594         }
595
596         _2XY2 = ecf_double(XY2, a->curve);
597         freebn(XY2);
598         if (!_2XY2) {
599             return NULL;
600         }
601         S = ecf_double(_2XY2, a->curve);
602         freebn(_2XY2);
603         if (!S) {
604             return NULL;
605         }
606     }
607
608     /* Faster calculation if a = -3 */
609     if (aminus3) {
610         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
611         Bignum Z2, XpZ2, XmZ2, second;
612
613         if (a->z == NULL) {
614             Z2 = copybn(One);
615         } else {
616             Z2 = ecf_square(a->z, a->curve);
617         }
618         if (!Z2) {
619             freebn(S);
620             return NULL;
621         }
622
623         XpZ2 = ecf_add(a->x, Z2, a->curve);
624         if (!XpZ2) {
625             freebn(S);
626             freebn(Z2);
627             return NULL;
628         }
629         XmZ2 = modsub(a->x, Z2, a->curve->p);
630         freebn(Z2);
631         if (!XmZ2) {
632             freebn(S);
633             freebn(XpZ2);
634             return NULL;
635         }
636
637         second = modmul(XpZ2, XmZ2, a->curve->p);
638         freebn(XpZ2);
639         freebn(XmZ2);
640         if (!second) {
641             freebn(S);
642             return NULL;
643         }
644
645         M = ecf_treble(second, a->curve);
646         freebn(second);
647         if (!M) {
648             freebn(S);
649             return NULL;
650         }
651     } else {
652         /* M = 3*X^2 + a*Z^4 */
653         Bignum _3X2, X2, aZ4;
654
655         if (a->z == NULL) {
656             aZ4 = copybn(a->curve->a);
657         } else {
658             Bignum Z2, Z4;
659
660             Z2 = ecf_square(a->z, a->curve);
661             if (!Z2) {
662                 freebn(S);
663                 return NULL;
664             }
665             Z4 = ecf_square(Z2, a->curve);
666             freebn(Z2);
667             if (!Z4) {
668                 freebn(S);
669                 return NULL;
670             }
671             aZ4 = modmul(a->curve->a, Z4, a->curve->p);
672             freebn(Z4);
673         }
674         if (!aZ4) {
675             freebn(S);
676             return NULL;
677         }
678
679         X2 = modmul(a->x, a->x, a->curve->p);
680         if (!X2) {
681             freebn(S);
682             freebn(aZ4);
683             return NULL;
684         }
685         _3X2 = ecf_treble(X2, a->curve);
686         freebn(X2);
687         if (!_3X2) {
688             freebn(S);
689             freebn(aZ4);
690             return NULL;
691         }
692         M = ecf_add(_3X2, aZ4, a->curve);
693         freebn(_3X2);
694         freebn(aZ4);
695         if (!M) {
696             freebn(S);
697             return NULL;
698         }
699     }
700
701     /* X' = M^2 - 2*S */
702     {
703         Bignum M2, _2S;
704
705         M2 = ecf_square(M, a->curve);
706         if (!M2) {
707             freebn(S);
708             freebn(M);
709             return NULL;
710         }
711
712         _2S = ecf_double(S, a->curve);
713         if (!_2S) {
714             freebn(M2);
715             freebn(S);
716             freebn(M);
717             return NULL;
718         }
719
720         outx = modsub(M2, _2S, a->curve->p);
721         freebn(M2);
722         freebn(_2S);
723         if (!outx) {
724             freebn(S);
725             freebn(M);
726             return NULL;
727         }
728     }
729
730     /* Y' = M*(S - X') - 8*Y^4 */
731     {
732         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
733
734         SX = modsub(S, outx, a->curve->p);
735         freebn(S);
736         if (!SX) {
737             freebn(M);
738             freebn(outx);
739             return NULL;
740         }
741         MSX = modmul(M, SX, a->curve->p);
742         freebn(SX);
743         freebn(M);
744         if (!MSX) {
745             freebn(outx);
746             return NULL;
747         }
748         Y2 = ecf_square(a->y, a->curve);
749         if (!Y2) {
750             freebn(outx);
751             freebn(MSX);
752             return NULL;
753         }
754         Y4 = ecf_square(Y2, a->curve);
755         freebn(Y2);
756         if (!Y4) {
757             freebn(outx);
758             freebn(MSX);
759             return NULL;
760         }
761         Eight = bignum_from_long(8);
762         if (!Eight) {
763             freebn(outx);
764             freebn(MSX);
765             freebn(Y4);
766             return NULL;
767         }
768         _8Y4 = modmul(Eight, Y4, a->curve->p);
769         freebn(Eight);
770         freebn(Y4);
771         if (!_8Y4) {
772             freebn(outx);
773             freebn(MSX);
774             return NULL;
775         }
776         outy = modsub(MSX, _8Y4, a->curve->p);
777         freebn(MSX);
778         freebn(_8Y4);
779         if (!outy) {
780             freebn(outx);
781             return NULL;
782         }
783     }
784
785     /* Z' = 2*Y*Z */
786     {
787         Bignum YZ;
788
789         if (a->z == NULL) {
790             YZ = copybn(a->y);
791         } else {
792             YZ = modmul(a->y, a->z, a->curve->p);
793         }
794         if (!YZ) {
795             freebn(outx);
796             freebn(outy);
797             return NULL;
798         }
799
800         outz = ecf_double(YZ, a->curve);
801         freebn(YZ);
802         if (!outz) {
803             freebn(outx);
804             freebn(outy);
805             return NULL;
806         }
807     }
808
809     return ec_point_new(a->curve, outx, outy, outz, 0);
810 }
811
812 static struct ec_point *ecp_add(const struct ec_point *a,
813                                 const struct ec_point *b,
814                                 const int aminus3)
815 {
816     Bignum U1, U2, S1, S2, outx, outy, outz;
817
818     /* Check if multiplying by infinity */
819     if (a->infinity) return ec_point_copy(b);
820     if (b->infinity) return ec_point_copy(a);
821
822     /* U1 = X1*Z2^2 */
823     /* S1 = Y1*Z2^3 */
824     if (b->z) {
825         Bignum Z2, Z3;
826
827         Z2 = ecf_square(b->z, a->curve);
828         if (!Z2) {
829             return NULL;
830         }
831         U1 = modmul(a->x, Z2, a->curve->p);
832         if (!U1) {
833             freebn(Z2);
834             return NULL;
835         }
836         Z3 = modmul(Z2, b->z, a->curve->p);
837         freebn(Z2);
838         if (!Z3) {
839             freebn(U1);
840             return NULL;
841         }
842         S1 = modmul(a->y, Z3, a->curve->p);
843         freebn(Z3);
844         if (!S1) {
845             freebn(U1);
846             return NULL;
847         }
848     } else {
849         U1 = copybn(a->x);
850         if (!U1) {
851             return NULL;
852         }
853         S1 = copybn(a->y);
854         if (!S1) {
855             freebn(U1);
856             return NULL;
857         }
858     }
859
860     /* U2 = X2*Z1^2 */
861     /* S2 = Y2*Z1^3 */
862     if (a->z) {
863         Bignum Z2, Z3;
864
865         Z2 = ecf_square(a->z, b->curve);
866         if (!Z2) {
867             freebn(U1);
868             freebn(S1);
869             return NULL;
870         }
871         U2 = modmul(b->x, Z2, b->curve->p);
872         if (!U2) {
873             freebn(U1);
874             freebn(S1);
875             freebn(Z2);
876             return NULL;
877         }
878         Z3 = modmul(Z2, a->z, b->curve->p);
879         freebn(Z2);
880         if (!Z3) {
881             freebn(U1);
882             freebn(S1);
883             freebn(U2);
884             return NULL;
885         }
886         S2 = modmul(b->y, Z3, b->curve->p);
887         freebn(Z3);
888         if (!S2) {
889             freebn(U1);
890             freebn(S1);
891             freebn(U2);
892             return NULL;
893         }
894     } else {
895         U2 = copybn(b->x);
896         if (!U2) {
897             freebn(U1);
898             freebn(S1);
899             return NULL;
900         }
901         S2 = copybn(b->y);
902         if (!S2) {
903             freebn(U1);
904             freebn(S1);
905             freebn(U2);
906             return NULL;
907         }
908     }
909
910     /* Check if multiplying by self */
911     if (bignum_cmp(U1, U2) == 0)
912     {
913         freebn(U1);
914         freebn(U2);
915         if (bignum_cmp(S1, S2) == 0)
916         {
917             freebn(S1);
918             freebn(S2);
919             return ecp_double(a, aminus3);
920         }
921         else
922         {
923             freebn(S1);
924             freebn(S2);
925             /* Infinity */
926             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
927         }
928     }
929
930     {
931         Bignum H, R, UH2, H3;
932
933         /* H = U2 - U1 */
934         H = modsub(U2, U1, a->curve->p);
935         freebn(U2);
936         if (!H) {
937             freebn(U1);
938             freebn(S1);
939             freebn(S2);
940             return NULL;
941         }
942
943         /* R = S2 - S1 */
944         R = modsub(S2, S1, a->curve->p);
945         freebn(S2);
946         if (!R) {
947             freebn(H);
948             freebn(S1);
949             freebn(U1);
950             return NULL;
951         }
952
953         /* X3 = R^2 - H^3 - 2*U1*H^2 */
954         {
955             Bignum R2, H2, _2UH2, first;
956
957             H2 = ecf_square(H, a->curve);
958             if (!H2) {
959                 freebn(U1);
960                 freebn(S1);
961                 freebn(H);
962                 freebn(R);
963                 return NULL;
964             }
965             UH2 = modmul(U1, H2, a->curve->p);
966             freebn(U1);
967             if (!UH2) {
968                 freebn(H2);
969                 freebn(S1);
970                 freebn(H);
971                 freebn(R);
972                 return NULL;
973             }
974             H3 = modmul(H2, H, a->curve->p);
975             freebn(H2);
976             if (!H3) {
977                 freebn(UH2);
978                 freebn(S1);
979                 freebn(H);
980                 freebn(R);
981                 return NULL;
982             }
983             R2 = ecf_square(R, a->curve);
984             if (!R2) {
985                 freebn(H3);
986                 freebn(UH2);
987                 freebn(S1);
988                 freebn(H);
989                 freebn(R);
990                 return NULL;
991             }
992             _2UH2 = ecf_double(UH2, a->curve);
993             if (!_2UH2) {
994                 freebn(R2);
995                 freebn(H3);
996                 freebn(UH2);
997                 freebn(S1);
998                 freebn(H);
999                 freebn(R);
1000                 return NULL;
1001             }
1002             first = modsub(R2, H3, a->curve->p);
1003             freebn(R2);
1004             if (!first) {
1005                 freebn(H3);
1006                 freebn(_2UH2);
1007                 freebn(UH2);
1008                 freebn(S1);
1009                 freebn(H);
1010                 freebn(R);
1011                 return NULL;
1012             }
1013             outx = modsub(first, _2UH2, a->curve->p);
1014             freebn(first);
1015             freebn(_2UH2);
1016             if (!outx) {
1017                 freebn(H3);
1018                 freebn(UH2);
1019                 freebn(S1);
1020                 freebn(H);
1021                 freebn(R);
1022                 return NULL;
1023             }
1024         }
1025
1026         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
1027         {
1028             Bignum RUH2mX, UH2mX, SH3;
1029
1030             UH2mX = modsub(UH2, outx, a->curve->p);
1031             freebn(UH2);
1032             if (!UH2mX) {
1033                 freebn(outx);
1034                 freebn(H3);
1035                 freebn(S1);
1036                 freebn(H);
1037                 freebn(R);
1038                 return NULL;
1039             }
1040             RUH2mX = modmul(R, UH2mX, a->curve->p);
1041             freebn(UH2mX);
1042             freebn(R);
1043             if (!RUH2mX) {
1044                 freebn(outx);
1045                 freebn(H3);
1046                 freebn(S1);
1047                 freebn(H);
1048                 return NULL;
1049             }
1050             SH3 = modmul(S1, H3, a->curve->p);
1051             freebn(S1);
1052             freebn(H3);
1053             if (!SH3) {
1054                 freebn(RUH2mX);
1055                 freebn(outx);
1056                 freebn(H);
1057                 return NULL;
1058             }
1059
1060             outy = modsub(RUH2mX, SH3, a->curve->p);
1061             freebn(RUH2mX);
1062             freebn(SH3);
1063             if (!outy) {
1064                 freebn(outx);
1065                 freebn(H);
1066                 return NULL;
1067             }
1068         }
1069
1070         /* Z3 = H*Z1*Z2 */
1071         if (a->z && b->z) {
1072             Bignum ZZ;
1073
1074             ZZ = modmul(a->z, b->z, a->curve->p);
1075             if (!ZZ) {
1076                 freebn(outx);
1077                 freebn(outy);
1078                 freebn(H);
1079                 return NULL;
1080             }
1081             outz = modmul(H, ZZ, a->curve->p);
1082             freebn(H);
1083             freebn(ZZ);
1084             if (!outz) {
1085                 freebn(outx);
1086                 freebn(outy);
1087                 return NULL;
1088             }
1089         } else if (a->z) {
1090             outz = modmul(H, a->z, a->curve->p);
1091             freebn(H);
1092             if (!outz) {
1093                 freebn(outx);
1094                 freebn(outy);
1095                 return NULL;
1096             }
1097         } else if (b->z) {
1098             outz = modmul(H, b->z, a->curve->p);
1099             freebn(H);
1100             if (!outz) {
1101                 freebn(outx);
1102                 freebn(outy);
1103                 return NULL;
1104             }
1105         } else {
1106             outz = H;
1107         }
1108     }
1109
1110     return ec_point_new(a->curve, outx, outy, outz, 0);
1111 }
1112
1113 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1114 {
1115     struct ec_point *A, *ret;
1116     int bits, i;
1117
1118     A = ec_point_copy(a);
1119     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1120
1121     bits = bignum_bitcount(b);
1122     for (i = 0; ret != NULL && A != NULL && i < bits; ++i)
1123     {
1124         if (bignum_bit(b, i))
1125         {
1126             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1127             ec_point_free(ret);
1128             ret = tmp;
1129         }
1130         if (i+1 != bits)
1131         {
1132             struct ec_point *tmp = ecp_double(A, aminus3);
1133             ec_point_free(A);
1134             A = tmp;
1135         }
1136     }
1137
1138     if (!A) {
1139         ec_point_free(ret);
1140         ret = NULL;
1141     } else {
1142         ec_point_free(A);
1143     }
1144
1145     return ret;
1146 }
1147
1148 /* Not static because it is used by sshecdsag.c to generate a new key */
1149 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1150 {
1151     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1152
1153     if (!ecp_normalise(ret)) {
1154         ec_point_free(ret);
1155         return NULL;
1156     }
1157
1158     return ret;
1159 }
1160
1161 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
1162                                    const struct ec_point *point)
1163 {
1164     struct ec_point *aG, *bP, *ret;
1165     int aminus3 = ec_aminus3(point->curve);
1166
1167     aG = ecp_mul_(&point->curve->G, a, aminus3);
1168     if (!aG) return NULL;
1169     bP = ecp_mul_(point, b, aminus3);
1170     if (!bP) {
1171         ec_point_free(aG);
1172         return NULL;
1173     }
1174
1175     ret = ecp_add(aG, bP, aminus3);
1176
1177     ec_point_free(aG);
1178     ec_point_free(bP);
1179
1180     if (!ecp_normalise(ret)) {
1181         ec_point_free(ret);
1182         return NULL;
1183     }
1184
1185     return ret;
1186 }
1187
1188 /* ----------------------------------------------------------------------
1189  * Basic sign and verify routines
1190  */
1191
1192 static int _ecdsa_verify(const struct ec_point *publicKey,
1193                          const unsigned char *data, const int dataLen,
1194                          const Bignum r, const Bignum s)
1195 {
1196     int z_bits, n_bits;
1197     Bignum z;
1198     int valid = 0;
1199
1200     /* Sanity checks */
1201     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->n) >= 0
1202         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->n) >= 0)
1203     {
1204         return 0;
1205     }
1206
1207     /* z = left most bitlen(curve->n) of data */
1208     z = bignum_from_bytes(data, dataLen);
1209     if (!z) return 0;
1210     n_bits = bignum_bitcount(publicKey->curve->n);
1211     z_bits = bignum_bitcount(z);
1212     if (z_bits > n_bits)
1213     {
1214         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
1215         freebn(z);
1216         z = tmp;
1217         if (!z) return 0;
1218     }
1219
1220     /* Ensure z in range of n */
1221     {
1222         Bignum tmp = bigmod(z, publicKey->curve->n);
1223         freebn(z);
1224         z = tmp;
1225         if (!z) return 0;
1226     }
1227
1228     /* Calculate signature */
1229     {
1230         Bignum w, x, u1, u2;
1231         struct ec_point *tmp;
1232
1233         w = modinv(s, publicKey->curve->n);
1234         if (!w) {
1235             freebn(z);
1236             return 0;
1237         }
1238         u1 = modmul(z, w, publicKey->curve->n);
1239         if (!u1) {
1240             freebn(z);
1241             freebn(w);
1242             return 0;
1243         }
1244         u2 = modmul(r, w, publicKey->curve->n);
1245         freebn(w);
1246         if (!u2) {
1247             freebn(z);
1248             freebn(u1);
1249             return 0;
1250         }
1251
1252         tmp = ecp_summul(u1, u2, publicKey);
1253         freebn(u1);
1254         freebn(u2);
1255         if (!tmp) {
1256             freebn(z);
1257             return 0;
1258         }
1259
1260         x = bigmod(tmp->x, publicKey->curve->n);
1261         ec_point_free(tmp);
1262         if (!x) {
1263             freebn(z);
1264             return 0;
1265         }
1266
1267         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
1268         freebn(x);
1269     }
1270
1271     freebn(z);
1272
1273     return valid;
1274 }
1275
1276 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
1277                         const unsigned char *data, const int dataLen,
1278                         Bignum *r, Bignum *s)
1279 {
1280     unsigned char digest[20];
1281     int z_bits, n_bits;
1282     Bignum z, k;
1283     struct ec_point *kG;
1284
1285     *r = NULL;
1286     *s = NULL;
1287
1288     /* z = left most bitlen(curve->n) of data */
1289     z = bignum_from_bytes(data, dataLen);
1290     if (!z) return;
1291     n_bits = bignum_bitcount(curve->n);
1292     z_bits = bignum_bitcount(z);
1293     if (z_bits > n_bits)
1294     {
1295         Bignum tmp;
1296         tmp = bignum_rshift(z, z_bits - n_bits);
1297         freebn(z);
1298         z = tmp;
1299         if (!z) return;
1300     }
1301
1302     /* Generate k between 1 and curve->n, using the same deterministic
1303      * k generation system we use for conventional DSA. */
1304     SHA_Simple(data, dataLen, digest);
1305     k = dss_gen_k("ECDSA deterministic k generator", curve->n, privateKey,
1306                   digest, sizeof(digest));
1307     if (!k) return;
1308
1309     kG = ecp_mul(&curve->G, k);
1310     if (!kG) {
1311         freebn(z);
1312         freebn(k);
1313         return;
1314     }
1315
1316     /* r = kG.x mod n */
1317     *r = bigmod(kG->x, curve->n);
1318     ec_point_free(kG);
1319     if (!*r) {
1320         freebn(z);
1321         freebn(k);
1322         return;
1323     }
1324
1325     /* s = (z + r * priv)/k mod n */
1326     {
1327         Bignum rPriv, zMod, first, firstMod, kInv;
1328         rPriv = modmul(*r, privateKey, curve->n);
1329         if (!rPriv) {
1330             freebn(*r);
1331             freebn(z);
1332             freebn(k);
1333             return;
1334         }
1335         zMod = bigmod(z, curve->n);
1336         freebn(z);
1337         if (!zMod) {
1338             freebn(rPriv);
1339             freebn(*r);
1340             freebn(k);
1341             return;
1342         }
1343         first = bigadd(rPriv, zMod);
1344         freebn(rPriv);
1345         freebn(zMod);
1346         if (!first) {
1347             freebn(*r);
1348             freebn(k);
1349             return;
1350         }
1351         firstMod = bigmod(first, curve->n);
1352         freebn(first);
1353         if (!firstMod) {
1354             freebn(*r);
1355             freebn(k);
1356             return;
1357         }
1358         kInv = modinv(k, curve->n);
1359         freebn(k);
1360         if (!kInv) {
1361             freebn(firstMod);
1362             freebn(*r);
1363             return;
1364         }
1365         *s = modmul(firstMod, kInv, curve->n);
1366         freebn(firstMod);
1367         freebn(kInv);
1368         if (!*s) {
1369             freebn(*r);
1370             return;
1371         }
1372     }
1373 }
1374
1375 /* ----------------------------------------------------------------------
1376  * Misc functions
1377  */
1378
1379 static void getstring(char **data, int *datalen, char **p, int *length)
1380 {
1381     *p = NULL;
1382     if (*datalen < 4)
1383         return;
1384     *length = toint(GET_32BIT(*data));
1385     if (*length < 0)
1386         return;
1387     *datalen -= 4;
1388     *data += 4;
1389     if (*datalen < *length)
1390         return;
1391     *p = *data;
1392     *data += *length;
1393     *datalen -= *length;
1394 }
1395
1396 static Bignum getmp(char **data, int *datalen)
1397 {
1398     char *p;
1399     int length;
1400
1401     getstring(data, datalen, &p, &length);
1402     if (!p)
1403         return NULL;
1404     if (p[0] & 0x80)
1405         return NULL;                   /* negative mp */
1406     return bignum_from_bytes((unsigned char *)p, length);
1407 }
1408
1409 static int decodepoint(char *p, int length, struct ec_point *point)
1410 {
1411     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
1412         return 0;
1413     /* Skip compression flag */
1414     ++p;
1415     --length;
1416     /* The two values must be equal length */
1417     if (length % 2 != 0) {
1418         point->x = NULL;
1419         point->y = NULL;
1420         point->z = NULL;
1421         return 0;
1422     }
1423     length = length / 2;
1424     point->x = bignum_from_bytes((unsigned char *)p, length);
1425     if (!point->x) return 0;
1426     p += length;
1427     point->y = bignum_from_bytes((unsigned char *)p, length);
1428     if (!point->y) {
1429         freebn(point->x);
1430         point->x = NULL;
1431         return 0;
1432     }
1433     point->z = NULL;
1434
1435     /* Verify the point is on the curve */
1436     if (!ec_point_verify(point)) {
1437         freebn(point->x);
1438         point->x = NULL;
1439         freebn(point->y);
1440         point->y = NULL;
1441         return 0;
1442     }
1443
1444     return 1;
1445 }
1446
1447 static int getmppoint(char **data, int *datalen, struct ec_point *point)
1448 {
1449     char *p;
1450     int length;
1451
1452     getstring(data, datalen, &p, &length);
1453     if (!p) return 0;
1454     return decodepoint(p, length, point);
1455 }
1456
1457 /* ----------------------------------------------------------------------
1458  * Exposed ECDSA interface
1459  */
1460
1461 static void ecdsa_freekey(void *key)
1462 {
1463     struct ec_key *ec = (struct ec_key *) key;
1464     if (!ec) return;
1465
1466     if (ec->publicKey.x)
1467         freebn(ec->publicKey.x);
1468     if (ec->publicKey.y)
1469         freebn(ec->publicKey.y);
1470     if (ec->publicKey.z)
1471         freebn(ec->publicKey.z);
1472     if (ec->privateKey)
1473         freebn(ec->privateKey);
1474     sfree(ec);
1475 }
1476
1477 static void *ecdsa_newkey(char *data, int len)
1478 {
1479     char *p;
1480     int slen;
1481     struct ec_key *ec;
1482     struct ec_curve *curve;
1483
1484     getstring(&data, &len, &p, &slen);
1485
1486     if (!p || slen < 11 || memcmp(p, "ecdsa-sha2-", 11)) {
1487         return NULL;
1488     }
1489     curve = ec_name_to_curve(p+11, slen-11);
1490     if (!curve) return NULL;
1491
1492     getstring(&data, &len, &p, &slen);
1493
1494     if (curve != ec_name_to_curve(p, slen)) return NULL;
1495
1496     ec = snew(struct ec_key);
1497
1498     ec->publicKey.curve = curve;
1499     ec->publicKey.infinity = 0;
1500     ec->publicKey.x = NULL;
1501     ec->publicKey.y = NULL;
1502     ec->publicKey.z = NULL;
1503     if (!getmppoint(&data, &len, &ec->publicKey)) {
1504         ecdsa_freekey(ec);
1505         return NULL;
1506     }
1507     ec->privateKey = NULL;
1508
1509     if (!ec->publicKey.x || !ec->publicKey.y ||
1510         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
1511         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
1512     {
1513         ecdsa_freekey(ec);
1514         ec = NULL;
1515     }
1516
1517     return ec;
1518 }
1519
1520 static char *ecdsa_fmtkey(void *key)
1521 {
1522     struct ec_key *ec = (struct ec_key *) key;
1523     char *p;
1524     int len, i, pos, nibbles;
1525     static const char hex[] = "0123456789abcdef";
1526     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
1527         return NULL;
1528
1529     pos = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1530     if (pos == 0) return NULL;
1531
1532     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
1533     len += pos; /* Curve name */
1534     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
1535     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
1536     p = snewn(len, char);
1537
1538     pos = ec_curve_to_name(ec->publicKey.curve, (unsigned char*)p, pos);
1539     pos += sprintf(p + pos, ",0x");
1540     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
1541     if (nibbles < 1)
1542         nibbles = 1;
1543     for (i = nibbles; i--;) {
1544         p[pos++] =
1545             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
1546     }
1547     pos += sprintf(p + pos, ",0x");
1548     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
1549     if (nibbles < 1)
1550         nibbles = 1;
1551     for (i = nibbles; i--;) {
1552         p[pos++] =
1553             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
1554     }
1555     p[pos] = '\0';
1556     return p;
1557 }
1558
1559 static unsigned char *ecdsa_public_blob(void *key, int *len)
1560 {
1561     struct ec_key *ec = (struct ec_key *) key;
1562     int pointlen, bloblen, namelen;
1563     int i;
1564     unsigned char *blob, *p;
1565
1566     namelen = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1567     if (namelen == 0) return NULL;
1568
1569     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1570
1571     /*
1572      * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
1573      */
1574     bloblen = 4 + 11 + namelen + 4 + namelen + 4 + 1 + (pointlen * 2);
1575     blob = snewn(bloblen, unsigned char);
1576
1577     p = blob;
1578     PUT_32BIT(p, 11 + namelen);
1579     p += 4;
1580     memcpy(p, "ecdsa-sha2-", 11);
1581     p += 11;
1582     p += ec_curve_to_name(ec->publicKey.curve, p, namelen);
1583     PUT_32BIT(p, namelen);
1584     p += 4;
1585     p += ec_curve_to_name(ec->publicKey.curve, p, namelen);
1586     PUT_32BIT(p, (2 * pointlen) + 1);
1587     p += 4;
1588     *p++ = 0x04;
1589     for (i = pointlen; i--;)
1590         *p++ = bignum_byte(ec->publicKey.x, i);
1591     for (i = pointlen; i--;)
1592         *p++ = bignum_byte(ec->publicKey.y, i);
1593
1594     assert(p == blob + bloblen);
1595     *len = bloblen;
1596
1597     return blob;
1598 }
1599
1600 static unsigned char *ecdsa_private_blob(void *key, int *len)
1601 {
1602     struct ec_key *ec = (struct ec_key *) key;
1603     int keylen, bloblen;
1604     int i;
1605     unsigned char *blob, *p;
1606
1607     if (!ec->privateKey) return NULL;
1608
1609     keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
1610
1611     /*
1612      * mpint privateKey. Total 4 + keylen.
1613      */
1614     bloblen = 4 + keylen;
1615     blob = snewn(bloblen, unsigned char);
1616
1617     p = blob;
1618     PUT_32BIT(p, keylen);
1619     p += 4;
1620     for (i = keylen; i--;)
1621         *p++ = bignum_byte(ec->privateKey, i);
1622
1623     assert(p == blob + bloblen);
1624     *len = bloblen;
1625     return blob;
1626 }
1627
1628 static void *ecdsa_createkey(unsigned char *pub_blob, int pub_len,
1629                              unsigned char *priv_blob, int priv_len)
1630 {
1631     struct ec_key *ec;
1632     struct ec_point *publicKey;
1633     char *pb = (char *) priv_blob;
1634
1635     ec = (struct ec_key*)ecdsa_newkey((char *) pub_blob, pub_len);
1636     if (!ec) {
1637         return NULL;
1638     }
1639
1640     ec->privateKey = getmp(&pb, &priv_len);
1641     if (!ec->privateKey) {
1642         ecdsa_freekey(ec);
1643         return NULL;
1644     }
1645
1646     /* Check that private key generates public key */
1647     publicKey = ecp_mul(&ec->publicKey.curve->G, ec->privateKey);
1648
1649     if (!publicKey ||
1650         bignum_cmp(publicKey->x, ec->publicKey.x) ||
1651         bignum_cmp(publicKey->y, ec->publicKey.y))
1652     {
1653         ecdsa_freekey(ec);
1654         ec = NULL;
1655     }
1656     ec_point_free(publicKey);
1657
1658     return ec;
1659 }
1660
1661 static void *ecdsa_openssh_createkey(unsigned char **blob, int *len)
1662 {
1663     char **b = (char **) blob;
1664     char *p;
1665     int slen;
1666     struct ec_key *ec;
1667     struct ec_curve *curve;
1668     struct ec_point *publicKey;
1669
1670     getstring(b, len, &p, &slen);
1671
1672     if (!p) {
1673         return NULL;
1674     }
1675     curve = ec_name_to_curve(p, slen);
1676     if (!curve) return NULL;
1677
1678     ec = snew(struct ec_key);
1679
1680     ec->publicKey.curve = curve;
1681     ec->publicKey.infinity = 0;
1682     ec->publicKey.x = NULL;
1683     ec->publicKey.y = NULL;
1684     ec->publicKey.z = NULL;
1685     if (!getmppoint(b, len, &ec->publicKey)) {
1686         ecdsa_freekey(ec);
1687         return NULL;
1688     }
1689     ec->privateKey = NULL;
1690
1691     if (!ec->publicKey.x || !ec->publicKey.y ||
1692         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
1693         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
1694     {
1695         ecdsa_freekey(ec);
1696         return NULL;
1697     }
1698
1699     ec->privateKey = getmp(b, len);
1700     if (ec->privateKey == NULL)
1701     {
1702         ecdsa_freekey(ec);
1703         return NULL;
1704     }
1705
1706     /* Now check that the private key makes the public key */
1707     publicKey = ecp_mul(&ec->publicKey.curve->G, ec->privateKey);
1708     if (!publicKey)
1709     {
1710         ecdsa_freekey(ec);
1711         return NULL;
1712     }
1713
1714     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
1715         bignum_cmp(ec->publicKey.y, publicKey->y))
1716     {
1717         /* Private key doesn't make the public key on the given curve */
1718         ecdsa_freekey(ec);
1719         ec_point_free(publicKey);
1720         return NULL;
1721     }
1722
1723     ec_point_free(publicKey);
1724
1725     return ec;
1726 }
1727
1728 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
1729 {
1730     struct ec_key *ec = (struct ec_key *) key;
1731
1732     int pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1733
1734     int namelen = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1735
1736     int bloblen =
1737         4 + namelen /* <LEN> nistpXXX */
1738         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
1739         + ssh2_bignum_length(ec->privateKey);
1740
1741     int i;
1742
1743     if (bloblen > len)
1744         return bloblen;
1745
1746     bloblen = 0;
1747
1748     PUT_32BIT(blob+bloblen, namelen);
1749     bloblen += 4;
1750
1751     bloblen += ec_curve_to_name(ec->publicKey.curve, blob+bloblen, namelen);
1752
1753     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
1754     bloblen += 4;
1755     blob[bloblen++] = 0x04;
1756     for (i = pointlen; i--; )
1757         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
1758     for (i = pointlen; i--; )
1759         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
1760
1761     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
1762     PUT_32BIT(blob+bloblen, pointlen);
1763     bloblen += 4;
1764     for (i = pointlen; i--; )
1765         blob[bloblen++] = bignum_byte(ec->privateKey, i);
1766
1767     return bloblen;
1768 }
1769
1770 static int ecdsa_pubkey_bits(void *blob, int len)
1771 {
1772     struct ec_key *ec;
1773     int ret;
1774
1775     ec = (struct ec_key*)ecdsa_newkey((char *) blob, len);
1776     if (!ec)
1777         return -1;
1778     ret = ec->publicKey.curve->fieldBits;
1779     ecdsa_freekey(ec);
1780
1781     return ret;
1782 }
1783
1784 static char *ecdsa_fingerprint(void *key)
1785 {
1786     struct ec_key *ec = (struct ec_key *) key;
1787     struct MD5Context md5c;
1788     unsigned char digest[16], lenbuf[4];
1789     char *ret;
1790     unsigned char *name;
1791     int pointlen, namelen, i, j;
1792
1793     namelen = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1794     name = snewn(namelen, unsigned char);
1795     ec_curve_to_name(ec->publicKey.curve, name, namelen);
1796
1797     MD5Init(&md5c);
1798
1799     PUT_32BIT(lenbuf, namelen + 11);
1800     MD5Update(&md5c, lenbuf, 4);
1801     MD5Update(&md5c, (const unsigned char *)"ecdsa-sha2-", 11);
1802     MD5Update(&md5c, name, namelen);
1803
1804     PUT_32BIT(lenbuf, namelen);
1805     MD5Update(&md5c, lenbuf, 4);
1806     MD5Update(&md5c, name, namelen);
1807
1808     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1809     PUT_32BIT(lenbuf, 1 + (pointlen * 2));
1810     MD5Update(&md5c, lenbuf, 4);
1811     MD5Update(&md5c, (const unsigned char *)"\x04", 1);
1812     for (i = pointlen; i--; ) {
1813         unsigned char c = bignum_byte(ec->publicKey.x, i);
1814         MD5Update(&md5c, &c, 1);
1815     }
1816     for (i = pointlen; i--; ) {
1817         unsigned char c = bignum_byte(ec->publicKey.y, i);
1818         MD5Update(&md5c, &c, 1);
1819     }
1820
1821     MD5Final(digest, &md5c);
1822
1823     ret = snewn(11 + namelen + 1 + (16 * 3), char);
1824
1825     i = 11;
1826     memcpy(ret, "ecdsa-sha2-", 11);
1827     memcpy(ret+i, name, namelen);
1828     i += namelen;
1829     sfree(name);
1830     ret[i++] = ' ';
1831     for (j = 0; j < 16; j++)
1832         i += sprintf(ret + i, "%s%02x", j ? ":" : "", digest[j]);
1833
1834     return ret;
1835 }
1836
1837 static int ecdsa_verifysig(void *key, char *sig, int siglen,
1838                            char *data, int datalen)
1839 {
1840     struct ec_key *ec = (struct ec_key *) key;
1841     char *p;
1842     int slen;
1843     unsigned char digest[512 / 8];
1844     int digestLen;
1845     Bignum r, s;
1846     int ret;
1847
1848     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
1849         return 0;
1850
1851     /* Check the signature curve matches the key curve */
1852     getstring(&sig, &siglen, &p, &slen);
1853     if (!p || slen < 11 || memcmp(p, "ecdsa-sha2-", 11)) {
1854         return 0;
1855     }
1856     if (ec->publicKey.curve != ec_name_to_curve(p+11, slen-11)) {
1857         return 0;
1858     }
1859
1860     getstring(&sig, &siglen, &p, &slen);
1861     r = getmp(&p, &slen);
1862     if (!r) return 0;
1863     s = getmp(&p, &slen);
1864     if (!s) {
1865         freebn(r);
1866         return 0;
1867     }
1868
1869     /* Perform correct hash function depending on curve size */
1870     if (ec->publicKey.curve->fieldBits <= 256) {
1871         SHA256_Simple(data, datalen, digest);
1872         digestLen = 256 / 8;
1873     } else if (ec->publicKey.curve->fieldBits <= 384) {
1874         SHA384_Simple(data, datalen, digest);
1875         digestLen = 384 / 8;
1876     } else {
1877         SHA512_Simple(data, datalen, digest);
1878         digestLen = 512 / 8;
1879     }
1880
1881     /* Verify the signature */
1882     if (!_ecdsa_verify(&ec->publicKey, digest, digestLen, r, s)) {
1883         ret = 0;
1884     } else {
1885         ret = 1;
1886     }
1887
1888     freebn(r);
1889     freebn(s);
1890
1891     return ret;
1892 }
1893
1894 static unsigned char *ecdsa_sign(void *key, char *data, int datalen,
1895                                  int *siglen)
1896 {
1897     struct ec_key *ec = (struct ec_key *) key;
1898     unsigned char digest[512 / 8];
1899     int digestLen;
1900     Bignum r = NULL, s = NULL;
1901     unsigned char *buf, *p;
1902     int rlen, slen, namelen;
1903     int i;
1904
1905     if (!ec->privateKey || !ec->publicKey.curve) {
1906         return NULL;
1907     }
1908
1909     /* Perform correct hash function depending on curve size */
1910     if (ec->publicKey.curve->fieldBits <= 256) {
1911         SHA256_Simple(data, datalen, digest);
1912         digestLen = 256 / 8;
1913     } else if (ec->publicKey.curve->fieldBits <= 384) {
1914         SHA384_Simple(data, datalen, digest);
1915         digestLen = 384 / 8;
1916     } else {
1917         SHA512_Simple(data, datalen, digest);
1918         digestLen = 512 / 8;
1919     }
1920
1921     /* Do the signature */
1922     _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
1923     if (!r || !s) {
1924         if (r) freebn(r);
1925         if (s) freebn(s);
1926         return NULL;
1927     }
1928
1929     rlen = (bignum_bitcount(r) + 8) / 8;
1930     slen = (bignum_bitcount(s) + 8) / 8;
1931
1932     namelen = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1933
1934     /* Format the output */
1935     *siglen = 8+11+namelen+rlen+slen+8;
1936     buf = snewn(*siglen, unsigned char);
1937     p = buf;
1938     PUT_32BIT(p, 11+namelen);
1939     p += 4;
1940     memcpy(p, "ecdsa-sha2-", 11);
1941     p += 11;
1942     p += ec_curve_to_name(ec->publicKey.curve, p, namelen);
1943     PUT_32BIT(p, rlen + slen + 8);
1944     p += 4;
1945     PUT_32BIT(p, rlen);
1946     p += 4;
1947     for (i = rlen; i--;)
1948         *p++ = bignum_byte(r, i);
1949     PUT_32BIT(p, slen);
1950     p += 4;
1951     for (i = slen; i--;)
1952         *p++ = bignum_byte(s, i);
1953
1954     freebn(r);
1955     freebn(s);
1956
1957     return buf;
1958 }
1959
1960 const struct ssh_signkey ssh_ecdsa_nistp256 = {
1961     ecdsa_newkey,
1962     ecdsa_freekey,
1963     ecdsa_fmtkey,
1964     ecdsa_public_blob,
1965     ecdsa_private_blob,
1966     ecdsa_createkey,
1967     ecdsa_openssh_createkey,
1968     ecdsa_openssh_fmtkey,
1969     ecdsa_pubkey_bits,
1970     ecdsa_fingerprint,
1971     ecdsa_verifysig,
1972     ecdsa_sign,
1973     "ecdsa-sha2-nistp256",
1974     "ecdsa-sha2-nistp256",
1975 };
1976
1977 const struct ssh_signkey ssh_ecdsa_nistp384 = {
1978     ecdsa_newkey,
1979     ecdsa_freekey,
1980     ecdsa_fmtkey,
1981     ecdsa_public_blob,
1982     ecdsa_private_blob,
1983     ecdsa_createkey,
1984     ecdsa_openssh_createkey,
1985     ecdsa_openssh_fmtkey,
1986     ecdsa_pubkey_bits,
1987     ecdsa_fingerprint,
1988     ecdsa_verifysig,
1989     ecdsa_sign,
1990     "ecdsa-sha2-nistp384",
1991     "ecdsa-sha2-nistp384",
1992 };
1993
1994 const struct ssh_signkey ssh_ecdsa_nistp521 = {
1995     ecdsa_newkey,
1996     ecdsa_freekey,
1997     ecdsa_fmtkey,
1998     ecdsa_public_blob,
1999     ecdsa_private_blob,
2000     ecdsa_createkey,
2001     ecdsa_openssh_createkey,
2002     ecdsa_openssh_fmtkey,
2003     ecdsa_pubkey_bits,
2004     ecdsa_fingerprint,
2005     ecdsa_verifysig,
2006     ecdsa_sign,
2007     "ecdsa-sha2-nistp521",
2008     "ecdsa-sha2-nistp521",
2009 };
2010
2011 /* ----------------------------------------------------------------------
2012  * Exposed ECDH interface
2013  */
2014
2015 static Bignum ecdh_calculate(const Bignum private,
2016                              const struct ec_point *public)
2017 {
2018     struct ec_point *p;
2019     Bignum ret;
2020     p = ecp_mul(public, private);
2021     if (!p) return NULL;
2022     ret = p->x;
2023     p->x = NULL;
2024     ec_point_free(p);
2025     return ret;
2026 }
2027
2028 void *ssh_ecdhkex_newkey(struct ec_curve *curve)
2029 {
2030     struct ec_key *key = snew(struct ec_key);
2031     struct ec_point *publicKey;
2032     key->publicKey.curve = curve;
2033     key->privateKey = bignum_random_in_range(One, key->publicKey.curve->n);
2034     if (!key->privateKey) {
2035         sfree(key);
2036         return NULL;
2037     }
2038     publicKey = ecp_mul(&key->publicKey.curve->G, key->privateKey);
2039     if (!publicKey) {
2040         freebn(key->privateKey);
2041         sfree(key);
2042         return NULL;
2043     }
2044     key->publicKey.x = publicKey->x;
2045     key->publicKey.y = publicKey->y;
2046     key->publicKey.z = NULL;
2047     sfree(publicKey);
2048     return key;
2049 }
2050
2051 char *ssh_ecdhkex_getpublic(void *key, int *len)
2052 {
2053     struct ec_key *ec = (struct ec_key*)key;
2054     char *point, *p;
2055     int i;
2056     int pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2057
2058     *len = 1 + pointlen * 2;
2059     point = (char*)snewn(*len, char);
2060
2061     p = point;
2062     *p++ = 0x04;
2063     for (i = pointlen; i--;)
2064         *p++ = bignum_byte(ec->publicKey.x, i);
2065     for (i = pointlen; i--;)
2066         *p++ = bignum_byte(ec->publicKey.y, i);
2067
2068     return point;
2069 }
2070
2071 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
2072 {
2073     struct ec_key *ec = (struct ec_key*) key;
2074     struct ec_point remote;
2075
2076     remote.curve = ec->publicKey.curve;
2077     remote.infinity = 0;
2078     if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
2079         return NULL;
2080     }
2081
2082     return ecdh_calculate(ec->privateKey, &remote);
2083 }
2084
2085 void ssh_ecdhkex_freekey(void *key)
2086 {
2087     ecdsa_freekey(key);
2088 }
2089
2090 static const struct ssh_kex ssh_ec_kex_nistp256 = {
2091     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha256
2092 };
2093
2094 static const struct ssh_kex ssh_ec_kex_nistp384 = {
2095     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha384
2096 };
2097
2098 static const struct ssh_kex ssh_ec_kex_nistp521 = {
2099     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha512
2100 };
2101
2102 static const struct ssh_kex *const ec_kex_list[] = {
2103     &ssh_ec_kex_nistp256,
2104     &ssh_ec_kex_nistp384,
2105     &ssh_ec_kex_nistp521
2106 };
2107
2108 const struct ssh_kexes ssh_ecdh_kex = {
2109     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
2110     ec_kex_list
2111 };