]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
Const-correctness in public-key functions.
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  * NOTE: Only curves on prime field are handled by the maths functions
5  */
6
7 /*
8  * References:
9  *
10  * Elliptic curves in SSH are specified in RFC 5656:
11  *   http://tools.ietf.org/html/rfc5656
12  *
13  * That specification delegates details of public key formatting and a
14  * lot of underlying mechanism to SEC 1:
15  *   http://www.secg.org/sec1-v2.pdf
16  */
17
18 #include <stdlib.h>
19 #include <assert.h>
20
21 #include "ssh.h"
22
23 /* ----------------------------------------------------------------------
24  * Elliptic curve definitions
25  */
26
27 static int initialise_curve(struct ec_curve *curve, int bits, unsigned char *p,
28                             unsigned char *a, unsigned char *b,
29                             unsigned char *n, unsigned char *Gx,
30                             unsigned char *Gy)
31 {
32     int length = bits / 8;
33     if (bits % 8) ++length;
34
35     curve->fieldBits = bits;
36     curve->p = bignum_from_bytes(p, length);
37     if (!curve->p) goto error;
38
39     /* Curve co-efficients */
40     curve->a = bignum_from_bytes(a, length);
41     if (!curve->a) goto error;
42     curve->b = bignum_from_bytes(b, length);
43     if (!curve->b) goto error;
44
45     /* Group order and generator */
46     curve->n = bignum_from_bytes(n, length);
47     if (!curve->n) goto error;
48     curve->G.x = bignum_from_bytes(Gx, length);
49     if (!curve->G.x) goto error;
50     curve->G.y = bignum_from_bytes(Gy, length);
51     if (!curve->G.y) goto error;
52     curve->G.curve = curve;
53     curve->G.infinity = 0;
54
55     return 1;
56   error:
57     if (curve->p) freebn(curve->p);
58     if (curve->a) freebn(curve->a);
59     if (curve->b) freebn(curve->b);
60     if (curve->n) freebn(curve->n);
61     if (curve->G.x) freebn(curve->G.x);
62     return 0;
63 }
64
65 unsigned char nistp256_oid[] = {0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07};
66 int nistp256_oid_len = 8;
67 unsigned char nistp384_oid[] = {0x2b, 0x81, 0x04, 0x00, 0x22};
68 int nistp384_oid_len = 5;
69 unsigned char nistp521_oid[] = {0x2b, 0x81, 0x04, 0x00, 0x23};
70 int nistp521_oid_len = 5;
71
72 struct ec_curve *ec_p256(void)
73 {
74     static struct ec_curve curve = { 0 };
75     static unsigned char initialised = 0;
76
77     if (!initialised)
78     {
79         unsigned char p[] = {
80             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
81             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
82             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
83             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
84         };
85         unsigned char a[] = {
86             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
87             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
88             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
89             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
90         };
91         unsigned char b[] = {
92             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
93             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
94             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
95             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
96         };
97         unsigned char n[] = {
98             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
99             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
100             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
101             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
102         };
103         unsigned char Gx[] = {
104             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
105             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
106             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
107             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
108         };
109         unsigned char Gy[] = {
110             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
111             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
112             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
113             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
114         };
115
116         if (!initialise_curve(&curve, 256, p, a, b, n, Gx, Gy)) {
117             return NULL;
118         }
119
120         /* Now initialised, no need to do it again */
121         initialised = 1;
122     }
123
124     return &curve;
125 }
126
127 struct ec_curve *ec_p384(void)
128 {
129     static struct ec_curve curve = { 0 };
130     static unsigned char initialised = 0;
131
132     if (!initialised)
133     {
134         unsigned char p[] = {
135             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
136             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
137             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
138             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
139             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
140             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
141         };
142         unsigned char a[] = {
143             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
144             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
145             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
146             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
147             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
148             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
149         };
150         unsigned char b[] = {
151             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
152             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
153             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
154             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
155             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
156             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
157         };
158         unsigned char n[] = {
159             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
160             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
161             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
162             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
163             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
164             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
165         };
166         unsigned char Gx[] = {
167             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
168             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
169             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
170             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
171             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
172             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
173         };
174         unsigned char Gy[] = {
175             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
176             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
177             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
178             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
179             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
180             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
181         };
182
183         if (!initialise_curve(&curve, 384, p, a, b, n, Gx, Gy)) {
184             return NULL;
185         }
186
187         /* Now initialised, no need to do it again */
188         initialised = 1;
189     }
190
191     return &curve;
192 }
193
194 struct ec_curve *ec_p521(void)
195 {
196     static struct ec_curve curve = { 0 };
197     static unsigned char initialised = 0;
198
199     if (!initialised)
200     {
201         unsigned char p[] = {
202             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
203             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
204             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
205             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
206             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
207             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
208             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
209             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
210             0xff, 0xff
211         };
212         unsigned char a[] = {
213             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
214             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
215             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
216             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
217             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
218             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
219             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
220             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
221             0xff, 0xfc
222         };
223         unsigned char b[] = {
224             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
225             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
226             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
227             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
228             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
229             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
230             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
231             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
232             0x3f, 0x00
233         };
234         unsigned char n[] = {
235             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
236             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
237             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
238             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
239             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
240             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
241             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
242             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
243             0x64, 0x09
244         };
245         unsigned char Gx[] = {
246             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
247             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
248             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
249             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
250             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
251             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
252             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
253             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
254             0xbd, 0x66
255         };
256         unsigned char Gy[] = {
257             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
258             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
259             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
260             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
261             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
262             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
263             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
264             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
265             0x66, 0x50
266         };
267
268         if (!initialise_curve(&curve, 521, p, a, b, n, Gx, Gy)) {
269             return NULL;
270         }
271
272         /* Now initialised, no need to do it again */
273         initialised = 1;
274     }
275
276     return &curve;
277 }
278
279 static struct ec_curve *ec_name_to_curve(const char *name, int len) {
280     if (len == 8 && !memcmp(name, "nistp", 5)) {
281         name += 5;
282         if (!memcmp(name, "256", 3)) {
283             return ec_p256();
284         } else if (!memcmp(name, "384", 3)) {
285             return ec_p384();
286         } else if (!memcmp(name, "521", 3)) {
287             return ec_p521();
288         }
289     }
290
291     return NULL;
292 }
293
294 static int ec_curve_to_name(const struct ec_curve *curve, unsigned char *name, int len) {
295     /* Return length of string */
296     if (name == NULL) return 8;
297
298     /* Not enough space for the name */
299     if (len < 8) return 0;
300
301     /* Put the name in the buffer */
302     switch (curve->fieldBits) {
303       case 256:
304         memcpy(name+5, "256", 3);
305         break;
306       case 384:
307         memcpy(name+5, "384", 3);
308         break;
309       case 521:
310         memcpy(name+5, "521", 3);
311         break;
312       default:
313         return 0;
314     }
315
316     memcpy(name, "nistp", 5);
317     return 8;
318 }
319
320 /* Return 1 if a is -3 % p, otherwise return 0
321  * This is used because there are some maths optimisations */
322 static int ec_aminus3(const struct ec_curve *curve)
323 {
324     int ret;
325     Bignum _p;
326
327     _p = bignum_add_long(curve->a, 3);
328     if (!_p) return 0;
329
330     ret = !bignum_cmp(curve->p, _p);
331     freebn(_p);
332     return ret;
333 }
334
335 /* ----------------------------------------------------------------------
336  * Elliptic curve field maths
337  */
338
339 static Bignum ecf_add(const Bignum a, const Bignum b,
340                       const struct ec_curve *curve)
341 {
342     Bignum a1, b1, ab, ret;
343
344     a1 = bigmod(a, curve->p);
345     if (!a1) return NULL;
346     b1 = bigmod(b, curve->p);
347     if (!b1)
348     {
349         freebn(a1);
350         return NULL;
351     }
352
353     ab = bigadd(a1, b1);
354     freebn(a1);
355     freebn(b1);
356     if (!ab) return NULL;
357
358     ret = bigmod(ab, curve->p);
359     freebn(ab);
360
361     return ret;
362 }
363
364 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
365 {
366     return modmul(a, a, curve->p);
367 }
368
369 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
370 {
371     Bignum ret, tmp;
372
373     /* Double */
374     tmp = bignum_lshift(a, 1);
375     if (!tmp) return NULL;
376
377     /* Add itself (i.e. treble) */
378     ret = bigadd(tmp, a);
379     freebn(tmp);
380
381     /* Normalise */
382     while (ret != NULL && bignum_cmp(ret, curve->p) >= 0)
383     {
384         tmp = bigsub(ret, curve->p);
385         freebn(ret);
386         ret = tmp;
387     }
388
389     return ret;
390 }
391
392 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
393 {
394     Bignum ret = bignum_lshift(a, 1);
395     if (!ret) return NULL;
396     if (bignum_cmp(ret, curve->p) >= 0)
397     {
398         Bignum tmp = bigsub(ret, curve->p);
399         freebn(ret);
400         return tmp;
401     }
402     else
403     {
404         return ret;
405     }
406 }
407
408 /* ----------------------------------------------------------------------
409  * Memory functions
410  */
411
412 void ec_point_free(struct ec_point *point)
413 {
414     if (point == NULL) return;
415     point->curve = 0;
416     if (point->x) freebn(point->x);
417     if (point->y) freebn(point->y);
418     if (point->z) freebn(point->z);
419     point->infinity = 0;
420     sfree(point);
421 }
422
423 static struct ec_point *ec_point_new(const struct ec_curve *curve,
424                                      const Bignum x, const Bignum y, const Bignum z,
425                                      unsigned char infinity)
426 {
427     struct ec_point *point = snewn(1, struct ec_point);
428     point->curve = curve;
429     point->x = x;
430     point->y = y;
431     point->z = z;
432     point->infinity = infinity ? 1 : 0;
433     return point;
434 }
435
436 static struct ec_point *ec_point_copy(const struct ec_point *a)
437 {
438     if (a == NULL) return NULL;
439     return ec_point_new(a->curve,
440                         a->x ? copybn(a->x) : NULL,
441                         a->y ? copybn(a->y) : NULL,
442                         a->z ? copybn(a->z) : NULL,
443                         a->infinity);
444 }
445
446 static int ec_point_verify(const struct ec_point *a)
447 {
448     if (a->infinity)
449     {
450         return 1;
451     }
452     else
453     {
454         /* Verify y^2 = x^3 + ax + b */
455         int ret = 0;
456
457         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
458
459         Bignum Three = bignum_from_long(3);
460         if (!Three) return 0;
461
462         lhs = modmul(a->y, a->y, a->curve->p);
463         if (!lhs) goto error;
464
465         /* This uses montgomery multiplication to optimise */
466         x3 = modpow(a->x, Three, a->curve->p);
467         freebn(Three);
468         if (!x3) goto error;
469         ax = modmul(a->curve->a, a->x, a->curve->p);
470         if (!ax) goto error;
471         x3ax = bigadd(x3, ax);
472         if (!x3ax) goto error;
473         freebn(x3); x3 = NULL;
474         freebn(ax); ax = NULL;
475         x3axm = bigmod(x3ax, a->curve->p);
476         if (!x3axm) goto error;
477         freebn(x3ax); x3ax = NULL;
478         x3axb = bigadd(x3axm, a->curve->b);
479         if (!x3axb) goto error;
480         freebn(x3axm); x3axm = NULL;
481         rhs = bigmod(x3axb, a->curve->p);
482         if (!rhs) goto error;
483         freebn(x3axb);
484
485         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
486         freebn(lhs);
487         freebn(rhs);
488
489         return ret;
490
491       error:
492         if (x3) freebn(x3);
493         if (ax) freebn(ax);
494         if (x3ax) freebn(x3ax);
495         if (x3axm) freebn(x3axm);
496         if (x3axb) freebn(x3axb);
497         if (lhs) freebn(lhs);
498         return 0;
499     }
500 }
501
502 /* ----------------------------------------------------------------------
503  * Elliptic curve point maths
504  */
505
506 /* Returns 1 on success and 0 on memory error */
507 static int ecp_normalise(struct ec_point *a)
508 {
509     Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
510
511     /* In Jacobian Coordinates the triple (X, Y, Z) represents
512        the affine point (X / Z^2, Y / Z^3) */
513     if (!a) {
514         /* No point */
515         return 0;
516     }
517     if (a->infinity) {
518         /* Point is at infinity - i.e. normalised */
519         return 1;
520     } else if (!a->x || !a->y) {
521         /* No point defined */
522         return 0;
523     } else if (!a->z) {
524         /* Already normalised */
525         return 1;
526     }
527
528     Z2 = ecf_square(a->z, a->curve);
529     if (!Z2) {
530         return 0;
531     }
532     Z2inv = modinv(Z2, a->curve->p);
533     if (!Z2inv) {
534         freebn(Z2);
535         return 0;
536     }
537     tx = modmul(a->x, Z2inv, a->curve->p);
538     freebn(Z2inv);
539     if (!tx) {
540         freebn(Z2);
541         return 0;
542     }
543
544     Z3 = modmul(Z2, a->z, a->curve->p);
545     freebn(Z2);
546     if (!Z3) {
547         freebn(tx);
548         return 0;
549     }
550     Z3inv = modinv(Z3, a->curve->p);
551     freebn(Z3);
552     if (!Z3inv) {
553         freebn(tx);
554         return 0;
555     }
556     ty = modmul(a->y, Z3inv, a->curve->p);
557     freebn(Z3inv);
558     if (!ty) {
559         freebn(tx);
560         return 0;
561     }
562
563     freebn(a->x);
564     a->x = tx;
565     freebn(a->y);
566     a->y = ty;
567     freebn(a->z);
568     a->z = NULL;
569     return 1;
570 }
571
572 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
573 {
574     Bignum S, M, outx, outy, outz;
575
576     if (a->infinity || bignum_cmp(a->y, Zero) == 0)
577     {
578         /* Identity */
579         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
580     }
581
582     /* S = 4*X*Y^2 */
583     {
584         Bignum Y2, XY2, _2XY2;
585
586         Y2 = ecf_square(a->y, a->curve);
587         if (!Y2) {
588             return NULL;
589         }
590         XY2 = modmul(a->x, Y2, a->curve->p);
591         freebn(Y2);
592         if (!XY2) {
593             return NULL;
594         }
595
596         _2XY2 = ecf_double(XY2, a->curve);
597         freebn(XY2);
598         if (!_2XY2) {
599             return NULL;
600         }
601         S = ecf_double(_2XY2, a->curve);
602         freebn(_2XY2);
603         if (!S) {
604             return NULL;
605         }
606     }
607
608     /* Faster calculation if a = -3 */
609     if (aminus3) {
610         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
611         Bignum Z2, XpZ2, XmZ2, second;
612
613         if (a->z == NULL) {
614             Z2 = copybn(One);
615         } else {
616             Z2 = ecf_square(a->z, a->curve);
617         }
618         if (!Z2) {
619             freebn(S);
620             return NULL;
621         }
622
623         XpZ2 = ecf_add(a->x, Z2, a->curve);
624         if (!XpZ2) {
625             freebn(S);
626             freebn(Z2);
627             return NULL;
628         }
629         XmZ2 = modsub(a->x, Z2, a->curve->p);
630         freebn(Z2);
631         if (!XmZ2) {
632             freebn(S);
633             freebn(XpZ2);
634             return NULL;
635         }
636
637         second = modmul(XpZ2, XmZ2, a->curve->p);
638         freebn(XpZ2);
639         freebn(XmZ2);
640         if (!second) {
641             freebn(S);
642             return NULL;
643         }
644
645         M = ecf_treble(second, a->curve);
646         freebn(second);
647         if (!M) {
648             freebn(S);
649             return NULL;
650         }
651     } else {
652         /* M = 3*X^2 + a*Z^4 */
653         Bignum _3X2, X2, aZ4;
654
655         if (a->z == NULL) {
656             aZ4 = copybn(a->curve->a);
657         } else {
658             Bignum Z2, Z4;
659
660             Z2 = ecf_square(a->z, a->curve);
661             if (!Z2) {
662                 freebn(S);
663                 return NULL;
664             }
665             Z4 = ecf_square(Z2, a->curve);
666             freebn(Z2);
667             if (!Z4) {
668                 freebn(S);
669                 return NULL;
670             }
671             aZ4 = modmul(a->curve->a, Z4, a->curve->p);
672             freebn(Z4);
673         }
674         if (!aZ4) {
675             freebn(S);
676             return NULL;
677         }
678
679         X2 = modmul(a->x, a->x, a->curve->p);
680         if (!X2) {
681             freebn(S);
682             freebn(aZ4);
683             return NULL;
684         }
685         _3X2 = ecf_treble(X2, a->curve);
686         freebn(X2);
687         if (!_3X2) {
688             freebn(S);
689             freebn(aZ4);
690             return NULL;
691         }
692         M = ecf_add(_3X2, aZ4, a->curve);
693         freebn(_3X2);
694         freebn(aZ4);
695         if (!M) {
696             freebn(S);
697             return NULL;
698         }
699     }
700
701     /* X' = M^2 - 2*S */
702     {
703         Bignum M2, _2S;
704
705         M2 = ecf_square(M, a->curve);
706         if (!M2) {
707             freebn(S);
708             freebn(M);
709             return NULL;
710         }
711
712         _2S = ecf_double(S, a->curve);
713         if (!_2S) {
714             freebn(M2);
715             freebn(S);
716             freebn(M);
717             return NULL;
718         }
719
720         outx = modsub(M2, _2S, a->curve->p);
721         freebn(M2);
722         freebn(_2S);
723         if (!outx) {
724             freebn(S);
725             freebn(M);
726             return NULL;
727         }
728     }
729
730     /* Y' = M*(S - X') - 8*Y^4 */
731     {
732         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
733
734         SX = modsub(S, outx, a->curve->p);
735         freebn(S);
736         if (!SX) {
737             freebn(M);
738             freebn(outx);
739             return NULL;
740         }
741         MSX = modmul(M, SX, a->curve->p);
742         freebn(SX);
743         freebn(M);
744         if (!MSX) {
745             freebn(outx);
746             return NULL;
747         }
748         Y2 = ecf_square(a->y, a->curve);
749         if (!Y2) {
750             freebn(outx);
751             freebn(MSX);
752             return NULL;
753         }
754         Y4 = ecf_square(Y2, a->curve);
755         freebn(Y2);
756         if (!Y4) {
757             freebn(outx);
758             freebn(MSX);
759             return NULL;
760         }
761         Eight = bignum_from_long(8);
762         if (!Eight) {
763             freebn(outx);
764             freebn(MSX);
765             freebn(Y4);
766             return NULL;
767         }
768         _8Y4 = modmul(Eight, Y4, a->curve->p);
769         freebn(Eight);
770         freebn(Y4);
771         if (!_8Y4) {
772             freebn(outx);
773             freebn(MSX);
774             return NULL;
775         }
776         outy = modsub(MSX, _8Y4, a->curve->p);
777         freebn(MSX);
778         freebn(_8Y4);
779         if (!outy) {
780             freebn(outx);
781             return NULL;
782         }
783     }
784
785     /* Z' = 2*Y*Z */
786     {
787         Bignum YZ;
788
789         if (a->z == NULL) {
790             YZ = copybn(a->y);
791         } else {
792             YZ = modmul(a->y, a->z, a->curve->p);
793         }
794         if (!YZ) {
795             freebn(outx);
796             freebn(outy);
797             return NULL;
798         }
799
800         outz = ecf_double(YZ, a->curve);
801         freebn(YZ);
802         if (!outz) {
803             freebn(outx);
804             freebn(outy);
805             return NULL;
806         }
807     }
808
809     return ec_point_new(a->curve, outx, outy, outz, 0);
810 }
811
812 static struct ec_point *ecp_add(const struct ec_point *a,
813                                 const struct ec_point *b,
814                                 const int aminus3)
815 {
816     Bignum U1, U2, S1, S2, outx, outy, outz;
817
818     /* Check if multiplying by infinity */
819     if (a->infinity) return ec_point_copy(b);
820     if (b->infinity) return ec_point_copy(a);
821
822     /* U1 = X1*Z2^2 */
823     /* S1 = Y1*Z2^3 */
824     if (b->z) {
825         Bignum Z2, Z3;
826
827         Z2 = ecf_square(b->z, a->curve);
828         if (!Z2) {
829             return NULL;
830         }
831         U1 = modmul(a->x, Z2, a->curve->p);
832         if (!U1) {
833             freebn(Z2);
834             return NULL;
835         }
836         Z3 = modmul(Z2, b->z, a->curve->p);
837         freebn(Z2);
838         if (!Z3) {
839             freebn(U1);
840             return NULL;
841         }
842         S1 = modmul(a->y, Z3, a->curve->p);
843         freebn(Z3);
844         if (!S1) {
845             freebn(U1);
846             return NULL;
847         }
848     } else {
849         U1 = copybn(a->x);
850         if (!U1) {
851             return NULL;
852         }
853         S1 = copybn(a->y);
854         if (!S1) {
855             freebn(U1);
856             return NULL;
857         }
858     }
859
860     /* U2 = X2*Z1^2 */
861     /* S2 = Y2*Z1^3 */
862     if (a->z) {
863         Bignum Z2, Z3;
864
865         Z2 = ecf_square(a->z, b->curve);
866         if (!Z2) {
867             freebn(U1);
868             freebn(S1);
869             return NULL;
870         }
871         U2 = modmul(b->x, Z2, b->curve->p);
872         if (!U2) {
873             freebn(U1);
874             freebn(S1);
875             freebn(Z2);
876             return NULL;
877         }
878         Z3 = modmul(Z2, a->z, b->curve->p);
879         freebn(Z2);
880         if (!Z3) {
881             freebn(U1);
882             freebn(S1);
883             freebn(U2);
884             return NULL;
885         }
886         S2 = modmul(b->y, Z3, b->curve->p);
887         freebn(Z3);
888         if (!S2) {
889             freebn(U1);
890             freebn(S1);
891             freebn(U2);
892             return NULL;
893         }
894     } else {
895         U2 = copybn(b->x);
896         if (!U2) {
897             freebn(U1);
898             freebn(S1);
899             return NULL;
900         }
901         S2 = copybn(b->y);
902         if (!S2) {
903             freebn(U1);
904             freebn(S1);
905             freebn(U2);
906             return NULL;
907         }
908     }
909
910     /* Check if multiplying by self */
911     if (bignum_cmp(U1, U2) == 0)
912     {
913         freebn(U1);
914         freebn(U2);
915         if (bignum_cmp(S1, S2) == 0)
916         {
917             freebn(S1);
918             freebn(S2);
919             return ecp_double(a, aminus3);
920         }
921         else
922         {
923             freebn(S1);
924             freebn(S2);
925             /* Infinity */
926             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
927         }
928     }
929
930     {
931         Bignum H, R, UH2, H3;
932
933         /* H = U2 - U1 */
934         H = modsub(U2, U1, a->curve->p);
935         freebn(U2);
936         if (!H) {
937             freebn(U1);
938             freebn(S1);
939             freebn(S2);
940             return NULL;
941         }
942
943         /* R = S2 - S1 */
944         R = modsub(S2, S1, a->curve->p);
945         freebn(S2);
946         if (!R) {
947             freebn(H);
948             freebn(S1);
949             freebn(U1);
950             return NULL;
951         }
952
953         /* X3 = R^2 - H^3 - 2*U1*H^2 */
954         {
955             Bignum R2, H2, _2UH2, first;
956
957             H2 = ecf_square(H, a->curve);
958             if (!H2) {
959                 freebn(U1);
960                 freebn(S1);
961                 freebn(H);
962                 freebn(R);
963                 return NULL;
964             }
965             UH2 = modmul(U1, H2, a->curve->p);
966             freebn(U1);
967             if (!UH2) {
968                 freebn(H2);
969                 freebn(S1);
970                 freebn(H);
971                 freebn(R);
972                 return NULL;
973             }
974             H3 = modmul(H2, H, a->curve->p);
975             freebn(H2);
976             if (!H3) {
977                 freebn(UH2);
978                 freebn(S1);
979                 freebn(H);
980                 freebn(R);
981                 return NULL;
982             }
983             R2 = ecf_square(R, a->curve);
984             if (!R2) {
985                 freebn(H3);
986                 freebn(UH2);
987                 freebn(S1);
988                 freebn(H);
989                 freebn(R);
990                 return NULL;
991             }
992             _2UH2 = ecf_double(UH2, a->curve);
993             if (!_2UH2) {
994                 freebn(R2);
995                 freebn(H3);
996                 freebn(UH2);
997                 freebn(S1);
998                 freebn(H);
999                 freebn(R);
1000                 return NULL;
1001             }
1002             first = modsub(R2, H3, a->curve->p);
1003             freebn(R2);
1004             if (!first) {
1005                 freebn(H3);
1006                 freebn(_2UH2);
1007                 freebn(UH2);
1008                 freebn(S1);
1009                 freebn(H);
1010                 freebn(R);
1011                 return NULL;
1012             }
1013             outx = modsub(first, _2UH2, a->curve->p);
1014             freebn(first);
1015             freebn(_2UH2);
1016             if (!outx) {
1017                 freebn(H3);
1018                 freebn(UH2);
1019                 freebn(S1);
1020                 freebn(H);
1021                 freebn(R);
1022                 return NULL;
1023             }
1024         }
1025
1026         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
1027         {
1028             Bignum RUH2mX, UH2mX, SH3;
1029
1030             UH2mX = modsub(UH2, outx, a->curve->p);
1031             freebn(UH2);
1032             if (!UH2mX) {
1033                 freebn(outx);
1034                 freebn(H3);
1035                 freebn(S1);
1036                 freebn(H);
1037                 freebn(R);
1038                 return NULL;
1039             }
1040             RUH2mX = modmul(R, UH2mX, a->curve->p);
1041             freebn(UH2mX);
1042             freebn(R);
1043             if (!RUH2mX) {
1044                 freebn(outx);
1045                 freebn(H3);
1046                 freebn(S1);
1047                 freebn(H);
1048                 return NULL;
1049             }
1050             SH3 = modmul(S1, H3, a->curve->p);
1051             freebn(S1);
1052             freebn(H3);
1053             if (!SH3) {
1054                 freebn(RUH2mX);
1055                 freebn(outx);
1056                 freebn(H);
1057                 return NULL;
1058             }
1059
1060             outy = modsub(RUH2mX, SH3, a->curve->p);
1061             freebn(RUH2mX);
1062             freebn(SH3);
1063             if (!outy) {
1064                 freebn(outx);
1065                 freebn(H);
1066                 return NULL;
1067             }
1068         }
1069
1070         /* Z3 = H*Z1*Z2 */
1071         if (a->z && b->z) {
1072             Bignum ZZ;
1073
1074             ZZ = modmul(a->z, b->z, a->curve->p);
1075             if (!ZZ) {
1076                 freebn(outx);
1077                 freebn(outy);
1078                 freebn(H);
1079                 return NULL;
1080             }
1081             outz = modmul(H, ZZ, a->curve->p);
1082             freebn(H);
1083             freebn(ZZ);
1084             if (!outz) {
1085                 freebn(outx);
1086                 freebn(outy);
1087                 return NULL;
1088             }
1089         } else if (a->z) {
1090             outz = modmul(H, a->z, a->curve->p);
1091             freebn(H);
1092             if (!outz) {
1093                 freebn(outx);
1094                 freebn(outy);
1095                 return NULL;
1096             }
1097         } else if (b->z) {
1098             outz = modmul(H, b->z, a->curve->p);
1099             freebn(H);
1100             if (!outz) {
1101                 freebn(outx);
1102                 freebn(outy);
1103                 return NULL;
1104             }
1105         } else {
1106             outz = H;
1107         }
1108     }
1109
1110     return ec_point_new(a->curve, outx, outy, outz, 0);
1111 }
1112
1113 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1114 {
1115     struct ec_point *A, *ret;
1116     int bits, i;
1117
1118     A = ec_point_copy(a);
1119     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1120
1121     bits = bignum_bitcount(b);
1122     for (i = 0; ret != NULL && A != NULL && i < bits; ++i)
1123     {
1124         if (bignum_bit(b, i))
1125         {
1126             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1127             ec_point_free(ret);
1128             ret = tmp;
1129         }
1130         if (i+1 != bits)
1131         {
1132             struct ec_point *tmp = ecp_double(A, aminus3);
1133             ec_point_free(A);
1134             A = tmp;
1135         }
1136     }
1137
1138     if (!A) {
1139         ec_point_free(ret);
1140         ret = NULL;
1141     } else {
1142         ec_point_free(A);
1143     }
1144
1145     return ret;
1146 }
1147
1148 /* Not static because it is used by sshecdsag.c to generate a new key */
1149 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1150 {
1151     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1152
1153     if (!ecp_normalise(ret)) {
1154         ec_point_free(ret);
1155         return NULL;
1156     }
1157
1158     return ret;
1159 }
1160
1161 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
1162                                    const struct ec_point *point)
1163 {
1164     struct ec_point *aG, *bP, *ret;
1165     int aminus3 = ec_aminus3(point->curve);
1166
1167     aG = ecp_mul_(&point->curve->G, a, aminus3);
1168     if (!aG) return NULL;
1169     bP = ecp_mul_(point, b, aminus3);
1170     if (!bP) {
1171         ec_point_free(aG);
1172         return NULL;
1173     }
1174
1175     ret = ecp_add(aG, bP, aminus3);
1176
1177     ec_point_free(aG);
1178     ec_point_free(bP);
1179
1180     if (!ecp_normalise(ret)) {
1181         ec_point_free(ret);
1182         return NULL;
1183     }
1184
1185     return ret;
1186 }
1187
1188 /* ----------------------------------------------------------------------
1189  * Basic sign and verify routines
1190  */
1191
1192 static int _ecdsa_verify(const struct ec_point *publicKey,
1193                          const unsigned char *data, const int dataLen,
1194                          const Bignum r, const Bignum s)
1195 {
1196     int z_bits, n_bits;
1197     Bignum z;
1198     int valid = 0;
1199
1200     /* Sanity checks */
1201     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->n) >= 0
1202         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->n) >= 0)
1203     {
1204         return 0;
1205     }
1206
1207     /* z = left most bitlen(curve->n) of data */
1208     z = bignum_from_bytes(data, dataLen);
1209     if (!z) return 0;
1210     n_bits = bignum_bitcount(publicKey->curve->n);
1211     z_bits = bignum_bitcount(z);
1212     if (z_bits > n_bits)
1213     {
1214         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
1215         freebn(z);
1216         z = tmp;
1217         if (!z) return 0;
1218     }
1219
1220     /* Ensure z in range of n */
1221     {
1222         Bignum tmp = bigmod(z, publicKey->curve->n);
1223         freebn(z);
1224         z = tmp;
1225         if (!z) return 0;
1226     }
1227
1228     /* Calculate signature */
1229     {
1230         Bignum w, x, u1, u2;
1231         struct ec_point *tmp;
1232
1233         w = modinv(s, publicKey->curve->n);
1234         if (!w) {
1235             freebn(z);
1236             return 0;
1237         }
1238         u1 = modmul(z, w, publicKey->curve->n);
1239         if (!u1) {
1240             freebn(z);
1241             freebn(w);
1242             return 0;
1243         }
1244         u2 = modmul(r, w, publicKey->curve->n);
1245         freebn(w);
1246         if (!u2) {
1247             freebn(z);
1248             freebn(u1);
1249             return 0;
1250         }
1251
1252         tmp = ecp_summul(u1, u2, publicKey);
1253         freebn(u1);
1254         freebn(u2);
1255         if (!tmp) {
1256             freebn(z);
1257             return 0;
1258         }
1259
1260         x = bigmod(tmp->x, publicKey->curve->n);
1261         ec_point_free(tmp);
1262         if (!x) {
1263             freebn(z);
1264             return 0;
1265         }
1266
1267         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
1268         freebn(x);
1269     }
1270
1271     freebn(z);
1272
1273     return valid;
1274 }
1275
1276 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
1277                         const unsigned char *data, const int dataLen,
1278                         Bignum *r, Bignum *s)
1279 {
1280     unsigned char digest[20];
1281     int z_bits, n_bits;
1282     Bignum z, k;
1283     struct ec_point *kG;
1284
1285     *r = NULL;
1286     *s = NULL;
1287
1288     /* z = left most bitlen(curve->n) of data */
1289     z = bignum_from_bytes(data, dataLen);
1290     if (!z) return;
1291     n_bits = bignum_bitcount(curve->n);
1292     z_bits = bignum_bitcount(z);
1293     if (z_bits > n_bits)
1294     {
1295         Bignum tmp;
1296         tmp = bignum_rshift(z, z_bits - n_bits);
1297         freebn(z);
1298         z = tmp;
1299         if (!z) return;
1300     }
1301
1302     /* Generate k between 1 and curve->n, using the same deterministic
1303      * k generation system we use for conventional DSA. */
1304     SHA_Simple(data, dataLen, digest);
1305     k = dss_gen_k("ECDSA deterministic k generator", curve->n, privateKey,
1306                   digest, sizeof(digest));
1307     if (!k) return;
1308
1309     kG = ecp_mul(&curve->G, k);
1310     if (!kG) {
1311         freebn(z);
1312         freebn(k);
1313         return;
1314     }
1315
1316     /* r = kG.x mod n */
1317     *r = bigmod(kG->x, curve->n);
1318     ec_point_free(kG);
1319     if (!*r) {
1320         freebn(z);
1321         freebn(k);
1322         return;
1323     }
1324
1325     /* s = (z + r * priv)/k mod n */
1326     {
1327         Bignum rPriv, zMod, first, firstMod, kInv;
1328         rPriv = modmul(*r, privateKey, curve->n);
1329         if (!rPriv) {
1330             freebn(*r);
1331             freebn(z);
1332             freebn(k);
1333             return;
1334         }
1335         zMod = bigmod(z, curve->n);
1336         freebn(z);
1337         if (!zMod) {
1338             freebn(rPriv);
1339             freebn(*r);
1340             freebn(k);
1341             return;
1342         }
1343         first = bigadd(rPriv, zMod);
1344         freebn(rPriv);
1345         freebn(zMod);
1346         if (!first) {
1347             freebn(*r);
1348             freebn(k);
1349             return;
1350         }
1351         firstMod = bigmod(first, curve->n);
1352         freebn(first);
1353         if (!firstMod) {
1354             freebn(*r);
1355             freebn(k);
1356             return;
1357         }
1358         kInv = modinv(k, curve->n);
1359         freebn(k);
1360         if (!kInv) {
1361             freebn(firstMod);
1362             freebn(*r);
1363             return;
1364         }
1365         *s = modmul(firstMod, kInv, curve->n);
1366         freebn(firstMod);
1367         freebn(kInv);
1368         if (!*s) {
1369             freebn(*r);
1370             return;
1371         }
1372     }
1373 }
1374
1375 /* ----------------------------------------------------------------------
1376  * Misc functions
1377  */
1378
1379 static void getstring(const char **data, int *datalen,
1380                       const char **p, int *length)
1381 {
1382     *p = NULL;
1383     if (*datalen < 4)
1384         return;
1385     *length = toint(GET_32BIT(*data));
1386     if (*length < 0)
1387         return;
1388     *datalen -= 4;
1389     *data += 4;
1390     if (*datalen < *length)
1391         return;
1392     *p = *data;
1393     *data += *length;
1394     *datalen -= *length;
1395 }
1396
1397 static Bignum getmp(const char **data, int *datalen)
1398 {
1399     const char *p;
1400     int length;
1401
1402     getstring(data, datalen, &p, &length);
1403     if (!p)
1404         return NULL;
1405     if (p[0] & 0x80)
1406         return NULL;                   /* negative mp */
1407     return bignum_from_bytes((unsigned char *)p, length);
1408 }
1409
1410 static int decodepoint(const char *p, int length, struct ec_point *point)
1411 {
1412     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
1413         return 0;
1414     /* Skip compression flag */
1415     ++p;
1416     --length;
1417     /* The two values must be equal length */
1418     if (length % 2 != 0) {
1419         point->x = NULL;
1420         point->y = NULL;
1421         point->z = NULL;
1422         return 0;
1423     }
1424     length = length / 2;
1425     point->x = bignum_from_bytes((unsigned char *)p, length);
1426     if (!point->x) return 0;
1427     p += length;
1428     point->y = bignum_from_bytes((unsigned char *)p, length);
1429     if (!point->y) {
1430         freebn(point->x);
1431         point->x = NULL;
1432         return 0;
1433     }
1434     point->z = NULL;
1435
1436     /* Verify the point is on the curve */
1437     if (!ec_point_verify(point)) {
1438         freebn(point->x);
1439         point->x = NULL;
1440         freebn(point->y);
1441         point->y = NULL;
1442         return 0;
1443     }
1444
1445     return 1;
1446 }
1447
1448 static int getmppoint(const char **data, int *datalen, struct ec_point *point)
1449 {
1450     const char *p;
1451     int length;
1452
1453     getstring(data, datalen, &p, &length);
1454     if (!p) return 0;
1455     return decodepoint(p, length, point);
1456 }
1457
1458 /* ----------------------------------------------------------------------
1459  * Exposed ECDSA interface
1460  */
1461
1462 static void ecdsa_freekey(void *key)
1463 {
1464     struct ec_key *ec = (struct ec_key *) key;
1465     if (!ec) return;
1466
1467     if (ec->publicKey.x)
1468         freebn(ec->publicKey.x);
1469     if (ec->publicKey.y)
1470         freebn(ec->publicKey.y);
1471     if (ec->publicKey.z)
1472         freebn(ec->publicKey.z);
1473     if (ec->privateKey)
1474         freebn(ec->privateKey);
1475     sfree(ec);
1476 }
1477
1478 static void *ecdsa_newkey(const char *data, int len)
1479 {
1480     const char *p;
1481     int slen;
1482     struct ec_key *ec;
1483     struct ec_curve *curve;
1484
1485     getstring(&data, &len, &p, &slen);
1486
1487     if (!p || slen < 11 || memcmp(p, "ecdsa-sha2-", 11)) {
1488         return NULL;
1489     }
1490     curve = ec_name_to_curve(p+11, slen-11);
1491     if (!curve) return NULL;
1492
1493     getstring(&data, &len, &p, &slen);
1494
1495     if (curve != ec_name_to_curve(p, slen)) return NULL;
1496
1497     ec = snew(struct ec_key);
1498
1499     ec->publicKey.curve = curve;
1500     ec->publicKey.infinity = 0;
1501     ec->publicKey.x = NULL;
1502     ec->publicKey.y = NULL;
1503     ec->publicKey.z = NULL;
1504     if (!getmppoint(&data, &len, &ec->publicKey)) {
1505         ecdsa_freekey(ec);
1506         return NULL;
1507     }
1508     ec->privateKey = NULL;
1509
1510     if (!ec->publicKey.x || !ec->publicKey.y ||
1511         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
1512         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
1513     {
1514         ecdsa_freekey(ec);
1515         ec = NULL;
1516     }
1517
1518     return ec;
1519 }
1520
1521 static char *ecdsa_fmtkey(void *key)
1522 {
1523     struct ec_key *ec = (struct ec_key *) key;
1524     char *p;
1525     int len, i, pos, nibbles;
1526     static const char hex[] = "0123456789abcdef";
1527     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
1528         return NULL;
1529
1530     pos = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1531     if (pos == 0) return NULL;
1532
1533     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
1534     len += pos; /* Curve name */
1535     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
1536     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
1537     p = snewn(len, char);
1538
1539     pos = ec_curve_to_name(ec->publicKey.curve, (unsigned char*)p, pos);
1540     pos += sprintf(p + pos, ",0x");
1541     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
1542     if (nibbles < 1)
1543         nibbles = 1;
1544     for (i = nibbles; i--;) {
1545         p[pos++] =
1546             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
1547     }
1548     pos += sprintf(p + pos, ",0x");
1549     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
1550     if (nibbles < 1)
1551         nibbles = 1;
1552     for (i = nibbles; i--;) {
1553         p[pos++] =
1554             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
1555     }
1556     p[pos] = '\0';
1557     return p;
1558 }
1559
1560 static unsigned char *ecdsa_public_blob(void *key, int *len)
1561 {
1562     struct ec_key *ec = (struct ec_key *) key;
1563     int pointlen, bloblen, namelen;
1564     int i;
1565     unsigned char *blob, *p;
1566
1567     namelen = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1568     if (namelen == 0) return NULL;
1569
1570     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1571
1572     /*
1573      * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
1574      */
1575     bloblen = 4 + 11 + namelen + 4 + namelen + 4 + 1 + (pointlen * 2);
1576     blob = snewn(bloblen, unsigned char);
1577
1578     p = blob;
1579     PUT_32BIT(p, 11 + namelen);
1580     p += 4;
1581     memcpy(p, "ecdsa-sha2-", 11);
1582     p += 11;
1583     p += ec_curve_to_name(ec->publicKey.curve, p, namelen);
1584     PUT_32BIT(p, namelen);
1585     p += 4;
1586     p += ec_curve_to_name(ec->publicKey.curve, p, namelen);
1587     PUT_32BIT(p, (2 * pointlen) + 1);
1588     p += 4;
1589     *p++ = 0x04;
1590     for (i = pointlen; i--;)
1591         *p++ = bignum_byte(ec->publicKey.x, i);
1592     for (i = pointlen; i--;)
1593         *p++ = bignum_byte(ec->publicKey.y, i);
1594
1595     assert(p == blob + bloblen);
1596     *len = bloblen;
1597
1598     return blob;
1599 }
1600
1601 static unsigned char *ecdsa_private_blob(void *key, int *len)
1602 {
1603     struct ec_key *ec = (struct ec_key *) key;
1604     int keylen, bloblen;
1605     int i;
1606     unsigned char *blob, *p;
1607
1608     if (!ec->privateKey) return NULL;
1609
1610     keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
1611
1612     /*
1613      * mpint privateKey. Total 4 + keylen.
1614      */
1615     bloblen = 4 + keylen;
1616     blob = snewn(bloblen, unsigned char);
1617
1618     p = blob;
1619     PUT_32BIT(p, keylen);
1620     p += 4;
1621     for (i = keylen; i--;)
1622         *p++ = bignum_byte(ec->privateKey, i);
1623
1624     assert(p == blob + bloblen);
1625     *len = bloblen;
1626     return blob;
1627 }
1628
1629 static void *ecdsa_createkey(const unsigned char *pub_blob, int pub_len,
1630                              const unsigned char *priv_blob, int priv_len)
1631 {
1632     struct ec_key *ec;
1633     struct ec_point *publicKey;
1634     const char *pb = (const char *) priv_blob;
1635
1636     ec = (struct ec_key*)ecdsa_newkey((const char *) pub_blob, pub_len);
1637     if (!ec) {
1638         return NULL;
1639     }
1640
1641     ec->privateKey = getmp(&pb, &priv_len);
1642     if (!ec->privateKey) {
1643         ecdsa_freekey(ec);
1644         return NULL;
1645     }
1646
1647     /* Check that private key generates public key */
1648     publicKey = ecp_mul(&ec->publicKey.curve->G, ec->privateKey);
1649
1650     if (!publicKey ||
1651         bignum_cmp(publicKey->x, ec->publicKey.x) ||
1652         bignum_cmp(publicKey->y, ec->publicKey.y))
1653     {
1654         ecdsa_freekey(ec);
1655         ec = NULL;
1656     }
1657     ec_point_free(publicKey);
1658
1659     return ec;
1660 }
1661
1662 static void *ecdsa_openssh_createkey(const unsigned char **blob, int *len)
1663 {
1664     const char **b = (const char **) blob;
1665     const char *p;
1666     int slen;
1667     struct ec_key *ec;
1668     struct ec_curve *curve;
1669     struct ec_point *publicKey;
1670
1671     getstring(b, len, &p, &slen);
1672
1673     if (!p) {
1674         return NULL;
1675     }
1676     curve = ec_name_to_curve(p, slen);
1677     if (!curve) return NULL;
1678
1679     ec = snew(struct ec_key);
1680
1681     ec->publicKey.curve = curve;
1682     ec->publicKey.infinity = 0;
1683     ec->publicKey.x = NULL;
1684     ec->publicKey.y = NULL;
1685     ec->publicKey.z = NULL;
1686     if (!getmppoint(b, len, &ec->publicKey)) {
1687         ecdsa_freekey(ec);
1688         return NULL;
1689     }
1690     ec->privateKey = NULL;
1691
1692     if (!ec->publicKey.x || !ec->publicKey.y ||
1693         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
1694         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
1695     {
1696         ecdsa_freekey(ec);
1697         return NULL;
1698     }
1699
1700     ec->privateKey = getmp(b, len);
1701     if (ec->privateKey == NULL)
1702     {
1703         ecdsa_freekey(ec);
1704         return NULL;
1705     }
1706
1707     /* Now check that the private key makes the public key */
1708     publicKey = ecp_mul(&ec->publicKey.curve->G, ec->privateKey);
1709     if (!publicKey)
1710     {
1711         ecdsa_freekey(ec);
1712         return NULL;
1713     }
1714
1715     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
1716         bignum_cmp(ec->publicKey.y, publicKey->y))
1717     {
1718         /* Private key doesn't make the public key on the given curve */
1719         ecdsa_freekey(ec);
1720         ec_point_free(publicKey);
1721         return NULL;
1722     }
1723
1724     ec_point_free(publicKey);
1725
1726     return ec;
1727 }
1728
1729 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
1730 {
1731     struct ec_key *ec = (struct ec_key *) key;
1732
1733     int pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1734
1735     int namelen = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1736
1737     int bloblen =
1738         4 + namelen /* <LEN> nistpXXX */
1739         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
1740         + ssh2_bignum_length(ec->privateKey);
1741
1742     int i;
1743
1744     if (bloblen > len)
1745         return bloblen;
1746
1747     bloblen = 0;
1748
1749     PUT_32BIT(blob+bloblen, namelen);
1750     bloblen += 4;
1751
1752     bloblen += ec_curve_to_name(ec->publicKey.curve, blob+bloblen, namelen);
1753
1754     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
1755     bloblen += 4;
1756     blob[bloblen++] = 0x04;
1757     for (i = pointlen; i--; )
1758         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
1759     for (i = pointlen; i--; )
1760         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
1761
1762     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
1763     PUT_32BIT(blob+bloblen, pointlen);
1764     bloblen += 4;
1765     for (i = pointlen; i--; )
1766         blob[bloblen++] = bignum_byte(ec->privateKey, i);
1767
1768     return bloblen;
1769 }
1770
1771 static int ecdsa_pubkey_bits(const void *blob, int len)
1772 {
1773     struct ec_key *ec;
1774     int ret;
1775
1776     ec = (struct ec_key*)ecdsa_newkey((const char *) blob, len);
1777     if (!ec)
1778         return -1;
1779     ret = ec->publicKey.curve->fieldBits;
1780     ecdsa_freekey(ec);
1781
1782     return ret;
1783 }
1784
1785 static char *ecdsa_fingerprint(void *key)
1786 {
1787     struct ec_key *ec = (struct ec_key *) key;
1788     struct MD5Context md5c;
1789     unsigned char digest[16], lenbuf[4];
1790     char *ret;
1791     unsigned char *name;
1792     int pointlen, namelen, i, j;
1793
1794     namelen = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1795     name = snewn(namelen, unsigned char);
1796     ec_curve_to_name(ec->publicKey.curve, name, namelen);
1797
1798     MD5Init(&md5c);
1799
1800     PUT_32BIT(lenbuf, namelen + 11);
1801     MD5Update(&md5c, lenbuf, 4);
1802     MD5Update(&md5c, (const unsigned char *)"ecdsa-sha2-", 11);
1803     MD5Update(&md5c, name, namelen);
1804
1805     PUT_32BIT(lenbuf, namelen);
1806     MD5Update(&md5c, lenbuf, 4);
1807     MD5Update(&md5c, name, namelen);
1808
1809     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1810     PUT_32BIT(lenbuf, 1 + (pointlen * 2));
1811     MD5Update(&md5c, lenbuf, 4);
1812     MD5Update(&md5c, (const unsigned char *)"\x04", 1);
1813     for (i = pointlen; i--; ) {
1814         unsigned char c = bignum_byte(ec->publicKey.x, i);
1815         MD5Update(&md5c, &c, 1);
1816     }
1817     for (i = pointlen; i--; ) {
1818         unsigned char c = bignum_byte(ec->publicKey.y, i);
1819         MD5Update(&md5c, &c, 1);
1820     }
1821
1822     MD5Final(digest, &md5c);
1823
1824     ret = snewn(11 + namelen + 1 + (16 * 3), char);
1825
1826     i = 11;
1827     memcpy(ret, "ecdsa-sha2-", 11);
1828     memcpy(ret+i, name, namelen);
1829     i += namelen;
1830     sfree(name);
1831     ret[i++] = ' ';
1832     for (j = 0; j < 16; j++)
1833         i += sprintf(ret + i, "%s%02x", j ? ":" : "", digest[j]);
1834
1835     return ret;
1836 }
1837
1838 static int ecdsa_verifysig(void *key, const char *sig, int siglen,
1839                            const char *data, int datalen)
1840 {
1841     struct ec_key *ec = (struct ec_key *) key;
1842     const char *p;
1843     int slen;
1844     unsigned char digest[512 / 8];
1845     int digestLen;
1846     Bignum r, s;
1847     int ret;
1848
1849     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
1850         return 0;
1851
1852     /* Check the signature curve matches the key curve */
1853     getstring(&sig, &siglen, &p, &slen);
1854     if (!p || slen < 11 || memcmp(p, "ecdsa-sha2-", 11)) {
1855         return 0;
1856     }
1857     if (ec->publicKey.curve != ec_name_to_curve(p+11, slen-11)) {
1858         return 0;
1859     }
1860
1861     getstring(&sig, &siglen, &p, &slen);
1862     r = getmp(&p, &slen);
1863     if (!r) return 0;
1864     s = getmp(&p, &slen);
1865     if (!s) {
1866         freebn(r);
1867         return 0;
1868     }
1869
1870     /* Perform correct hash function depending on curve size */
1871     if (ec->publicKey.curve->fieldBits <= 256) {
1872         SHA256_Simple(data, datalen, digest);
1873         digestLen = 256 / 8;
1874     } else if (ec->publicKey.curve->fieldBits <= 384) {
1875         SHA384_Simple(data, datalen, digest);
1876         digestLen = 384 / 8;
1877     } else {
1878         SHA512_Simple(data, datalen, digest);
1879         digestLen = 512 / 8;
1880     }
1881
1882     /* Verify the signature */
1883     if (!_ecdsa_verify(&ec->publicKey, digest, digestLen, r, s)) {
1884         ret = 0;
1885     } else {
1886         ret = 1;
1887     }
1888
1889     freebn(r);
1890     freebn(s);
1891
1892     return ret;
1893 }
1894
1895 static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
1896                                  int *siglen)
1897 {
1898     struct ec_key *ec = (struct ec_key *) key;
1899     unsigned char digest[512 / 8];
1900     int digestLen;
1901     Bignum r = NULL, s = NULL;
1902     unsigned char *buf, *p;
1903     int rlen, slen, namelen;
1904     int i;
1905
1906     if (!ec->privateKey || !ec->publicKey.curve) {
1907         return NULL;
1908     }
1909
1910     /* Perform correct hash function depending on curve size */
1911     if (ec->publicKey.curve->fieldBits <= 256) {
1912         SHA256_Simple(data, datalen, digest);
1913         digestLen = 256 / 8;
1914     } else if (ec->publicKey.curve->fieldBits <= 384) {
1915         SHA384_Simple(data, datalen, digest);
1916         digestLen = 384 / 8;
1917     } else {
1918         SHA512_Simple(data, datalen, digest);
1919         digestLen = 512 / 8;
1920     }
1921
1922     /* Do the signature */
1923     _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
1924     if (!r || !s) {
1925         if (r) freebn(r);
1926         if (s) freebn(s);
1927         return NULL;
1928     }
1929
1930     rlen = (bignum_bitcount(r) + 8) / 8;
1931     slen = (bignum_bitcount(s) + 8) / 8;
1932
1933     namelen = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1934
1935     /* Format the output */
1936     *siglen = 8+11+namelen+rlen+slen+8;
1937     buf = snewn(*siglen, unsigned char);
1938     p = buf;
1939     PUT_32BIT(p, 11+namelen);
1940     p += 4;
1941     memcpy(p, "ecdsa-sha2-", 11);
1942     p += 11;
1943     p += ec_curve_to_name(ec->publicKey.curve, p, namelen);
1944     PUT_32BIT(p, rlen + slen + 8);
1945     p += 4;
1946     PUT_32BIT(p, rlen);
1947     p += 4;
1948     for (i = rlen; i--;)
1949         *p++ = bignum_byte(r, i);
1950     PUT_32BIT(p, slen);
1951     p += 4;
1952     for (i = slen; i--;)
1953         *p++ = bignum_byte(s, i);
1954
1955     freebn(r);
1956     freebn(s);
1957
1958     return buf;
1959 }
1960
1961 const struct ssh_signkey ssh_ecdsa_nistp256 = {
1962     ecdsa_newkey,
1963     ecdsa_freekey,
1964     ecdsa_fmtkey,
1965     ecdsa_public_blob,
1966     ecdsa_private_blob,
1967     ecdsa_createkey,
1968     ecdsa_openssh_createkey,
1969     ecdsa_openssh_fmtkey,
1970     3 /* curve name, point, private exponent */,
1971     ecdsa_pubkey_bits,
1972     ecdsa_fingerprint,
1973     ecdsa_verifysig,
1974     ecdsa_sign,
1975     "ecdsa-sha2-nistp256",
1976     "ecdsa-sha2-nistp256",
1977 };
1978
1979 const struct ssh_signkey ssh_ecdsa_nistp384 = {
1980     ecdsa_newkey,
1981     ecdsa_freekey,
1982     ecdsa_fmtkey,
1983     ecdsa_public_blob,
1984     ecdsa_private_blob,
1985     ecdsa_createkey,
1986     ecdsa_openssh_createkey,
1987     ecdsa_openssh_fmtkey,
1988     3 /* curve name, point, private exponent */,
1989     ecdsa_pubkey_bits,
1990     ecdsa_fingerprint,
1991     ecdsa_verifysig,
1992     ecdsa_sign,
1993     "ecdsa-sha2-nistp384",
1994     "ecdsa-sha2-nistp384",
1995 };
1996
1997 const struct ssh_signkey ssh_ecdsa_nistp521 = {
1998     ecdsa_newkey,
1999     ecdsa_freekey,
2000     ecdsa_fmtkey,
2001     ecdsa_public_blob,
2002     ecdsa_private_blob,
2003     ecdsa_createkey,
2004     ecdsa_openssh_createkey,
2005     ecdsa_openssh_fmtkey,
2006     3 /* curve name, point, private exponent */,
2007     ecdsa_pubkey_bits,
2008     ecdsa_fingerprint,
2009     ecdsa_verifysig,
2010     ecdsa_sign,
2011     "ecdsa-sha2-nistp521",
2012     "ecdsa-sha2-nistp521",
2013 };
2014
2015 /* ----------------------------------------------------------------------
2016  * Exposed ECDH interface
2017  */
2018
2019 static Bignum ecdh_calculate(const Bignum private,
2020                              const struct ec_point *public)
2021 {
2022     struct ec_point *p;
2023     Bignum ret;
2024     p = ecp_mul(public, private);
2025     if (!p) return NULL;
2026     ret = p->x;
2027     p->x = NULL;
2028     ec_point_free(p);
2029     return ret;
2030 }
2031
2032 void *ssh_ecdhkex_newkey(struct ec_curve *curve)
2033 {
2034     struct ec_key *key = snew(struct ec_key);
2035     struct ec_point *publicKey;
2036     key->publicKey.curve = curve;
2037     key->privateKey = bignum_random_in_range(One, key->publicKey.curve->n);
2038     if (!key->privateKey) {
2039         sfree(key);
2040         return NULL;
2041     }
2042     publicKey = ecp_mul(&key->publicKey.curve->G, key->privateKey);
2043     if (!publicKey) {
2044         freebn(key->privateKey);
2045         sfree(key);
2046         return NULL;
2047     }
2048     key->publicKey.x = publicKey->x;
2049     key->publicKey.y = publicKey->y;
2050     key->publicKey.z = NULL;
2051     sfree(publicKey);
2052     return key;
2053 }
2054
2055 char *ssh_ecdhkex_getpublic(void *key, int *len)
2056 {
2057     struct ec_key *ec = (struct ec_key*)key;
2058     char *point, *p;
2059     int i;
2060     int pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2061
2062     *len = 1 + pointlen * 2;
2063     point = (char*)snewn(*len, char);
2064
2065     p = point;
2066     *p++ = 0x04;
2067     for (i = pointlen; i--;)
2068         *p++ = bignum_byte(ec->publicKey.x, i);
2069     for (i = pointlen; i--;)
2070         *p++ = bignum_byte(ec->publicKey.y, i);
2071
2072     return point;
2073 }
2074
2075 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
2076 {
2077     struct ec_key *ec = (struct ec_key*) key;
2078     struct ec_point remote;
2079
2080     remote.curve = ec->publicKey.curve;
2081     remote.infinity = 0;
2082     if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
2083         return NULL;
2084     }
2085
2086     return ecdh_calculate(ec->privateKey, &remote);
2087 }
2088
2089 void ssh_ecdhkex_freekey(void *key)
2090 {
2091     ecdsa_freekey(key);
2092 }
2093
2094 static const struct ssh_kex ssh_ec_kex_nistp256 = {
2095     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha256
2096 };
2097
2098 static const struct ssh_kex ssh_ec_kex_nistp384 = {
2099     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha384
2100 };
2101
2102 static const struct ssh_kex ssh_ec_kex_nistp521 = {
2103     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha512
2104 };
2105
2106 static const struct ssh_kex *const ec_kex_list[] = {
2107     &ssh_ec_kex_nistp256,
2108     &ssh_ec_kex_nistp384,
2109     &ssh_ec_kex_nistp521
2110 };
2111
2112 const struct ssh_kexes ssh_ecdh_kex = {
2113     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
2114     ec_kex_list
2115 };