]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
Elliptic-curve cryptography support.
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  * NOTE: Only curves on prime field are handled by the maths functions
5  */
6
7 /*
8  * References:
9  *
10  * Elliptic curves in SSH are specified in RFC 5656:
11  *   http://tools.ietf.org/html/rfc5656
12  *
13  * That specification delegates details of public key formatting and a
14  * lot of underlying mechanism to SEC 1:
15  *   http://www.secg.org/sec1-v2.pdf
16  */
17
18 #include <stdlib.h>
19 #include <assert.h>
20
21 #include "ssh.h"
22
23 /* ----------------------------------------------------------------------
24  * Elliptic curve definitions
25  */
26
27 static int initialise_curve(struct ec_curve *curve, int bits, unsigned char *p,
28                             unsigned char *a, unsigned char *b,
29                             unsigned char *n, unsigned char *Gx,
30                             unsigned char *Gy)
31 {
32     int length = bits / 8;
33     if (bits % 8) ++length;
34
35     curve->fieldBits = bits;
36     curve->p = bignum_from_bytes(p, length);
37     if (!curve->p) goto error;
38
39     /* Curve co-efficients */
40     curve->a = bignum_from_bytes(a, length);
41     if (!curve->a) goto error;
42     curve->b = bignum_from_bytes(b, length);
43     if (!curve->b) goto error;
44
45     /* Group order and generator */
46     curve->n = bignum_from_bytes(n, length);
47     if (!curve->n) goto error;
48     curve->G.x = bignum_from_bytes(Gx, length);
49     if (!curve->G.x) goto error;
50     curve->G.y = bignum_from_bytes(Gy, length);
51     if (!curve->G.y) goto error;
52     curve->G.curve = curve;
53     curve->G.infinity = 0;
54
55     return 1;
56   error:
57     if (curve->p) freebn(curve->p);
58     if (curve->a) freebn(curve->a);
59     if (curve->b) freebn(curve->b);
60     if (curve->n) freebn(curve->n);
61     if (curve->G.x) freebn(curve->G.x);
62     return 0;
63 }
64
65 unsigned char nistp256_oid[] = {0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07};
66 int nistp256_oid_len = 8;
67 unsigned char nistp384_oid[] = {0x2b, 0x81, 0x04, 0x00, 0x22};
68 int nistp384_oid_len = 5;
69 unsigned char nistp521_oid[] = {0x2b, 0x81, 0x04, 0x00, 0x23};
70 int nistp521_oid_len = 5;
71
72 struct ec_curve *ec_p256(void)
73 {
74     static struct ec_curve curve = { 0 };
75     static unsigned char initialised = 0;
76
77     if (!initialised)
78     {
79         unsigned char p[] = {
80             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
81             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
82             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
83             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
84         };
85         unsigned char a[] = {
86             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
87             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
88             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
89             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
90         };
91         unsigned char b[] = {
92             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
93             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
94             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
95             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
96         };
97         unsigned char n[] = {
98             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
99             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
100             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
101             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
102         };
103         unsigned char Gx[] = {
104             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
105             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
106             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
107             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
108         };
109         unsigned char Gy[] = {
110             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
111             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
112             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
113             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
114         };
115
116         if (!initialise_curve(&curve, 256, p, a, b, n, Gx, Gy)) {
117             return NULL;
118         }
119
120         /* Now initialised, no need to do it again */
121         initialised = 1;
122     }
123
124     return &curve;
125 }
126
127 struct ec_curve *ec_p384(void)
128 {
129     static struct ec_curve curve = { 0 };
130     static unsigned char initialised = 0;
131
132     if (!initialised)
133     {
134         unsigned char p[] = {
135             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
136             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
137             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
138             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
139             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
140             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
141         };
142         unsigned char a[] = {
143             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
144             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
145             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
146             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
147             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
148             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
149         };
150         unsigned char b[] = {
151             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
152             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
153             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
154             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
155             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
156             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
157         };
158         unsigned char n[] = {
159             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
160             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
161             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
162             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
163             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
164             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
165         };
166         unsigned char Gx[] = {
167             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
168             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
169             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
170             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
171             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
172             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
173         };
174         unsigned char Gy[] = {
175             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
176             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
177             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
178             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
179             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
180             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
181         };
182
183         if (!initialise_curve(&curve, 384, p, a, b, n, Gx, Gy)) {
184             return NULL;
185         }
186
187         /* Now initialised, no need to do it again */
188         initialised = 1;
189     }
190
191     return &curve;
192 }
193
194 struct ec_curve *ec_p521(void)
195 {
196     static struct ec_curve curve = { 0 };
197     static unsigned char initialised = 0;
198
199     if (!initialised)
200     {
201         unsigned char p[] = {
202             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
203             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
204             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
205             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
206             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
207             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
208             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
209             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
210             0xff, 0xff
211         };
212         unsigned char a[] = {
213             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
214             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
215             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
216             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
217             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
218             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
219             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
220             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
221             0xff, 0xfc
222         };
223         unsigned char b[] = {
224             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
225             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
226             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
227             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
228             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
229             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
230             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
231             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
232             0x3f, 0x00
233         };
234         unsigned char n[] = {
235             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
236             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
237             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
238             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
239             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
240             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
241             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
242             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
243             0x64, 0x09
244         };
245         unsigned char Gx[] = {
246             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
247             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
248             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
249             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
250             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
251             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
252             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
253             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
254             0xbd, 0x66
255         };
256         unsigned char Gy[] = {
257             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
258             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
259             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
260             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
261             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
262             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
263             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
264             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
265             0x66, 0x50
266         };
267
268         if (!initialise_curve(&curve, 521, p, a, b, n, Gx, Gy)) {
269             return NULL;
270         }
271
272         /* Now initialised, no need to do it again */
273         initialised = 1;
274     }
275
276     return &curve;
277 }
278
279 static struct ec_curve *ec_name_to_curve(char *name, int len) {
280     if (len == 8 && !memcmp(name, "nistp", 5)) {
281         name += 5;
282         if (!memcmp(name, "256", 3)) {
283             return ec_p256();
284         } else if (!memcmp(name, "384", 3)) {
285             return ec_p384();
286         } else if (!memcmp(name, "521", 3)) {
287             return ec_p521();
288         }
289     }
290
291     return NULL;
292 }
293
294 static int ec_curve_to_name(const struct ec_curve *curve, unsigned char *name, int len) {
295     /* Return length of string */
296     if (name == NULL) return 8;
297
298     /* Not enough space for the name */
299     if (len < 8) return 0;
300
301     /* Put the name in the buffer */
302     switch (curve->fieldBits) {
303       case 256:
304         memcpy(name+5, "256", 3);
305         break;
306       case 384:
307         memcpy(name+5, "384", 3);
308         break;
309       case 521:
310         memcpy(name+5, "521", 3);
311         break;
312       default:
313         return 0;
314     }
315
316     memcpy(name, "nistp", 5);
317     return 8;
318 }
319
320 /* Return 1 if a is -3 % p, otherwise return 0
321  * This is used because there are some maths optimisations */
322 static int ec_aminus3(const struct ec_curve *curve)
323 {
324     int ret;
325     Bignum _p;
326
327     _p = bignum_add_long(curve->a, 3);
328     if (!_p) return 0;
329
330     ret = !bignum_cmp(curve->p, _p);
331     freebn(_p);
332     return ret;
333 }
334
335 /* ----------------------------------------------------------------------
336  * Elliptic curve field maths
337  */
338
339 static Bignum ecf_add(const Bignum a, const Bignum b,
340                       const struct ec_curve *curve)
341 {
342     Bignum a1, b1, ab, ret;
343
344     a1 = bigmod(a, curve->p);
345     if (!a1) return NULL;
346     b1 = bigmod(b, curve->p);
347     if (!b1)
348     {
349         freebn(a1);
350         return NULL;
351     }
352
353     ab = bigadd(a1, b1);
354     freebn(a1);
355     freebn(b1);
356     if (!ab) return NULL;
357
358     ret = bigmod(ab, curve->p);
359     freebn(ab);
360
361     return ret;
362 }
363
364 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
365 {
366     return modmul(a, a, curve->p);
367 }
368
369 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
370 {
371     Bignum ret, tmp;
372
373     /* Double */
374     tmp = bignum_lshift(a, 1);
375     if (!tmp) return NULL;
376
377     /* Add itself (i.e. treble) */
378     ret = bigadd(tmp, a);
379     freebn(tmp);
380
381     /* Normalise */
382     while (ret != NULL && bignum_cmp(ret, curve->p) >= 0)
383     {
384         tmp = bigsub(ret, curve->p);
385         freebn(ret);
386         ret = tmp;
387     }
388
389     return ret;
390 }
391
392 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
393 {
394     Bignum ret = bignum_lshift(a, 1);
395     if (!ret) return NULL;
396     if (bignum_cmp(ret, curve->p) >= 0)
397     {
398         Bignum tmp = bigsub(ret, curve->p);
399         freebn(ret);
400         return tmp;
401     }
402     else
403     {
404         return ret;
405     }
406 }
407
408 /* ----------------------------------------------------------------------
409  * Memory functions
410  */
411
412 void ec_point_free(struct ec_point *point)
413 {
414     if (point == NULL) return;
415     point->curve = 0;
416     if (point->x) freebn(point->x);
417     if (point->y) freebn(point->y);
418     if (point->z) freebn(point->z);
419     point->infinity = 0;
420     sfree(point);
421 }
422
423 static struct ec_point *ec_point_new(const struct ec_curve *curve,
424                                      const Bignum x, const Bignum y, const Bignum z,
425                                      unsigned char infinity)
426 {
427     struct ec_point *point = snewn(1, struct ec_point);
428     point->curve = curve;
429     point->x = x;
430     point->y = y;
431     point->z = z;
432     point->infinity = infinity ? 1 : 0;
433     return point;
434 }
435
436 static struct ec_point *ec_point_copy(const struct ec_point *a)
437 {
438     if (a == NULL) return NULL;
439     return ec_point_new(a->curve,
440                         a->x ? copybn(a->x) : NULL,
441                         a->y ? copybn(a->y) : NULL,
442                         a->z ? copybn(a->z) : NULL,
443                         a->infinity);
444 }
445
446 static int ec_point_verify(const struct ec_point *a)
447 {
448     if (a->infinity)
449     {
450         return 1;
451     }
452     else
453     {
454         /* Verify y^2 = x^3 + ax + b */
455         int ret = 0;
456
457         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
458
459         Bignum Three = bignum_from_long(3);
460         if (!Three) return 0;
461
462         lhs = modmul(a->y, a->y, a->curve->p);
463         if (!lhs) goto error;
464
465         /* This uses montgomery multiplication to optimise */
466         x3 = modpow(a->x, Three, a->curve->p);
467         freebn(Three);
468         if (!x3) goto error;
469         ax = modmul(a->curve->a, a->x, a->curve->p);
470         if (!ax) goto error;
471         x3ax = bigadd(x3, ax);
472         if (!x3ax) goto error;
473         freebn(x3); x3 = NULL;
474         freebn(ax); ax = NULL;
475         x3axm = bigmod(x3ax, a->curve->p);
476         if (!x3axm) goto error;
477         freebn(x3ax); x3ax = NULL;
478         x3axb = bigadd(x3axm, a->curve->b);
479         if (!x3axb) goto error;
480         freebn(x3axm); x3axm = NULL;
481         rhs = bigmod(x3axb, a->curve->p);
482         if (!rhs) goto error;
483         freebn(x3axb);
484
485         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
486         freebn(lhs);
487         freebn(rhs);
488
489         return ret;
490
491       error:
492         if (x3) freebn(x3);
493         if (ax) freebn(ax);
494         if (x3ax) freebn(x3ax);
495         if (x3axm) freebn(x3axm);
496         if (x3axb) freebn(x3axb);
497         if (lhs) freebn(lhs);
498         return 0;
499     }
500 }
501
502 /* ----------------------------------------------------------------------
503  * Elliptic curve point maths
504  */
505
506 /* Returns 1 on success and 0 on memory error */
507 static int ecp_normalise(struct ec_point *a)
508 {
509     Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
510
511     /* In Jacobian Coordinates the triple (X, Y, Z) represents
512        the affine point (X / Z^2, Y / Z^3) */
513     if (!a) {
514         // No point
515         return 0;
516     }
517     if (a->infinity) {
518         // Point is at infinity - i.e. normalised
519         return 1;
520     } else if (!a->x || !a->y) {
521         // No point defined
522         return 0;
523     } else if (!a->z) {
524         // Already normalised
525         return 1;
526     }
527
528     Z2 = ecf_square(a->z, a->curve);
529     if (!Z2) {
530         return 0;
531     }
532     Z2inv = modinv(Z2, a->curve->p);
533     if (!Z2inv) {
534         freebn(Z2);
535         return 0;
536     }
537     tx = modmul(a->x, Z2inv, a->curve->p);
538     freebn(Z2inv);
539     if (!tx) {
540         freebn(Z2);
541         return 0;
542     }
543
544     Z3 = modmul(Z2, a->z, a->curve->p);
545     freebn(Z2);
546     if (!Z3) {
547         freebn(tx);
548         return 0;
549     }
550     Z3inv = modinv(Z3, a->curve->p);
551     freebn(Z3);
552     if (!Z3inv) {
553         freebn(tx);
554         return 0;
555     }
556     ty = modmul(a->y, Z3inv, a->curve->p);
557     freebn(Z3inv);
558     if (!ty) {
559         freebn(tx);
560         return 0;
561     }
562
563     freebn(a->x);
564     a->x = tx;
565     freebn(a->y);
566     a->y = ty;
567     freebn(a->z);
568     a->z = NULL;
569     return 1;
570 }
571
572 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
573 {
574     Bignum S, M, outx, outy, outz;
575
576     if (a->infinity || bignum_cmp(a->y, Zero) == 0)
577     {
578         /* Identity */
579         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
580     }
581
582     /* S = 4*X*Y^2 */
583     {
584         Bignum Y2, XY2, _2XY2;
585
586         Y2 = ecf_square(a->y, a->curve);
587         if (!Y2) {
588             return NULL;
589         }
590         XY2 = modmul(a->x, Y2, a->curve->p);
591         freebn(Y2);
592         if (!XY2) {
593             return NULL;
594         }
595
596         _2XY2 = ecf_double(XY2, a->curve);
597         freebn(XY2);
598         if (!_2XY2) {
599             return NULL;
600         }
601         S = ecf_double(_2XY2, a->curve);
602         freebn(_2XY2);
603         if (!S) {
604             return NULL;
605         }
606     }
607
608     /* Faster calculation if a = -3 */
609     if (aminus3) {
610         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
611         Bignum Z2, XpZ2, XmZ2, second;
612
613         if (a->z == NULL) {
614             Z2 = copybn(One);
615         } else {
616             Z2 = ecf_square(a->z, a->curve);
617         }
618         if (!Z2) {
619             freebn(S);
620             return NULL;
621         }
622
623         XpZ2 = ecf_add(a->x, Z2, a->curve);
624         if (!XpZ2) {
625             freebn(S);
626             freebn(Z2);
627             return NULL;
628         }
629         XmZ2 = modsub(a->x, Z2, a->curve->p);
630         freebn(Z2);
631         if (!XpZ2) {
632             freebn(S);
633             freebn(XpZ2);
634             return NULL;
635         }
636
637         second = modmul(XpZ2, XmZ2, a->curve->p);
638         freebn(XpZ2);
639         freebn(XmZ2);
640         if (!second) {
641             freebn(S);
642             return NULL;
643         }
644
645         M = ecf_treble(second, a->curve);
646         freebn(second);
647         if (!M) {
648             freebn(S);
649             return NULL;
650         }
651     } else {
652         /* M = 3*X^2 + a*Z^4 */
653         Bignum _3X2, X2, aZ4;
654
655         if (a->z == NULL) {
656             aZ4 = copybn(a->curve->a);
657         } else {
658             Bignum Z2, Z4;
659
660             Z2 = ecf_square(a->z, a->curve);
661             if (!Z2) {
662                 freebn(S);
663                 return NULL;
664             }
665             Z4 = ecf_square(Z2, a->curve);
666             freebn(Z2);
667             if (!Z4) {
668                 freebn(S);
669                 return NULL;
670             }
671             aZ4 = modmul(a->curve->a, Z4, a->curve->p);
672             freebn(Z4);
673         }
674         if (!aZ4) {
675             freebn(S);
676             return NULL;
677         }
678
679         X2 = modmul(a->x, a->x, a->curve->p);
680         if (!X2) {
681             freebn(S);
682             freebn(aZ4);
683             return NULL;
684         }
685         _3X2 = ecf_treble(X2, a->curve);
686         freebn(X2);
687         if (!_3X2) {
688             freebn(S);
689             freebn(aZ4);
690             return NULL;
691         }
692         M = ecf_add(_3X2, aZ4, a->curve);
693         freebn(_3X2);
694         freebn(aZ4);
695         if (!M) {
696             freebn(S);
697             return NULL;
698         }
699     }
700
701     /* X' = M^2 - 2*S */
702     {
703         Bignum M2, _2S;
704
705         M2 = ecf_square(M, a->curve);
706         if (!M2) {
707             freebn(S);
708             freebn(M);
709             return NULL;
710         }
711
712         _2S = ecf_double(S, a->curve);
713         if (!_2S) {
714             freebn(M2);
715             freebn(S);
716             freebn(M);
717             return NULL;
718         }
719
720         outx = modsub(M2, _2S, a->curve->p);
721         freebn(M2);
722         freebn(_2S);
723         if (!outx) {
724             freebn(S);
725             freebn(M);
726             return NULL;
727         }
728     }
729
730     /* Y' = M*(S - X') - 8*Y^4 */
731     {
732         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
733
734         SX = modsub(S, outx, a->curve->p);
735         freebn(S);
736         if (!SX) {
737             freebn(M);
738             freebn(outx);
739             return NULL;
740         }
741         MSX = modmul(M, SX, a->curve->p);
742         freebn(SX);
743         freebn(M);
744         if (!MSX) {
745             freebn(outx);
746             return NULL;
747         }
748         Y2 = ecf_square(a->y, a->curve);
749         if (!Y2) {
750             freebn(outx);
751             freebn(MSX);
752             return NULL;
753         }
754         Y4 = ecf_square(Y2, a->curve);
755         freebn(Y2);
756         if (!Y4) {
757             freebn(outx);
758             freebn(MSX);
759             return NULL;
760         }
761         Eight = bignum_from_long(8);
762         if (!Eight) {
763             freebn(outx);
764             freebn(MSX);
765             freebn(Y4);
766             return NULL;
767         }
768         _8Y4 = modmul(Eight, Y4, a->curve->p);
769         freebn(Eight);
770         freebn(Y4);
771         if (!_8Y4) {
772             freebn(outx);
773             freebn(MSX);
774             return NULL;
775         }
776         outy = modsub(MSX, _8Y4, a->curve->p);
777         freebn(MSX);
778         freebn(_8Y4);
779         if (!outy) {
780             freebn(outx);
781             return NULL;
782         }
783     }
784
785     /* Z' = 2*Y*Z */
786     {
787         Bignum YZ;
788
789         if (a->z == NULL) {
790             YZ = copybn(a->y);
791         } else {
792             YZ = modmul(a->y, a->z, a->curve->p);
793         }
794         if (!YZ) {
795             freebn(outx);
796             freebn(outy);
797             return NULL;
798         }
799
800         outz = ecf_double(YZ, a->curve);
801         freebn(YZ);
802         if (!outz) {
803             freebn(outx);
804             freebn(outy);
805             return NULL;
806         }
807     }
808
809     return ec_point_new(a->curve, outx, outy, outz, 0);
810 }
811
812 static struct ec_point *ecp_add(const struct ec_point *a,
813                                 const struct ec_point *b,
814                                 const int aminus3)
815 {
816     Bignum U1, U2, S1, S2, outx, outy, outz;
817
818     /* Check if multiplying by infinity */
819     if (a->infinity) return ec_point_copy(b);
820     if (b->infinity) return ec_point_copy(a);
821
822     /* U1 = X1*Z2^2 */
823     /* S1 = Y1*Z2^3 */
824     if (b->z) {
825         Bignum Z2, Z3;
826
827         Z2 = ecf_square(b->z, a->curve);
828         if (!Z2) {
829             return NULL;
830         }
831         U1 = modmul(a->x, Z2, a->curve->p);
832         if (!U1) {
833             freebn(Z2);
834             return NULL;
835         }
836         Z3 = modmul(Z2, b->z, a->curve->p);
837         freebn(Z2);
838         if (!Z3) {
839             freebn(U1);
840             return NULL;
841         }
842         S1 = modmul(a->y, Z3, a->curve->p);
843         freebn(Z3);
844         if (!S1) {
845             freebn(U1);
846             return NULL;
847         }
848     } else {
849         U1 = copybn(a->x);
850         if (!U1) {
851             return NULL;
852         }
853         S1 = copybn(a->y);
854         if (!S1) {
855             freebn(U1);
856             return NULL;
857         }
858     }
859
860     /* U2 = X2*Z1^2 */
861     /* S2 = Y2*Z1^3 */
862     if (a->z) {
863         Bignum Z2, Z3;
864
865         Z2 = ecf_square(a->z, b->curve);
866         if (!Z2) {
867             freebn(U1);
868             freebn(S1);
869             return NULL;
870         }
871         U2 = modmul(b->x, Z2, b->curve->p);
872         if (!U2) {
873             freebn(U1);
874             freebn(S1);
875             freebn(Z2);
876             return NULL;
877         }
878         Z3 = modmul(Z2, a->z, b->curve->p);
879         freebn(Z2);
880         if (!Z3) {
881             freebn(U1);
882             freebn(S1);
883             freebn(U2);
884             return NULL;
885         }
886         S2 = modmul(b->y, Z3, b->curve->p);
887         freebn(Z3);
888         if (!S2) {
889             freebn(U1);
890             freebn(S1);
891             freebn(U2);
892             return NULL;
893         }
894     } else {
895         U2 = copybn(b->x);
896         if (!U2) {
897             freebn(U1);
898             freebn(S1);
899             return NULL;
900         }
901         S2 = copybn(b->y);
902         if (!S2) {
903             freebn(U1);
904             freebn(S1);
905             freebn(U2);
906             return NULL;
907         }
908     }
909
910     /* Check if multiplying by self */
911     if (bignum_cmp(U1, U2) == 0)
912     {
913         freebn(U1);
914         freebn(U2);
915         if (bignum_cmp(S1, S2) == 0)
916         {
917             freebn(S1);
918             freebn(S2);
919             return ecp_double(a, aminus3);
920         }
921         else
922         {
923             freebn(S1);
924             freebn(S2);
925             /* Infinity */
926             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
927         }
928     }
929
930     {
931         Bignum H, R, UH2, H3;
932
933         /* H = U2 - U1 */
934         H = modsub(U2, U1, a->curve->p);
935         freebn(U2);
936         if (!H) {
937             freebn(U1);
938             freebn(S1);
939             freebn(S2);
940             return NULL;
941         }
942
943         /* R = S2 - S1 */
944         R = modsub(S2, S1, a->curve->p);
945         freebn(S2);
946         if (!R) {
947             freebn(H);
948             freebn(S1);
949             freebn(U1);
950             return NULL;
951         }
952
953         /* X3 = R^2 - H^3 - 2*U1*H^2 */
954         {
955             Bignum R2, H2, _2UH2, first;
956
957             H2 = ecf_square(H, a->curve);
958             if (!H2) {
959                 freebn(U1);
960                 freebn(S1);
961                 freebn(H);
962                 freebn(R);
963                 return NULL;
964             }
965             UH2 = modmul(U1, H2, a->curve->p);
966             freebn(U1);
967             if (!UH2) {
968                 freebn(H2);
969                 freebn(S1);
970                 freebn(H);
971                 freebn(R);
972                 return NULL;
973             }
974             H3 = modmul(H2, H, a->curve->p);
975             freebn(H2);
976             if (!H3) {
977                 freebn(UH2);
978                 freebn(S1);
979                 freebn(H);
980                 freebn(R);
981                 return NULL;
982             }
983             R2 = ecf_square(R, a->curve);
984             if (!R2) {
985                 freebn(H3);
986                 freebn(UH2);
987                 freebn(S1);
988                 freebn(H);
989                 freebn(R);
990                 return NULL;
991             }
992             _2UH2 = ecf_double(UH2, a->curve);
993             if (!_2UH2) {
994                 freebn(R2);
995                 freebn(H3);
996                 freebn(UH2);
997                 freebn(S1);
998                 freebn(H);
999                 freebn(R);
1000                 return NULL;
1001             }
1002             first = modsub(R2, H3, a->curve->p);
1003             freebn(R2);
1004             if (!first) {
1005                 freebn(H3);
1006                 freebn(_2UH2);
1007                 freebn(UH2);
1008                 freebn(S1);
1009                 freebn(H);
1010                 freebn(R);
1011                 return NULL;
1012             }
1013             outx = modsub(first, _2UH2, a->curve->p);
1014             freebn(first);
1015             freebn(_2UH2);
1016             if (!outx) {
1017                 freebn(H3);
1018                 freebn(UH2);
1019                 freebn(S1);
1020                 freebn(H);
1021                 freebn(R);
1022                 return NULL;
1023             }
1024         }
1025
1026         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
1027         {
1028             Bignum RUH2mX, UH2mX, SH3;
1029
1030             UH2mX = modsub(UH2, outx, a->curve->p);
1031             freebn(UH2);
1032             if (!UH2mX) {
1033                 freebn(outx);
1034                 freebn(H3);
1035                 freebn(S1);
1036                 freebn(H);
1037                 freebn(R);
1038                 return NULL;
1039             }
1040             RUH2mX = modmul(R, UH2mX, a->curve->p);
1041             freebn(UH2mX);
1042             freebn(R);
1043             if (!RUH2mX) {
1044                 freebn(outx);
1045                 freebn(H3);
1046                 freebn(S1);
1047                 freebn(H);
1048                 return NULL;
1049             }
1050             SH3 = modmul(S1, H3, a->curve->p);
1051             freebn(S1);
1052             freebn(H3);
1053             if (!SH3) {
1054                 freebn(RUH2mX);
1055                 freebn(outx);
1056                 freebn(H);
1057                 return NULL;
1058             }
1059
1060             outy = modsub(RUH2mX, SH3, a->curve->p);
1061             freebn(RUH2mX);
1062             freebn(SH3);
1063             if (!outy) {
1064                 freebn(outx);
1065                 freebn(H);
1066                 return NULL;
1067             }
1068         }
1069
1070         /* Z3 = H*Z1*Z2 */
1071         if (a->z && b->z) {
1072             Bignum ZZ;
1073
1074             ZZ = modmul(a->z, b->z, a->curve->p);
1075             if (!ZZ) {
1076                 freebn(outx);
1077                 freebn(outy);
1078                 freebn(H);
1079                 return NULL;
1080             }
1081             outz = modmul(H, ZZ, a->curve->p);
1082             freebn(H);
1083             freebn(ZZ);
1084             if (!outz) {
1085                 freebn(outx);
1086                 freebn(outy);
1087                 return NULL;
1088             }
1089         } else if (a->z) {
1090             outz = modmul(H, a->z, a->curve->p);
1091             freebn(H);
1092             if (!outz) {
1093                 freebn(outx);
1094                 freebn(outy);
1095                 return NULL;
1096             }
1097         } else if (b->z) {
1098             outz = modmul(H, b->z, a->curve->p);
1099             freebn(H);
1100             if (!outz) {
1101                 freebn(outx);
1102                 freebn(outy);
1103                 return NULL;
1104             }
1105         } else {
1106             outz = H;
1107         }
1108     }
1109
1110     return ec_point_new(a->curve, outx, outy, outz, 0);
1111 }
1112
1113 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1114 {
1115     struct ec_point *A, *ret;
1116     int bits, i;
1117
1118     A = ec_point_copy(a);
1119     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1120
1121     bits = bignum_bitcount(b);
1122     for (i = 0; ret != NULL && A != NULL && i < bits; ++i)
1123     {
1124         if (bignum_bit(b, i))
1125         {
1126             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1127             ec_point_free(ret);
1128             ret = tmp;
1129         }
1130         if (i+1 != bits)
1131         {
1132             struct ec_point *tmp = ecp_double(A, aminus3);
1133             ec_point_free(A);
1134             A = tmp;
1135         }
1136     }
1137
1138     if (!A) {
1139         ec_point_free(ret);
1140         ret = NULL;
1141     } else {
1142         ec_point_free(A);
1143     }
1144
1145     return ret;
1146 }
1147
1148 /* Not static because it is used by sshecdsag.c to generate a new key */
1149 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1150 {
1151     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1152
1153     if (!ecp_normalise(ret)) {
1154         ec_point_free(ret);
1155         return NULL;
1156     }
1157
1158     return ret;
1159 }
1160
1161 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
1162                                    const struct ec_point *point)
1163 {
1164     struct ec_point *aG, *bP, *ret;
1165     int aminus3 = ec_aminus3(point->curve);
1166
1167     aG = ecp_mul_(&point->curve->G, a, aminus3);
1168     if (!aG) return NULL;
1169     bP = ecp_mul_(point, b, aminus3);
1170     if (!bP) {
1171         ec_point_free(aG);
1172         return NULL;
1173     }
1174
1175     ret = ecp_add(aG, bP, aminus3);
1176
1177     ec_point_free(aG);
1178     ec_point_free(bP);
1179
1180     if (!ecp_normalise(ret)) {
1181         ec_point_free(ret);
1182         return NULL;
1183     }
1184
1185     return ret;
1186 }
1187
1188 /* ----------------------------------------------------------------------
1189  * Basic sign and verify routines
1190  */
1191
1192 static int _ecdsa_verify(const struct ec_point *publicKey,
1193                          const unsigned char *data, const int dataLen,
1194                          const Bignum r, const Bignum s)
1195 {
1196     int z_bits, n_bits;
1197     Bignum z;
1198     int valid = 0;
1199
1200     /* Sanity checks */
1201     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->n) >= 0
1202         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->n) >= 0)
1203     {
1204         return 0;
1205     }
1206
1207     /* z = left most bitlen(curve->n) of data */
1208     z = bignum_from_bytes(data, dataLen);
1209     if (!z) return 0;
1210     n_bits = bignum_bitcount(publicKey->curve->n);
1211     z_bits = bignum_bitcount(z);
1212     if (z_bits > n_bits)
1213     {
1214         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
1215         freebn(z);
1216         z = tmp;
1217         if (!z) return 0;
1218     }
1219
1220     /* Ensure z in range of n */
1221     {
1222         Bignum tmp = bigmod(z, publicKey->curve->n);
1223         freebn(z);
1224         z = tmp;
1225         if (!z) return 0;
1226     }
1227
1228     /* Calculate signature */
1229     {
1230         Bignum w, x, u1, u2;
1231         struct ec_point *tmp;
1232
1233         w = modinv(s, publicKey->curve->n);
1234         if (!w) {
1235             freebn(z);
1236             return 0;
1237         }
1238         u1 = modmul(z, w, publicKey->curve->n);
1239         if (!u1) {
1240             freebn(z);
1241             freebn(w);
1242             return 0;
1243         }
1244         u2 = modmul(r, w, publicKey->curve->n);
1245         freebn(w);
1246         if (!u2) {
1247             freebn(z);
1248             freebn(u1);
1249             return 0;
1250         }
1251
1252         tmp = ecp_summul(u1, u2, publicKey);
1253         freebn(u1);
1254         freebn(u2);
1255         if (!tmp) {
1256             freebn(z);
1257             return 0;
1258         }
1259
1260         x = bigmod(tmp->x, publicKey->curve->n);
1261         ec_point_free(tmp);
1262         if (!x) {
1263             freebn(z);
1264             return 0;
1265         }
1266
1267         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
1268         freebn(x);
1269     }
1270
1271     freebn(z);
1272
1273     return valid;
1274 }
1275
1276 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
1277                         const unsigned char *data, const int dataLen,
1278                         Bignum *r, Bignum *s)
1279 {
1280     unsigned char digest[20];
1281     int z_bits, n_bits;
1282     Bignum z, k;
1283     struct ec_point *kG;
1284
1285     *r = NULL;
1286     *s = NULL;
1287
1288     /* z = left most bitlen(curve->n) of data */
1289     z = bignum_from_bytes(data, dataLen);
1290     if (!z) return;
1291     n_bits = bignum_bitcount(curve->n);
1292     z_bits = bignum_bitcount(z);
1293     if (z_bits > n_bits)
1294     {
1295         Bignum tmp;
1296         tmp = bignum_rshift(z, z_bits - n_bits);
1297         freebn(z);
1298         z = tmp;
1299         if (!z) return;
1300     }
1301
1302     /* Generate k between 1 and curve->n, using the same deterministic
1303      * k generation system we use for conventional DSA. */
1304     SHA_Simple(data, dataLen, digest);
1305     k = dss_gen_k("ECDSA deterministic k generator", curve->n, privateKey,
1306                   digest, sizeof(digest));
1307     if (!k) return;
1308
1309     kG = ecp_mul(&curve->G, k);
1310     if (!kG) {
1311         freebn(z);
1312         freebn(k);
1313         return;
1314     }
1315
1316     /* r = kG.x mod n */
1317     *r = bigmod(kG->x, curve->n);
1318     ec_point_free(kG);
1319     if (!*r) {
1320         freebn(z);
1321         freebn(k);
1322         return;
1323     }
1324
1325     /* s = (z + r * priv)/k mod n */
1326     {
1327         Bignum rPriv, zMod, first, firstMod, kInv;
1328         rPriv = modmul(*r, privateKey, curve->n);
1329         if (!rPriv) {
1330             freebn(*r);
1331             freebn(z);
1332             freebn(k);
1333             return;
1334         }
1335         zMod = bigmod(z, curve->n);
1336         freebn(z);
1337         if (!zMod) {
1338             freebn(rPriv);
1339             freebn(*r);
1340             freebn(k);
1341             return;
1342         }
1343         first = bigadd(rPriv, zMod);
1344         freebn(rPriv);
1345         freebn(zMod);
1346         if (!first) {
1347             freebn(*r);
1348             freebn(k);
1349             return;
1350         }
1351         firstMod = bigmod(first, curve->n);
1352         freebn(first);
1353         if (!firstMod) {
1354             freebn(*r);
1355             freebn(k);
1356             return;
1357         }
1358         kInv = modinv(k, curve->n);
1359         freebn(k);
1360         if (!kInv) {
1361             freebn(firstMod);
1362             freebn(*r);
1363             return;
1364         }
1365         *s = modmul(firstMod, kInv, curve->n);
1366         freebn(firstMod);
1367         freebn(kInv);
1368         if (!*s) {
1369             freebn(*r);
1370             return;
1371         }
1372     }
1373 }
1374
1375 /* ----------------------------------------------------------------------
1376  * Misc functions
1377  */
1378
1379 static void getstring(char **data, int *datalen, char **p, int *length)
1380 {
1381     *p = NULL;
1382     if (*datalen < 4)
1383         return;
1384     *length = toint(GET_32BIT(*data));
1385     if (*length < 0)
1386         return;
1387     *datalen -= 4;
1388     *data += 4;
1389     if (*datalen < *length)
1390         return;
1391     *p = *data;
1392     *data += *length;
1393     *datalen -= *length;
1394 }
1395
1396 static Bignum getmp(char **data, int *datalen)
1397 {
1398     char *p;
1399     int length;
1400
1401     getstring(data, datalen, &p, &length);
1402     if (!p)
1403         return NULL;
1404     if (p[0] & 0x80)
1405         return NULL;                   /* negative mp */
1406     return bignum_from_bytes((unsigned char *)p, length);
1407 }
1408
1409 static int decodepoint(char *p, int length, struct ec_point *point)
1410 {
1411     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
1412         return 0;
1413     /* Skip compression flag */
1414     ++p;
1415     --length;
1416     /* The two values must be equal length */
1417     if (length % 2 != 0) {
1418         point->x = NULL;
1419         point->y = NULL;
1420         point->z = NULL;
1421         return 0;
1422     }
1423     length = length / 2;
1424     point->x = bignum_from_bytes((unsigned char *)p, length);
1425     if (!point->x) return 0;
1426     p += length;
1427     point->y = bignum_from_bytes((unsigned char *)p, length);
1428     if (!point->y) {
1429         freebn(point->x);
1430         point->x = NULL;
1431         return 0;
1432     }
1433     point->z = NULL;
1434
1435     /* Verify the point is on the curve */
1436     if (!ec_point_verify(point)) {
1437         ec_point_free(point);
1438         return 0;
1439     }
1440
1441     return 1;
1442 }
1443
1444 static int getmppoint(char **data, int *datalen, struct ec_point *point)
1445 {
1446     char *p;
1447     int length;
1448
1449     getstring(data, datalen, &p, &length);
1450     if (!p) return 0;
1451     return decodepoint(p, length, point);
1452 }
1453
1454 /* ----------------------------------------------------------------------
1455  * Exposed ECDSA interface
1456  */
1457
1458 static void ecdsa_freekey(void *key)
1459 {
1460     struct ec_key *ec = (struct ec_key *) key;
1461     if (!ec) return;
1462
1463     if (ec->publicKey.x)
1464         freebn(ec->publicKey.x);
1465     if (ec->publicKey.y)
1466         freebn(ec->publicKey.y);
1467     if (ec->publicKey.z)
1468         freebn(ec->publicKey.z);
1469     if (ec->privateKey)
1470         freebn(ec->privateKey);
1471     sfree(ec);
1472 }
1473
1474 static void *ecdsa_newkey(char *data, int len)
1475 {
1476     char *p;
1477     int slen;
1478     struct ec_key *ec;
1479     struct ec_curve *curve;
1480
1481     getstring(&data, &len, &p, &slen);
1482
1483     if (!p || slen < 11 || memcmp(p, "ecdsa-sha2-", 11)) {
1484         return NULL;
1485     }
1486     curve = ec_name_to_curve(p+11, slen-11);
1487     if (!curve) return NULL;
1488
1489     getstring(&data, &len, &p, &slen);
1490
1491     if (curve != ec_name_to_curve(p, slen)) return NULL;
1492
1493     ec = snew(struct ec_key);
1494
1495     ec->publicKey.curve = curve;
1496     ec->publicKey.infinity = 0;
1497     ec->publicKey.x = NULL;
1498     ec->publicKey.y = NULL;
1499     ec->publicKey.z = NULL;
1500     if (!getmppoint(&data, &len, &ec->publicKey)) {
1501         ecdsa_freekey(ec);
1502         return NULL;
1503     }
1504     ec->privateKey = NULL;
1505
1506     if (!ec->publicKey.x || !ec->publicKey.y ||
1507         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
1508         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
1509     {
1510         ecdsa_freekey(ec);
1511         ec = NULL;
1512     }
1513
1514     return ec;
1515 }
1516
1517 static char *ecdsa_fmtkey(void *key)
1518 {
1519     struct ec_key *ec = (struct ec_key *) key;
1520     char *p;
1521     int len, i, pos, nibbles;
1522     static const char hex[] = "0123456789abcdef";
1523     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
1524         return NULL;
1525
1526     pos = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1527     if (pos == 0) return NULL;
1528
1529     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
1530     len += pos; /* Curve name */
1531     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
1532     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
1533     p = snewn(len, char);
1534
1535     pos = ec_curve_to_name(ec->publicKey.curve, (unsigned char*)p, pos);
1536     pos += sprintf(p + pos, ",0x");
1537     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
1538     if (nibbles < 1)
1539         nibbles = 1;
1540     for (i = nibbles; i--;) {
1541         p[pos++] =
1542             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
1543     }
1544     pos += sprintf(p + pos, ",0x");
1545     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
1546     if (nibbles < 1)
1547         nibbles = 1;
1548     for (i = nibbles; i--;) {
1549         p[pos++] =
1550             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
1551     }
1552     p[pos] = '\0';
1553     return p;
1554 }
1555
1556 static unsigned char *ecdsa_public_blob(void *key, int *len)
1557 {
1558     struct ec_key *ec = (struct ec_key *) key;
1559     int pointlen, bloblen, namelen;
1560     int i;
1561     unsigned char *blob, *p;
1562
1563     namelen = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1564     if (namelen == 0) return NULL;
1565
1566     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1567
1568     /*
1569      * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
1570      */
1571     bloblen = 4 + 11 + namelen + 4 + namelen + 4 + 1 + (pointlen * 2);
1572     blob = snewn(bloblen, unsigned char);
1573
1574     p = blob;
1575     PUT_32BIT(p, 11 + namelen);
1576     p += 4;
1577     memcpy(p, "ecdsa-sha2-", 11);
1578     p += 11;
1579     p += ec_curve_to_name(ec->publicKey.curve, p, namelen);
1580     PUT_32BIT(p, namelen);
1581     p += 4;
1582     p += ec_curve_to_name(ec->publicKey.curve, p, namelen);
1583     PUT_32BIT(p, (2 * pointlen) + 1);
1584     p += 4;
1585     *p++ = 0x04;
1586     for (i = pointlen; i--;)
1587         *p++ = bignum_byte(ec->publicKey.x, i);
1588     for (i = pointlen; i--;)
1589         *p++ = bignum_byte(ec->publicKey.y, i);
1590
1591     assert(p == blob + bloblen);
1592     *len = bloblen;
1593
1594     return blob;
1595 }
1596
1597 static unsigned char *ecdsa_private_blob(void *key, int *len)
1598 {
1599     struct ec_key *ec = (struct ec_key *) key;
1600     int keylen, bloblen;
1601     int i;
1602     unsigned char *blob, *p;
1603
1604     if (!ec->privateKey) return NULL;
1605
1606     keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
1607
1608     /*
1609      * mpint privateKey. Total 4 + keylen.
1610      */
1611     bloblen = 4 + keylen;
1612     blob = snewn(bloblen, unsigned char);
1613
1614     p = blob;
1615     PUT_32BIT(p, keylen);
1616     p += 4;
1617     for (i = keylen; i--;)
1618         *p++ = bignum_byte(ec->privateKey, i);
1619
1620     assert(p == blob + bloblen);
1621     *len = bloblen;
1622     return blob;
1623 }
1624
1625 static void *ecdsa_createkey(unsigned char *pub_blob, int pub_len,
1626                              unsigned char *priv_blob, int priv_len)
1627 {
1628     struct ec_key *ec;
1629     struct ec_point *publicKey;
1630     char *pb = (char *) priv_blob;
1631
1632     ec = (struct ec_key*)ecdsa_newkey((char *) pub_blob, pub_len);
1633     if (!ec) {
1634         return NULL;
1635     }
1636
1637     ec->privateKey = getmp(&pb, &priv_len);
1638     if (!ec->privateKey) {
1639         ecdsa_freekey(ec);
1640         return NULL;
1641     }
1642
1643     /* Check that private key generates public key */
1644     publicKey = ecp_mul(&ec->publicKey.curve->G, ec->privateKey);
1645
1646     if (!publicKey ||
1647         bignum_cmp(publicKey->x, ec->publicKey.x) ||
1648         bignum_cmp(publicKey->y, ec->publicKey.y))
1649     {
1650         ecdsa_freekey(ec);
1651         ec = NULL;
1652     }
1653     ec_point_free(publicKey);
1654
1655     return ec;
1656 }
1657
1658 static void *ecdsa_openssh_createkey(unsigned char **blob, int *len)
1659 {
1660     char **b = (char **) blob;
1661     char *p;
1662     int slen;
1663     struct ec_key *ec;
1664     struct ec_curve *curve;
1665     struct ec_point *publicKey;
1666
1667     getstring(b, len, &p, &slen);
1668
1669     if (!p) {
1670         return NULL;
1671     }
1672     curve = ec_name_to_curve(p, slen);
1673     if (!curve) return NULL;
1674
1675     ec = snew(struct ec_key);
1676
1677     ec->publicKey.curve = curve;
1678     ec->publicKey.infinity = 0;
1679     ec->publicKey.x = NULL;
1680     ec->publicKey.y = NULL;
1681     ec->publicKey.z = NULL;
1682     if (!getmppoint(b, len, &ec->publicKey)) {
1683         ecdsa_freekey(ec);
1684         return NULL;
1685     }
1686     ec->privateKey = NULL;
1687
1688     if (!ec->publicKey.x || !ec->publicKey.y ||
1689         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
1690         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
1691     {
1692         ecdsa_freekey(ec);
1693         return NULL;
1694     }
1695
1696     ec->privateKey = getmp(b, len);
1697     if (ec->privateKey == NULL)
1698     {
1699         ecdsa_freekey(ec);
1700         return NULL;
1701     }
1702
1703     /* Now check that the private key makes the public key */
1704     publicKey = ecp_mul(&ec->publicKey.curve->G, ec->privateKey);
1705     if (!publicKey)
1706     {
1707         ecdsa_freekey(ec);
1708         return NULL;
1709     }
1710
1711     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
1712         bignum_cmp(ec->publicKey.y, publicKey->y))
1713     {
1714         /* Private key doesn't make the public key on the given curve */
1715         ecdsa_freekey(ec);
1716         ec_point_free(publicKey);
1717     }
1718
1719     ec_point_free(publicKey);
1720
1721     return ec;
1722 }
1723
1724 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
1725 {
1726     struct ec_key *ec = (struct ec_key *) key;
1727
1728     int pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1729
1730     int namelen = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1731
1732     int bloblen =
1733         4 + namelen /* <LEN> nistpXXX */
1734         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
1735         + ssh2_bignum_length(ec->privateKey);
1736
1737     int i;
1738
1739     if (bloblen > len)
1740         return bloblen;
1741
1742     bloblen = 0;
1743
1744     PUT_32BIT(blob+bloblen, namelen);
1745     bloblen += 4;
1746
1747     bloblen += ec_curve_to_name(ec->publicKey.curve, blob+bloblen, namelen);
1748
1749     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
1750     bloblen += 4;
1751     blob[bloblen++] = 0x04;
1752     for (i = pointlen; i--; )
1753         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
1754     for (i = pointlen; i--; )
1755         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
1756
1757     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
1758     PUT_32BIT(blob+bloblen, pointlen);
1759     bloblen += 4;
1760     for (i = pointlen; i--; )
1761         blob[bloblen++] = bignum_byte(ec->privateKey, i);
1762
1763     return bloblen;
1764 }
1765
1766 static int ecdsa_pubkey_bits(void *blob, int len)
1767 {
1768     struct ec_key *ec;
1769     int ret;
1770
1771     ec = (struct ec_key*)ecdsa_newkey((char *) blob, len);
1772     if (!ec)
1773         return -1;
1774     ret = ec->publicKey.curve->fieldBits;
1775     ecdsa_freekey(ec);
1776
1777     return ret;
1778 }
1779
1780 static char *ecdsa_fingerprint(void *key)
1781 {
1782     struct ec_key *ec = (struct ec_key *) key;
1783     struct MD5Context md5c;
1784     unsigned char digest[16], lenbuf[4];
1785     char *ret;
1786     unsigned char *name;
1787     int pointlen, namelen, i, j;
1788
1789     namelen = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1790     name = snewn(namelen, unsigned char);
1791     ec_curve_to_name(ec->publicKey.curve, name, namelen);
1792
1793     MD5Init(&md5c);
1794
1795     PUT_32BIT(lenbuf, namelen + 11);
1796     MD5Update(&md5c, lenbuf, 4);
1797     MD5Update(&md5c, (const unsigned char *)"ecdsa-sha2-", 11);
1798     MD5Update(&md5c, name, namelen);
1799
1800     PUT_32BIT(lenbuf, namelen);
1801     MD5Update(&md5c, lenbuf, 4);
1802     MD5Update(&md5c, name, namelen);
1803
1804     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1805     PUT_32BIT(lenbuf, 1 + (pointlen * 2));
1806     MD5Update(&md5c, lenbuf, 4);
1807     MD5Update(&md5c, (const unsigned char *)"\x04", 1);
1808     for (i = pointlen; i--; ) {
1809         unsigned char c = bignum_byte(ec->publicKey.x, i);
1810         MD5Update(&md5c, &c, 1);
1811     }
1812     for (i = pointlen; i--; ) {
1813         unsigned char c = bignum_byte(ec->publicKey.y, i);
1814         MD5Update(&md5c, &c, 1);
1815     }
1816
1817     MD5Final(digest, &md5c);
1818
1819     ret = snewn(11 + namelen + 1 + (16 * 3), char);
1820
1821     i = 11;
1822     memcpy(ret, "ecdsa-sha2-", 11);
1823     memcpy(ret+i, name, namelen);
1824     i += namelen;
1825     sfree(name);
1826     ret[i++] = ' ';
1827     for (j = 0; j < 16; j++)
1828         i += sprintf(ret + i, "%s%02x", j ? ":" : "", digest[j]);
1829
1830     return ret;
1831 }
1832
1833 static int ecdsa_verifysig(void *key, char *sig, int siglen,
1834                            char *data, int datalen)
1835 {
1836     struct ec_key *ec = (struct ec_key *) key;
1837     char *p;
1838     int slen;
1839     unsigned char digest[512 / 8];
1840     int digestLen;
1841     Bignum r, s;
1842     int ret;
1843
1844     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
1845         return 0;
1846
1847     /* Check the signature curve matches the key curve */
1848     getstring(&sig, &siglen, &p, &slen);
1849     if (!p || slen < 11 || memcmp(p, "ecdsa-sha2-", 11)) {
1850         return 0;
1851     }
1852     if (ec->publicKey.curve != ec_name_to_curve(p+11, slen-11)) {
1853         return 0;
1854     }
1855
1856     getstring(&sig, &siglen, &p, &slen);
1857     r = getmp(&p, &slen);
1858     if (!r) return 0;
1859     s = getmp(&p, &slen);
1860     if (!s) {
1861         freebn(r);
1862         return 0;
1863     }
1864
1865     /* Perform correct hash function depending on curve size */
1866     if (ec->publicKey.curve->fieldBits <= 256) {
1867         SHA256_Simple(data, datalen, digest);
1868         digestLen = 256 / 8;
1869     } else if (ec->publicKey.curve->fieldBits <= 384) {
1870         SHA384_Simple(data, datalen, digest);
1871         digestLen = 384 / 8;
1872     } else {
1873         SHA512_Simple(data, datalen, digest);
1874         digestLen = 512 / 8;
1875     }
1876
1877     /* Verify the signature */
1878     if (!_ecdsa_verify(&ec->publicKey, digest, digestLen, r, s)) {
1879         ret = 0;
1880     } else {
1881         ret = 1;
1882     }
1883
1884     freebn(r);
1885     freebn(s);
1886
1887     return ret;
1888 }
1889
1890 static unsigned char *ecdsa_sign(void *key, char *data, int datalen,
1891                                  int *siglen)
1892 {
1893     struct ec_key *ec = (struct ec_key *) key;
1894     unsigned char digest[512 / 8];
1895     int digestLen;
1896     Bignum r = NULL, s = NULL;
1897     unsigned char *buf, *p;
1898     int rlen, slen, namelen;
1899     int i;
1900
1901     if (!ec->privateKey || !ec->publicKey.curve) {
1902         return NULL;
1903     }
1904
1905     /* Perform correct hash function depending on curve size */
1906     if (ec->publicKey.curve->fieldBits <= 256) {
1907         SHA256_Simple(data, datalen, digest);
1908         digestLen = 256 / 8;
1909     } else if (ec->publicKey.curve->fieldBits <= 384) {
1910         SHA384_Simple(data, datalen, digest);
1911         digestLen = 384 / 8;
1912     } else {
1913         SHA512_Simple(data, datalen, digest);
1914         digestLen = 512 / 8;
1915     }
1916
1917     /* Do the signature */
1918     _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
1919     if (!r || !s) {
1920         if (r) freebn(r);
1921         if (s) freebn(s);
1922         return NULL;
1923     }
1924
1925     rlen = (bignum_bitcount(r) + 8) / 8;
1926     slen = (bignum_bitcount(s) + 8) / 8;
1927
1928     namelen = ec_curve_to_name(ec->publicKey.curve, NULL, 0);
1929
1930     /* Format the output */
1931     *siglen = 8+11+namelen+rlen+slen+8;
1932     buf = snewn(*siglen, unsigned char);
1933     p = buf;
1934     PUT_32BIT(p, 11+namelen);
1935     p += 4;
1936     memcpy(p, "ecdsa-sha2-", 11);
1937     p += 11;
1938     p += ec_curve_to_name(ec->publicKey.curve, p, namelen);
1939     PUT_32BIT(p, rlen + slen + 8);
1940     p += 4;
1941     PUT_32BIT(p, rlen);
1942     p += 4;
1943     for (i = rlen; i--;)
1944         *p++ = bignum_byte(r, i);
1945     PUT_32BIT(p, slen);
1946     p += 4;
1947     for (i = slen; i--;)
1948         *p++ = bignum_byte(s, i);
1949
1950     return buf;
1951 }
1952
1953 const struct ssh_signkey ssh_ecdsa_nistp256 = {
1954     ecdsa_newkey,
1955     ecdsa_freekey,
1956     ecdsa_fmtkey,
1957     ecdsa_public_blob,
1958     ecdsa_private_blob,
1959     ecdsa_createkey,
1960     ecdsa_openssh_createkey,
1961     ecdsa_openssh_fmtkey,
1962     ecdsa_pubkey_bits,
1963     ecdsa_fingerprint,
1964     ecdsa_verifysig,
1965     ecdsa_sign,
1966     "ecdsa-sha2-nistp256",
1967     "ecdsa-sha2-nistp256",
1968 };
1969
1970 const struct ssh_signkey ssh_ecdsa_nistp384 = {
1971     ecdsa_newkey,
1972     ecdsa_freekey,
1973     ecdsa_fmtkey,
1974     ecdsa_public_blob,
1975     ecdsa_private_blob,
1976     ecdsa_createkey,
1977     ecdsa_openssh_createkey,
1978     ecdsa_openssh_fmtkey,
1979     ecdsa_pubkey_bits,
1980     ecdsa_fingerprint,
1981     ecdsa_verifysig,
1982     ecdsa_sign,
1983     "ecdsa-sha2-nistp384",
1984     "ecdsa-sha2-nistp384",
1985 };
1986
1987 const struct ssh_signkey ssh_ecdsa_nistp521 = {
1988     ecdsa_newkey,
1989     ecdsa_freekey,
1990     ecdsa_fmtkey,
1991     ecdsa_public_blob,
1992     ecdsa_private_blob,
1993     ecdsa_createkey,
1994     ecdsa_openssh_createkey,
1995     ecdsa_openssh_fmtkey,
1996     ecdsa_pubkey_bits,
1997     ecdsa_fingerprint,
1998     ecdsa_verifysig,
1999     ecdsa_sign,
2000     "ecdsa-sha2-nistp521",
2001     "ecdsa-sha2-nistp521",
2002 };
2003
2004 /* ----------------------------------------------------------------------
2005  * Exposed ECDH interface
2006  */
2007
2008 static Bignum ecdh_calculate(const Bignum private,
2009                              const struct ec_point *public)
2010 {
2011     struct ec_point *p;
2012     Bignum ret;
2013     p = ecp_mul(public, private);
2014     if (!p) return NULL;
2015     ret = p->x;
2016     p->x = NULL;
2017     ec_point_free(p);
2018     return ret;
2019 }
2020
2021 void *ssh_ecdhkex_newkey(struct ec_curve *curve)
2022 {
2023     struct ec_key *key = snew(struct ec_key);
2024     struct ec_point *publicKey;
2025     key->publicKey.curve = curve;
2026     key->privateKey = bignum_random_in_range(One, key->publicKey.curve->n);
2027     if (!key->privateKey) {
2028         sfree(key);
2029         return NULL;
2030     }
2031     publicKey = ecp_mul(&key->publicKey.curve->G, key->privateKey);
2032     if (!publicKey) {
2033         freebn(key->privateKey);
2034         sfree(key);
2035         return NULL;
2036     }
2037     key->publicKey.x = publicKey->x;
2038     key->publicKey.y = publicKey->y;
2039     key->publicKey.z = NULL;
2040     sfree(publicKey);
2041     return key;
2042 }
2043
2044 char *ssh_ecdhkex_getpublic(void *key, int *len)
2045 {
2046     struct ec_key *ec = (struct ec_key*)key;
2047     char *point, *p;
2048     int i;
2049     int pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2050
2051     *len = 1 + pointlen * 2;
2052     point = (char*)snewn(*len, char);
2053
2054     p = point;
2055     *p++ = 0x04;
2056     for (i = pointlen; i--;)
2057         *p++ = bignum_byte(ec->publicKey.x, i);
2058     for (i = pointlen; i--;)
2059         *p++ = bignum_byte(ec->publicKey.y, i);
2060
2061     return point;
2062 }
2063
2064 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
2065 {
2066     struct ec_key *ec = (struct ec_key*) key;
2067     struct ec_point remote;
2068
2069     remote.curve = ec->publicKey.curve;
2070     remote.infinity = 0;
2071     if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
2072         return NULL;
2073     }
2074
2075     return ecdh_calculate(ec->privateKey, &remote);
2076 }
2077
2078 void ssh_ecdhkex_freekey(void *key)
2079 {
2080     ecdsa_freekey(key);
2081 }
2082
2083 static const struct ssh_kex ssh_ec_kex_nistp256 = {
2084     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha256
2085 };
2086
2087 static const struct ssh_kex ssh_ec_kex_nistp384 = {
2088     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha384
2089 };
2090
2091 static const struct ssh_kex ssh_ec_kex_nistp521 = {
2092     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha512
2093 };
2094
2095 static const struct ssh_kex *const ec_kex_list[] = {
2096     &ssh_ec_kex_nistp256,
2097     &ssh_ec_kex_nistp384,
2098     &ssh_ec_kex_nistp521
2099 };
2100
2101 const struct ssh_kexes ssh_ecdh_kex = {
2102     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
2103     ec_kex_list
2104 };