]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
Clean up elliptic curve selection and naming.
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  *
5  * NOTE: Only curves on prime field are handled by the maths functions
6  *       in Weierstrass form using Jacobian co-ordinates.
7  *
8  *       Montgomery form curves are supported for DH. (Curve25519)
9  *
10  *       Edwards form curves are supported for DSA. (Ed25519)
11  */
12
13 /*
14  * References:
15  *
16  * Elliptic curves in SSH are specified in RFC 5656:
17  *   http://tools.ietf.org/html/rfc5656
18  *
19  * That specification delegates details of public key formatting and a
20  * lot of underlying mechanism to SEC 1:
21  *   http://www.secg.org/sec1-v2.pdf
22  *
23  * Montgomery maths from:
24  * Handbook of elliptic and hyperelliptic curve cryptography, Chapter 13
25  *   http://cs.ucsb.edu/~koc/ccs130h/2013/EllipticHyperelliptic-CohenFrey.pdf
26  *
27  * Edwards DSA:
28  *   http://ed25519.cr.yp.to/ed25519-20110926.pdf
29  */
30
31 #include <stdlib.h>
32 #include <assert.h>
33
34 #include "ssh.h"
35
36 /* ----------------------------------------------------------------------
37  * Elliptic curve definitions
38  */
39
40 static int initialise_wcurve(struct ec_curve *curve, int bits, unsigned char *p,
41                              unsigned char *a, unsigned char *b,
42                              unsigned char *n, unsigned char *Gx,
43                              unsigned char *Gy)
44 {
45     int length = bits / 8;
46     if (bits % 8) ++length;
47
48     curve->type = EC_WEIERSTRASS;
49
50     curve->fieldBits = bits;
51     curve->p = bignum_from_bytes(p, length);
52     if (!curve->p) goto error;
53
54     /* Curve co-efficients */
55     curve->w.a = bignum_from_bytes(a, length);
56     if (!curve->w.a) goto error;
57     curve->w.b = bignum_from_bytes(b, length);
58     if (!curve->w.b) goto error;
59
60     /* Group order and generator */
61     curve->w.n = bignum_from_bytes(n, length);
62     if (!curve->w.n) goto error;
63     curve->w.G.x = bignum_from_bytes(Gx, length);
64     if (!curve->w.G.x) goto error;
65     curve->w.G.y = bignum_from_bytes(Gy, length);
66     if (!curve->w.G.y) goto error;
67     curve->w.G.curve = curve;
68     curve->w.G.infinity = 0;
69
70     return 1;
71   error:
72     if (curve->p) freebn(curve->p);
73     if (curve->w.a) freebn(curve->w.a);
74     if (curve->w.b) freebn(curve->w.b);
75     if (curve->w.n) freebn(curve->w.n);
76     if (curve->w.G.x) freebn(curve->w.G.x);
77     return 0;
78 }
79
80 static int initialise_mcurve(struct ec_curve *curve, int bits, unsigned char *p,
81                              unsigned char *a, unsigned char *b,
82                              unsigned char *Gx)
83 {
84     int length = bits / 8;
85     if (bits % 8) ++length;
86
87     curve->type = EC_MONTGOMERY;
88
89     curve->fieldBits = bits;
90     curve->p = bignum_from_bytes(p, length);
91     if (!curve->p) goto error;
92
93     /* Curve co-efficients */
94     curve->m.a = bignum_from_bytes(a, length);
95     if (!curve->m.a) goto error;
96     curve->m.b = bignum_from_bytes(b, length);
97     if (!curve->m.b) goto error;
98
99     /* Generator */
100     curve->m.G.x = bignum_from_bytes(Gx, length);
101     if (!curve->m.G.x) goto error;
102     curve->m.G.y = NULL;
103     curve->m.G.z = NULL;
104     curve->m.G.curve = curve;
105     curve->m.G.infinity = 0;
106
107     return 1;
108   error:
109     if (curve->p) freebn(curve->p);
110     if (curve->m.a) freebn(curve->m.a);
111     if (curve->m.b) freebn(curve->m.b);
112     return 0;
113 }
114
115 static int initialise_ecurve(struct ec_curve *curve, int bits, unsigned char *p,
116                              unsigned char *l, unsigned char *d,
117                              unsigned char *Bx, unsigned char *By)
118 {
119     int length = bits / 8;
120     if (bits % 8) ++length;
121
122     curve->type = EC_EDWARDS;
123
124     curve->fieldBits = bits;
125     curve->p = bignum_from_bytes(p, length);
126     if (!curve->p) goto error;
127
128     /* Curve co-efficients */
129     curve->e.l = bignum_from_bytes(l, length);
130     if (!curve->e.l) goto error;
131     curve->e.d = bignum_from_bytes(d, length);
132     if (!curve->e.d) goto error;
133
134     /* Group order and generator */
135     curve->e.B.x = bignum_from_bytes(Bx, length);
136     if (!curve->e.B.x) goto error;
137     curve->e.B.y = bignum_from_bytes(By, length);
138     if (!curve->e.B.y) goto error;
139     curve->e.B.curve = curve;
140     curve->e.B.infinity = 0;
141
142     return 1;
143   error:
144     if (curve->p) freebn(curve->p);
145     if (curve->e.l) freebn(curve->e.l);
146     if (curve->e.d) freebn(curve->e.d);
147     if (curve->e.B.x) freebn(curve->e.B.x);
148     return 0;
149 }
150
151 static struct ec_curve *ec_p256(void)
152 {
153     static struct ec_curve curve = { 0 };
154     static unsigned char initialised = 0;
155
156     if (!initialised)
157     {
158         unsigned char p[] = {
159             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
160             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
161             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
162             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
163         };
164         unsigned char a[] = {
165             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
166             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
167             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
168             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
169         };
170         unsigned char b[] = {
171             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
172             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
173             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
174             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
175         };
176         unsigned char n[] = {
177             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
178             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
179             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
180             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
181         };
182         unsigned char Gx[] = {
183             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
184             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
185             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
186             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
187         };
188         unsigned char Gy[] = {
189             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
190             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
191             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
192             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
193         };
194
195         if (!initialise_wcurve(&curve, 256, p, a, b, n, Gx, Gy)) {
196             return NULL;
197         }
198
199         curve.name = "nistp256";
200
201         /* Now initialised, no need to do it again */
202         initialised = 1;
203     }
204
205     return &curve;
206 }
207
208 static struct ec_curve *ec_p384(void)
209 {
210     static struct ec_curve curve = { 0 };
211     static unsigned char initialised = 0;
212
213     if (!initialised)
214     {
215         unsigned char p[] = {
216             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
217             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
218             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
219             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
220             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
221             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
222         };
223         unsigned char a[] = {
224             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
225             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
226             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
227             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
228             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
229             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
230         };
231         unsigned char b[] = {
232             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
233             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
234             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
235             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
236             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
237             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
238         };
239         unsigned char n[] = {
240             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
241             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
242             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
243             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
244             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
245             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
246         };
247         unsigned char Gx[] = {
248             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
249             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
250             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
251             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
252             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
253             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
254         };
255         unsigned char Gy[] = {
256             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
257             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
258             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
259             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
260             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
261             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
262         };
263
264         if (!initialise_wcurve(&curve, 384, p, a, b, n, Gx, Gy)) {
265             return NULL;
266         }
267
268         curve.name = "nistp384";
269
270         /* Now initialised, no need to do it again */
271         initialised = 1;
272     }
273
274     return &curve;
275 }
276
277 static struct ec_curve *ec_p521(void)
278 {
279     static struct ec_curve curve = { 0 };
280     static unsigned char initialised = 0;
281
282     if (!initialised)
283     {
284         unsigned char p[] = {
285             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
286             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
287             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
288             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
289             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
290             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
291             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
292             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
293             0xff, 0xff
294         };
295         unsigned char a[] = {
296             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
297             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
298             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
299             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
300             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
301             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
302             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
303             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
304             0xff, 0xfc
305         };
306         unsigned char b[] = {
307             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
308             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
309             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
310             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
311             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
312             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
313             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
314             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
315             0x3f, 0x00
316         };
317         unsigned char n[] = {
318             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
319             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
320             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
321             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
322             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
323             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
324             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
325             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
326             0x64, 0x09
327         };
328         unsigned char Gx[] = {
329             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
330             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
331             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
332             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
333             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
334             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
335             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
336             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
337             0xbd, 0x66
338         };
339         unsigned char Gy[] = {
340             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
341             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
342             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
343             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
344             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
345             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
346             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
347             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
348             0x66, 0x50
349         };
350
351         if (!initialise_wcurve(&curve, 521, p, a, b, n, Gx, Gy)) {
352             return NULL;
353         }
354
355         curve.name = "nistp521";
356
357         /* Now initialised, no need to do it again */
358         initialised = 1;
359     }
360
361     return &curve;
362 }
363
364 static struct ec_curve *ec_curve25519(void)
365 {
366     static struct ec_curve curve = { 0 };
367     static unsigned char initialised = 0;
368
369     if (!initialised)
370     {
371         unsigned char p[] = {
372             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
373             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
374             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
375             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
376         };
377         unsigned char a[] = {
378             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
379             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
380             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
381             0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x6d, 0x06
382         };
383         unsigned char b[] = {
384             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
385             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
386             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
387             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01
388         };
389         unsigned char gx[32] = {
390             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
391             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
392             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
393             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09
394         };
395
396         if (!initialise_mcurve(&curve, 256, p, a, b, gx)) {
397             return NULL;
398         }
399
400         /* This curve doesn't need a name, because it's never used in
401          * any format that embeds the curve name */
402         curve.name = NULL;
403
404         /* Now initialised, no need to do it again */
405         initialised = 1;
406     }
407
408     return &curve;
409 }
410
411 static struct ec_curve *ec_ed25519(void)
412 {
413     static struct ec_curve curve = { 0 };
414     static unsigned char initialised = 0;
415
416     if (!initialised)
417     {
418         unsigned char q[] = {
419             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
420             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
421             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
422             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
423         };
424         unsigned char l[32] = {
425             0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
426             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
427             0x14, 0xde, 0xf9, 0xde, 0xa2, 0xf7, 0x9c, 0xd6,
428             0x58, 0x12, 0x63, 0x1a, 0x5c, 0xf5, 0xd3, 0xed
429         };
430         unsigned char d[32] = {
431             0x52, 0x03, 0x6c, 0xee, 0x2b, 0x6f, 0xfe, 0x73,
432             0x8c, 0xc7, 0x40, 0x79, 0x77, 0x79, 0xe8, 0x98,
433             0x00, 0x70, 0x0a, 0x4d, 0x41, 0x41, 0xd8, 0xab,
434             0x75, 0xeb, 0x4d, 0xca, 0x13, 0x59, 0x78, 0xa3
435         };
436         unsigned char Bx[32] = {
437             0x21, 0x69, 0x36, 0xd3, 0xcd, 0x6e, 0x53, 0xfe,
438             0xc0, 0xa4, 0xe2, 0x31, 0xfd, 0xd6, 0xdc, 0x5c,
439             0x69, 0x2c, 0xc7, 0x60, 0x95, 0x25, 0xa7, 0xb2,
440             0xc9, 0x56, 0x2d, 0x60, 0x8f, 0x25, 0xd5, 0x1a
441         };
442         unsigned char By[32] = {
443             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
444             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
445             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
446             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x58
447         };
448
449         /* This curve doesn't need a name, because it's never used in
450          * any format that embeds the curve name */
451         curve.name = NULL;
452
453         if (!initialise_ecurve(&curve, 256, q, l, d, Bx, By)) {
454             return NULL;
455         }
456
457         /* Now initialised, no need to do it again */
458         initialised = 1;
459     }
460
461     return &curve;
462 }
463
464 /* Return 1 if a is -3 % p, otherwise return 0
465  * This is used because there are some maths optimisations */
466 static int ec_aminus3(const struct ec_curve *curve)
467 {
468     int ret;
469     Bignum _p;
470
471     if (curve->type != EC_WEIERSTRASS) {
472         return 0;
473     }
474
475     _p = bignum_add_long(curve->w.a, 3);
476     if (!_p) return 0;
477
478     ret = !bignum_cmp(curve->p, _p);
479     freebn(_p);
480     return ret;
481 }
482
483 /* ----------------------------------------------------------------------
484  * Elliptic curve field maths
485  */
486
487 static Bignum ecf_add(const Bignum a, const Bignum b,
488                       const struct ec_curve *curve)
489 {
490     Bignum a1, b1, ab, ret;
491
492     a1 = bigmod(a, curve->p);
493     if (!a1) return NULL;
494     b1 = bigmod(b, curve->p);
495     if (!b1)
496     {
497         freebn(a1);
498         return NULL;
499     }
500
501     ab = bigadd(a1, b1);
502     freebn(a1);
503     freebn(b1);
504     if (!ab) return NULL;
505
506     ret = bigmod(ab, curve->p);
507     freebn(ab);
508
509     return ret;
510 }
511
512 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
513 {
514     return modmul(a, a, curve->p);
515 }
516
517 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
518 {
519     Bignum ret, tmp;
520
521     /* Double */
522     tmp = bignum_lshift(a, 1);
523     if (!tmp) return NULL;
524
525     /* Add itself (i.e. treble) */
526     ret = bigadd(tmp, a);
527     freebn(tmp);
528
529     /* Normalise */
530     while (ret != NULL && bignum_cmp(ret, curve->p) >= 0)
531     {
532         tmp = bigsub(ret, curve->p);
533         freebn(ret);
534         ret = tmp;
535     }
536
537     return ret;
538 }
539
540 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
541 {
542     Bignum ret = bignum_lshift(a, 1);
543     if (!ret) return NULL;
544     if (bignum_cmp(ret, curve->p) >= 0)
545     {
546         Bignum tmp = bigsub(ret, curve->p);
547         freebn(ret);
548         return tmp;
549     }
550     else
551     {
552         return ret;
553     }
554 }
555
556 /* ----------------------------------------------------------------------
557  * Memory functions
558  */
559
560 void ec_point_free(struct ec_point *point)
561 {
562     if (point == NULL) return;
563     point->curve = 0;
564     if (point->x) freebn(point->x);
565     if (point->y) freebn(point->y);
566     if (point->z) freebn(point->z);
567     point->infinity = 0;
568     sfree(point);
569 }
570
571 static struct ec_point *ec_point_new(const struct ec_curve *curve,
572                                      const Bignum x, const Bignum y, const Bignum z,
573                                      unsigned char infinity)
574 {
575     struct ec_point *point = snewn(1, struct ec_point);
576     point->curve = curve;
577     point->x = x;
578     point->y = y;
579     point->z = z;
580     point->infinity = infinity ? 1 : 0;
581     return point;
582 }
583
584 static struct ec_point *ec_point_copy(const struct ec_point *a)
585 {
586     if (a == NULL) return NULL;
587     return ec_point_new(a->curve,
588                         a->x ? copybn(a->x) : NULL,
589                         a->y ? copybn(a->y) : NULL,
590                         a->z ? copybn(a->z) : NULL,
591                         a->infinity);
592 }
593
594 static int ec_point_verify(const struct ec_point *a)
595 {
596     if (a->infinity) {
597         return 1;
598     } else if (a->curve->type == EC_EDWARDS) {
599         /* Check y^2 - x^2 - 1 - d * x^2 * y^2 == 0 */
600         Bignum y2, x2, tmp, tmp2, tmp3;
601         int ret;
602
603         y2 = ecf_square(a->y, a->curve);
604         if (!y2) {
605             return 0;
606         }
607         x2 = ecf_square(a->x, a->curve);
608         if (!x2) {
609             freebn(y2);
610             return 0;
611         }
612         tmp = modmul(a->curve->e.d, x2, a->curve->p);
613         if (!tmp) {
614             freebn(x2);
615             freebn(y2);
616             return 0;
617         }
618         tmp2 = modmul(tmp, y2, a->curve->p);
619         freebn(tmp);
620         if (!tmp2) {
621             freebn(x2);
622             freebn(y2);
623             return 0;
624         }
625         tmp = modsub(y2, x2, a->curve->p);
626         freebn(y2);
627         freebn(x2);
628         if (!tmp) {
629             freebn(tmp2);
630             return 0;
631         }
632         tmp3 = modsub(tmp, tmp2, a->curve->p);
633         freebn(tmp);
634         freebn(tmp2);
635         if (!tmp3) {
636             return 0;
637         }
638         ret = !bignum_cmp(tmp3, One);
639         freebn(tmp3);
640         return ret;
641     } else if (a->curve->type == EC_WEIERSTRASS) {
642         /* Verify y^2 = x^3 + ax + b */
643         int ret = 0;
644
645         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
646
647         Bignum Three = bignum_from_long(3);
648         if (!Three) return 0;
649
650         lhs = modmul(a->y, a->y, a->curve->p);
651         if (!lhs) goto error;
652
653         /* This uses montgomery multiplication to optimise */
654         x3 = modpow(a->x, Three, a->curve->p);
655         freebn(Three);
656         if (!x3) goto error;
657         ax = modmul(a->curve->w.a, a->x, a->curve->p);
658         if (!ax) goto error;
659         x3ax = bigadd(x3, ax);
660         if (!x3ax) goto error;
661         freebn(x3); x3 = NULL;
662         freebn(ax); ax = NULL;
663         x3axm = bigmod(x3ax, a->curve->p);
664         if (!x3axm) goto error;
665         freebn(x3ax); x3ax = NULL;
666         x3axb = bigadd(x3axm, a->curve->w.b);
667         if (!x3axb) goto error;
668         freebn(x3axm); x3axm = NULL;
669         rhs = bigmod(x3axb, a->curve->p);
670         if (!rhs) goto error;
671         freebn(x3axb);
672
673         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
674         freebn(lhs);
675         freebn(rhs);
676
677         return ret;
678
679       error:
680         if (x3) freebn(x3);
681         if (ax) freebn(ax);
682         if (x3ax) freebn(x3ax);
683         if (x3axm) freebn(x3axm);
684         if (x3axb) freebn(x3axb);
685         if (lhs) freebn(lhs);
686         return 0;
687     } else {
688         return 0;
689     }
690 }
691
692 /* ----------------------------------------------------------------------
693  * Elliptic curve point maths
694  */
695
696 /* Returns 1 on success and 0 on memory error */
697 static int ecp_normalise(struct ec_point *a)
698 {
699     if (!a) {
700         /* No point */
701         return 0;
702     }
703
704     if (a->infinity) {
705         /* Point is at infinity - i.e. normalised */
706         return 1;
707     }
708
709     if (a->curve->type == EC_WEIERSTRASS) {
710         /* In Jacobian Coordinates the triple (X, Y, Z) represents
711            the affine point (X / Z^2, Y / Z^3) */
712
713         Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
714
715         if (!a->x || !a->y) {
716             /* No point defined */
717             return 0;
718         } else if (!a->z) {
719             /* Already normalised */
720             return 1;
721         }
722
723         Z2 = ecf_square(a->z, a->curve);
724         if (!Z2) {
725             return 0;
726         }
727         Z2inv = modinv(Z2, a->curve->p);
728         if (!Z2inv) {
729             freebn(Z2);
730             return 0;
731         }
732         tx = modmul(a->x, Z2inv, a->curve->p);
733         freebn(Z2inv);
734         if (!tx) {
735             freebn(Z2);
736             return 0;
737         }
738
739         Z3 = modmul(Z2, a->z, a->curve->p);
740         freebn(Z2);
741         if (!Z3) {
742             freebn(tx);
743             return 0;
744         }
745         Z3inv = modinv(Z3, a->curve->p);
746         freebn(Z3);
747         if (!Z3inv) {
748             freebn(tx);
749             return 0;
750         }
751         ty = modmul(a->y, Z3inv, a->curve->p);
752         freebn(Z3inv);
753         if (!ty) {
754             freebn(tx);
755             return 0;
756         }
757
758         freebn(a->x);
759         a->x = tx;
760         freebn(a->y);
761         a->y = ty;
762         freebn(a->z);
763         a->z = NULL;
764         return 1;
765     } else if (a->curve->type == EC_MONTGOMERY) {
766         /* In Montgomery (X : Z) represents the x co-ord (X / Z, ?) */
767
768         Bignum tmp, tmp2;
769
770         if (!a->x) {
771             /* No point defined */
772             return 0;
773         } else if (!a->z) {
774             /* Already normalised */
775             return 1;
776         }
777
778         tmp = modinv(a->z, a->curve->p);
779         if (!tmp) {
780             return 0;
781         }
782         tmp2 = modmul(a->x, tmp, a->curve->p);
783         freebn(tmp);
784         if (!tmp2) {
785             return 0;
786         }
787
788         freebn(a->z);
789         a->z = NULL;
790         freebn(a->x);
791         a->x = tmp2;
792         return 1;
793     } else if (a->curve->type == EC_EDWARDS) {
794         /* Always normalised */
795         return 1;
796     } else {
797         return 0;
798     }
799 }
800
801 static struct ec_point *ecp_doublew(const struct ec_point *a, const int aminus3)
802 {
803     Bignum S, M, outx, outy, outz;
804
805     if (bignum_cmp(a->y, Zero) == 0)
806     {
807         /* Identity */
808         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
809     }
810
811     /* S = 4*X*Y^2 */
812     {
813         Bignum Y2, XY2, _2XY2;
814
815         Y2 = ecf_square(a->y, a->curve);
816         if (!Y2) {
817             return NULL;
818         }
819         XY2 = modmul(a->x, Y2, a->curve->p);
820         freebn(Y2);
821         if (!XY2) {
822             return NULL;
823         }
824
825         _2XY2 = ecf_double(XY2, a->curve);
826         freebn(XY2);
827         if (!_2XY2) {
828             return NULL;
829         }
830         S = ecf_double(_2XY2, a->curve);
831         freebn(_2XY2);
832         if (!S) {
833             return NULL;
834         }
835     }
836
837     /* Faster calculation if a = -3 */
838     if (aminus3) {
839         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
840         Bignum Z2, XpZ2, XmZ2, second;
841
842         if (a->z == NULL) {
843             Z2 = copybn(One);
844         } else {
845             Z2 = ecf_square(a->z, a->curve);
846         }
847         if (!Z2) {
848             freebn(S);
849             return NULL;
850         }
851
852         XpZ2 = ecf_add(a->x, Z2, a->curve);
853         if (!XpZ2) {
854             freebn(S);
855             freebn(Z2);
856             return NULL;
857         }
858         XmZ2 = modsub(a->x, Z2, a->curve->p);
859         freebn(Z2);
860         if (!XmZ2) {
861             freebn(S);
862             freebn(XpZ2);
863             return NULL;
864         }
865
866         second = modmul(XpZ2, XmZ2, a->curve->p);
867         freebn(XpZ2);
868         freebn(XmZ2);
869         if (!second) {
870             freebn(S);
871             return NULL;
872         }
873
874         M = ecf_treble(second, a->curve);
875         freebn(second);
876         if (!M) {
877             freebn(S);
878             return NULL;
879         }
880     } else {
881         /* M = 3*X^2 + a*Z^4 */
882         Bignum _3X2, X2, aZ4;
883
884         if (a->z == NULL) {
885             aZ4 = copybn(a->curve->w.a);
886         } else {
887             Bignum Z2, Z4;
888
889             Z2 = ecf_square(a->z, a->curve);
890             if (!Z2) {
891                 freebn(S);
892                 return NULL;
893             }
894             Z4 = ecf_square(Z2, a->curve);
895             freebn(Z2);
896             if (!Z4) {
897                 freebn(S);
898                 return NULL;
899             }
900             aZ4 = modmul(a->curve->w.a, Z4, a->curve->p);
901             freebn(Z4);
902         }
903         if (!aZ4) {
904             freebn(S);
905             return NULL;
906         }
907
908         X2 = modmul(a->x, a->x, a->curve->p);
909         if (!X2) {
910             freebn(S);
911             freebn(aZ4);
912             return NULL;
913         }
914         _3X2 = ecf_treble(X2, a->curve);
915         freebn(X2);
916         if (!_3X2) {
917             freebn(S);
918             freebn(aZ4);
919             return NULL;
920         }
921         M = ecf_add(_3X2, aZ4, a->curve);
922         freebn(_3X2);
923         freebn(aZ4);
924         if (!M) {
925             freebn(S);
926             return NULL;
927         }
928     }
929
930     /* X' = M^2 - 2*S */
931     {
932         Bignum M2, _2S;
933
934         M2 = ecf_square(M, a->curve);
935         if (!M2) {
936             freebn(S);
937             freebn(M);
938             return NULL;
939         }
940
941         _2S = ecf_double(S, a->curve);
942         if (!_2S) {
943             freebn(M2);
944             freebn(S);
945             freebn(M);
946             return NULL;
947         }
948
949         outx = modsub(M2, _2S, a->curve->p);
950         freebn(M2);
951         freebn(_2S);
952         if (!outx) {
953             freebn(S);
954             freebn(M);
955             return NULL;
956         }
957     }
958
959     /* Y' = M*(S - X') - 8*Y^4 */
960     {
961         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
962
963         SX = modsub(S, outx, a->curve->p);
964         freebn(S);
965         if (!SX) {
966             freebn(M);
967             freebn(outx);
968             return NULL;
969         }
970         MSX = modmul(M, SX, a->curve->p);
971         freebn(SX);
972         freebn(M);
973         if (!MSX) {
974             freebn(outx);
975             return NULL;
976         }
977         Y2 = ecf_square(a->y, a->curve);
978         if (!Y2) {
979             freebn(outx);
980             freebn(MSX);
981             return NULL;
982         }
983         Y4 = ecf_square(Y2, a->curve);
984         freebn(Y2);
985         if (!Y4) {
986             freebn(outx);
987             freebn(MSX);
988             return NULL;
989         }
990         Eight = bignum_from_long(8);
991         if (!Eight) {
992             freebn(outx);
993             freebn(MSX);
994             freebn(Y4);
995             return NULL;
996         }
997         _8Y4 = modmul(Eight, Y4, a->curve->p);
998         freebn(Eight);
999         freebn(Y4);
1000         if (!_8Y4) {
1001             freebn(outx);
1002             freebn(MSX);
1003             return NULL;
1004         }
1005         outy = modsub(MSX, _8Y4, a->curve->p);
1006         freebn(MSX);
1007         freebn(_8Y4);
1008         if (!outy) {
1009             freebn(outx);
1010             return NULL;
1011         }
1012     }
1013
1014     /* Z' = 2*Y*Z */
1015     {
1016         Bignum YZ;
1017
1018         if (a->z == NULL) {
1019             YZ = copybn(a->y);
1020         } else {
1021             YZ = modmul(a->y, a->z, a->curve->p);
1022         }
1023         if (!YZ) {
1024             freebn(outx);
1025             freebn(outy);
1026             return NULL;
1027         }
1028
1029         outz = ecf_double(YZ, a->curve);
1030         freebn(YZ);
1031         if (!outz) {
1032             freebn(outx);
1033             freebn(outy);
1034             return NULL;
1035         }
1036     }
1037
1038     return ec_point_new(a->curve, outx, outy, outz, 0);
1039 }
1040
1041 static struct ec_point *ecp_doublem(const struct ec_point *a)
1042 {
1043     Bignum z, outx, outz, xpz, xmz;
1044
1045     z = a->z;
1046     if (!z) {
1047         z = One;
1048     }
1049
1050     /* 4xz = (x + z)^2 - (x - z)^2 */
1051     {
1052         Bignum tmp;
1053
1054         tmp = ecf_add(a->x, z, a->curve);
1055         if (!tmp) {
1056             return NULL;
1057         }
1058         xpz = ecf_square(tmp, a->curve);
1059         freebn(tmp);
1060         if (!xpz) {
1061             return NULL;
1062         }
1063
1064         tmp = modsub(a->x, z, a->curve->p);
1065         if (!tmp) {
1066             freebn(xpz);
1067             return NULL;
1068         }
1069         xmz = ecf_square(tmp, a->curve);
1070         freebn(tmp);
1071         if (!xmz) {
1072             freebn(xpz);
1073             return NULL;
1074         }
1075     }
1076
1077     /* outx = (x + z)^2 * (x - z)^2 */
1078     outx = modmul(xpz, xmz, a->curve->p);
1079     if (!outx) {
1080         freebn(xpz);
1081         freebn(xmz);
1082         return NULL;
1083     }
1084
1085     /* outz = 4xz * ((x - z)^2 + ((A + 2) / 4)*4xz) */
1086     {
1087         Bignum _4xz, tmp, tmp2, tmp3;
1088
1089         tmp = bignum_from_long(2);
1090         if (!tmp) {
1091             freebn(xpz);
1092             freebn(outx);
1093             freebn(xmz);
1094             return NULL;
1095         }
1096         tmp2 = ecf_add(a->curve->m.a, tmp, a->curve);
1097         freebn(tmp);
1098         if (!tmp2) {
1099             freebn(xpz);
1100             freebn(outx);
1101             freebn(xmz);
1102             return NULL;
1103         }
1104
1105         _4xz = modsub(xpz, xmz, a->curve->p);
1106         freebn(xpz);
1107         if (!_4xz) {
1108             freebn(tmp2);
1109             freebn(outx);
1110             freebn(xmz);
1111             return NULL;
1112         }
1113         tmp = modmul(tmp2, _4xz, a->curve->p);
1114         freebn(tmp2);
1115         if (!tmp) {
1116             freebn(_4xz);
1117             freebn(outx);
1118             freebn(xmz);
1119             return NULL;
1120         }
1121
1122         tmp2 = bignum_from_long(4);
1123         if (!tmp2) {
1124             freebn(tmp);
1125             freebn(_4xz);
1126             freebn(outx);
1127             freebn(xmz);
1128             return NULL;
1129         }
1130         tmp3 = modinv(tmp2, a->curve->p);
1131         freebn(tmp2);
1132         if (!tmp3) {
1133             freebn(tmp);
1134             freebn(_4xz);
1135             freebn(outx);
1136             freebn(xmz);
1137             return NULL;
1138         }
1139         tmp2 = modmul(tmp, tmp3, a->curve->p);
1140         freebn(tmp);
1141         freebn(tmp3);
1142         if (!tmp2) {
1143             freebn(_4xz);
1144             freebn(outx);
1145             freebn(xmz);
1146             return NULL;
1147         }
1148
1149         tmp = ecf_add(xmz, tmp2, a->curve);
1150         freebn(xmz);
1151         freebn(tmp2);
1152         if (!tmp) {
1153             freebn(_4xz);
1154             freebn(outx);
1155             return NULL;
1156         }
1157         outz = modmul(_4xz, tmp, a->curve->p);
1158         freebn(_4xz);
1159         freebn(tmp);
1160         if (!outz) {
1161             freebn(outx);
1162             return NULL;
1163         }
1164     }
1165
1166     return ec_point_new(a->curve, outx, NULL, outz, 0);
1167 }
1168
1169 /* Forward declaration for Edwards curve doubling */
1170 static struct ec_point *ecp_add(const struct ec_point *a,
1171                                 const struct ec_point *b,
1172                                 const int aminus3);
1173
1174 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
1175 {
1176     if (a->infinity)
1177     {
1178         /* Identity */
1179         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
1180     }
1181
1182     if (a->curve->type == EC_EDWARDS)
1183     {
1184         return ecp_add(a, a, aminus3);
1185     }
1186     else if (a->curve->type == EC_WEIERSTRASS)
1187     {
1188         return ecp_doublew(a, aminus3);
1189     }
1190     else
1191     {
1192         return ecp_doublem(a);
1193     }
1194 }
1195
1196 static struct ec_point *ecp_addw(const struct ec_point *a,
1197                                  const struct ec_point *b,
1198                                  const int aminus3)
1199 {
1200     Bignum U1, U2, S1, S2, outx, outy, outz;
1201
1202     /* U1 = X1*Z2^2 */
1203     /* S1 = Y1*Z2^3 */
1204     if (b->z) {
1205         Bignum Z2, Z3;
1206
1207         Z2 = ecf_square(b->z, a->curve);
1208         if (!Z2) {
1209             return NULL;
1210         }
1211         U1 = modmul(a->x, Z2, a->curve->p);
1212         if (!U1) {
1213             freebn(Z2);
1214             return NULL;
1215         }
1216         Z3 = modmul(Z2, b->z, a->curve->p);
1217         freebn(Z2);
1218         if (!Z3) {
1219             freebn(U1);
1220             return NULL;
1221         }
1222         S1 = modmul(a->y, Z3, a->curve->p);
1223         freebn(Z3);
1224         if (!S1) {
1225             freebn(U1);
1226             return NULL;
1227         }
1228     } else {
1229         U1 = copybn(a->x);
1230         if (!U1) {
1231             return NULL;
1232         }
1233         S1 = copybn(a->y);
1234         if (!S1) {
1235             freebn(U1);
1236             return NULL;
1237         }
1238     }
1239
1240     /* U2 = X2*Z1^2 */
1241     /* S2 = Y2*Z1^3 */
1242     if (a->z) {
1243         Bignum Z2, Z3;
1244
1245         Z2 = ecf_square(a->z, b->curve);
1246         if (!Z2) {
1247             freebn(U1);
1248             freebn(S1);
1249             return NULL;
1250         }
1251         U2 = modmul(b->x, Z2, b->curve->p);
1252         if (!U2) {
1253             freebn(U1);
1254             freebn(S1);
1255             freebn(Z2);
1256             return NULL;
1257         }
1258         Z3 = modmul(Z2, a->z, b->curve->p);
1259         freebn(Z2);
1260         if (!Z3) {
1261             freebn(U1);
1262             freebn(S1);
1263             freebn(U2);
1264             return NULL;
1265         }
1266         S2 = modmul(b->y, Z3, b->curve->p);
1267         freebn(Z3);
1268         if (!S2) {
1269             freebn(U1);
1270             freebn(S1);
1271             freebn(U2);
1272             return NULL;
1273         }
1274     } else {
1275         U2 = copybn(b->x);
1276         if (!U2) {
1277             freebn(U1);
1278             freebn(S1);
1279             return NULL;
1280         }
1281         S2 = copybn(b->y);
1282         if (!S2) {
1283             freebn(U1);
1284             freebn(S1);
1285             freebn(U2);
1286             return NULL;
1287         }
1288     }
1289
1290     /* Check if multiplying by self */
1291     if (bignum_cmp(U1, U2) == 0)
1292     {
1293         freebn(U1);
1294         freebn(U2);
1295         if (bignum_cmp(S1, S2) == 0)
1296         {
1297             freebn(S1);
1298             freebn(S2);
1299             return ecp_double(a, aminus3);
1300         }
1301         else
1302         {
1303             freebn(S1);
1304             freebn(S2);
1305             /* Infinity */
1306             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
1307         }
1308     }
1309
1310     {
1311         Bignum H, R, UH2, H3;
1312
1313         /* H = U2 - U1 */
1314         H = modsub(U2, U1, a->curve->p);
1315         freebn(U2);
1316         if (!H) {
1317             freebn(U1);
1318             freebn(S1);
1319             freebn(S2);
1320             return NULL;
1321         }
1322
1323         /* R = S2 - S1 */
1324         R = modsub(S2, S1, a->curve->p);
1325         freebn(S2);
1326         if (!R) {
1327             freebn(H);
1328             freebn(S1);
1329             freebn(U1);
1330             return NULL;
1331         }
1332
1333         /* X3 = R^2 - H^3 - 2*U1*H^2 */
1334         {
1335             Bignum R2, H2, _2UH2, first;
1336
1337             H2 = ecf_square(H, a->curve);
1338             if (!H2) {
1339                 freebn(U1);
1340                 freebn(S1);
1341                 freebn(H);
1342                 freebn(R);
1343                 return NULL;
1344             }
1345             UH2 = modmul(U1, H2, a->curve->p);
1346             freebn(U1);
1347             if (!UH2) {
1348                 freebn(H2);
1349                 freebn(S1);
1350                 freebn(H);
1351                 freebn(R);
1352                 return NULL;
1353             }
1354             H3 = modmul(H2, H, a->curve->p);
1355             freebn(H2);
1356             if (!H3) {
1357                 freebn(UH2);
1358                 freebn(S1);
1359                 freebn(H);
1360                 freebn(R);
1361                 return NULL;
1362             }
1363             R2 = ecf_square(R, a->curve);
1364             if (!R2) {
1365                 freebn(H3);
1366                 freebn(UH2);
1367                 freebn(S1);
1368                 freebn(H);
1369                 freebn(R);
1370                 return NULL;
1371             }
1372             _2UH2 = ecf_double(UH2, a->curve);
1373             if (!_2UH2) {
1374                 freebn(R2);
1375                 freebn(H3);
1376                 freebn(UH2);
1377                 freebn(S1);
1378                 freebn(H);
1379                 freebn(R);
1380                 return NULL;
1381             }
1382             first = modsub(R2, H3, a->curve->p);
1383             freebn(R2);
1384             if (!first) {
1385                 freebn(H3);
1386                 freebn(_2UH2);
1387                 freebn(UH2);
1388                 freebn(S1);
1389                 freebn(H);
1390                 freebn(R);
1391                 return NULL;
1392             }
1393             outx = modsub(first, _2UH2, a->curve->p);
1394             freebn(first);
1395             freebn(_2UH2);
1396             if (!outx) {
1397                 freebn(H3);
1398                 freebn(UH2);
1399                 freebn(S1);
1400                 freebn(H);
1401                 freebn(R);
1402                 return NULL;
1403             }
1404         }
1405
1406         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
1407         {
1408             Bignum RUH2mX, UH2mX, SH3;
1409
1410             UH2mX = modsub(UH2, outx, a->curve->p);
1411             freebn(UH2);
1412             if (!UH2mX) {
1413                 freebn(outx);
1414                 freebn(H3);
1415                 freebn(S1);
1416                 freebn(H);
1417                 freebn(R);
1418                 return NULL;
1419             }
1420             RUH2mX = modmul(R, UH2mX, a->curve->p);
1421             freebn(UH2mX);
1422             freebn(R);
1423             if (!RUH2mX) {
1424                 freebn(outx);
1425                 freebn(H3);
1426                 freebn(S1);
1427                 freebn(H);
1428                 return NULL;
1429             }
1430             SH3 = modmul(S1, H3, a->curve->p);
1431             freebn(S1);
1432             freebn(H3);
1433             if (!SH3) {
1434                 freebn(RUH2mX);
1435                 freebn(outx);
1436                 freebn(H);
1437                 return NULL;
1438             }
1439
1440             outy = modsub(RUH2mX, SH3, a->curve->p);
1441             freebn(RUH2mX);
1442             freebn(SH3);
1443             if (!outy) {
1444                 freebn(outx);
1445                 freebn(H);
1446                 return NULL;
1447             }
1448         }
1449
1450         /* Z3 = H*Z1*Z2 */
1451         if (a->z && b->z) {
1452             Bignum ZZ;
1453
1454             ZZ = modmul(a->z, b->z, a->curve->p);
1455             if (!ZZ) {
1456                 freebn(outx);
1457                 freebn(outy);
1458                 freebn(H);
1459                 return NULL;
1460             }
1461             outz = modmul(H, ZZ, a->curve->p);
1462             freebn(H);
1463             freebn(ZZ);
1464             if (!outz) {
1465                 freebn(outx);
1466                 freebn(outy);
1467                 return NULL;
1468             }
1469         } else if (a->z) {
1470             outz = modmul(H, a->z, a->curve->p);
1471             freebn(H);
1472             if (!outz) {
1473                 freebn(outx);
1474                 freebn(outy);
1475                 return NULL;
1476             }
1477         } else if (b->z) {
1478             outz = modmul(H, b->z, a->curve->p);
1479             freebn(H);
1480             if (!outz) {
1481                 freebn(outx);
1482                 freebn(outy);
1483                 return NULL;
1484             }
1485         } else {
1486             outz = H;
1487         }
1488     }
1489
1490     return ec_point_new(a->curve, outx, outy, outz, 0);
1491 }
1492
1493 static struct ec_point *ecp_addm(const struct ec_point *a,
1494                                  const struct ec_point *b,
1495                                  const struct ec_point *base)
1496 {
1497     Bignum outx, outz, az, bz;
1498
1499     az = a->z;
1500     if (!az) {
1501         az = One;
1502     }
1503     bz = b->z;
1504     if (!bz) {
1505         bz = One;
1506     }
1507
1508     /* a-b is maintained at 1 due to Montgomery ladder implementation */
1509     /* Xa+b = Za-b * ((Xa - Za)*(Xb + Zb) + (Xa + Za)*(Xb - Zb))^2 */
1510     /* Za+b = Xa-b * ((Xa - Za)*(Xb + Zb) - (Xa + Za)*(Xb - Zb))^2 */
1511     {
1512         Bignum tmp, tmp2, tmp3, tmp4;
1513
1514         /* (Xa + Za) * (Xb - Zb) */
1515         tmp = ecf_add(a->x, az, a->curve);
1516         if (!tmp) {
1517             return NULL;
1518         }
1519         tmp2 = modsub(b->x, bz, a->curve->p);
1520         if (!tmp2) {
1521             freebn(tmp);
1522             return NULL;
1523         }
1524         tmp3 = modmul(tmp, tmp2, a->curve->p);
1525         freebn(tmp);
1526         freebn(tmp2);
1527         if (!tmp3) {
1528             return NULL;
1529         }
1530
1531         /* (Xa - Za) * (Xb + Zb) */
1532         tmp = modsub(a->x, az, a->curve->p);
1533         if (!tmp) {
1534             freebn(tmp3);
1535             return NULL;
1536         }
1537         tmp2 = ecf_add(b->x, bz, a->curve);
1538         if (!tmp2) {
1539             freebn(tmp);
1540             freebn(tmp3);
1541             return NULL;
1542         }
1543         tmp4 = modmul(tmp, tmp2, a->curve->p);
1544         freebn(tmp);
1545         freebn(tmp2);
1546         if (!tmp4) {
1547             freebn(tmp3);
1548             return NULL;
1549         }
1550
1551         tmp = ecf_add(tmp3, tmp4, a->curve);
1552         if (!tmp) {
1553             freebn(tmp3);
1554             freebn(tmp4);
1555             return NULL;
1556         }
1557         outx = ecf_square(tmp, a->curve);
1558         freebn(tmp);
1559         if (!outx) {
1560             freebn(tmp3);
1561             freebn(tmp4);
1562             return NULL;
1563         }
1564
1565         tmp = modsub(tmp3, tmp4, a->curve->p);
1566         freebn(tmp3);
1567         freebn(tmp4);
1568         if (!tmp) {
1569             freebn(outx);
1570             return NULL;
1571         }
1572         tmp2 = ecf_square(tmp, a->curve);
1573         freebn(tmp);
1574         if (!tmp2) {
1575             freebn(outx);
1576             return NULL;
1577         }
1578         outz = modmul(base->x, tmp2, a->curve->p);
1579         freebn(tmp2);
1580         if (!outz) {
1581             freebn(outx);
1582             return NULL;
1583         }
1584     }
1585
1586     return ec_point_new(a->curve, outx, NULL, outz, 0);
1587 }
1588
1589 static struct ec_point *ecp_adde(const struct ec_point *a,
1590                                  const struct ec_point *b)
1591 {
1592     Bignum outx, outy, dmul;
1593
1594     /* outx = (a->x * b->y + b->x * a->y) /
1595      *        (1 + a->curve->e.d * a->x * b->x * a->y * b->y) */
1596     {
1597         Bignum tmp, tmp2, tmp3, tmp4;
1598
1599         tmp = modmul(a->x, b->y, a->curve->p);
1600         if (!tmp)
1601         {
1602             return NULL;
1603         }
1604         tmp2 = modmul(b->x, a->y, a->curve->p);
1605         if (!tmp2)
1606         {
1607             freebn(tmp);
1608             return NULL;
1609         }
1610         tmp3 = ecf_add(tmp, tmp2, a->curve);
1611         if (!tmp3)
1612         {
1613             freebn(tmp);
1614             freebn(tmp2);
1615             return NULL;
1616         }
1617
1618         tmp4 = modmul(tmp, tmp2, a->curve->p);
1619         freebn(tmp);
1620         freebn(tmp2);
1621         if (!tmp4)
1622         {
1623             freebn(tmp3);
1624             return NULL;
1625         }
1626         dmul = modmul(a->curve->e.d, tmp4, a->curve->p);
1627         freebn(tmp4);
1628         if (!dmul) {
1629             freebn(tmp3);
1630             return NULL;
1631         }
1632
1633         tmp = ecf_add(One, dmul, a->curve);
1634         if (!tmp)
1635         {
1636             freebn(tmp3);
1637             freebn(dmul);
1638             return NULL;
1639         }
1640         tmp2 = modinv(tmp, a->curve->p);
1641         freebn(tmp);
1642         if (!tmp2)
1643         {
1644             freebn(tmp3);
1645             freebn(dmul);
1646             return NULL;
1647         }
1648
1649         outx = modmul(tmp3, tmp2, a->curve->p);
1650         freebn(tmp3);
1651         freebn(tmp2);
1652         if (!outx)
1653         {
1654             freebn(dmul);
1655             return NULL;
1656         }
1657     }
1658
1659     /* outy = (a->y * b->y + a->x * b->x) /
1660      *        (1 - a->curve->e.d * a->x * b->x * a->y * b->y) */
1661     {
1662         Bignum tmp, tmp2, tmp3, tmp4;
1663
1664         tmp = modsub(One, dmul, a->curve->p);
1665         freebn(dmul);
1666         if (!tmp)
1667         {
1668             freebn(outx);
1669             return NULL;
1670         }
1671
1672         tmp2 = modinv(tmp, a->curve->p);
1673         freebn(tmp);
1674         if (!tmp2)
1675         {
1676             freebn(outx);
1677             return NULL;
1678         }
1679
1680         tmp = modmul(a->y, b->y, a->curve->p);
1681         if (!tmp)
1682         {
1683             freebn(tmp2);
1684             freebn(outx);
1685             return NULL;
1686         }
1687         tmp3 = modmul(a->x, b->x, a->curve->p);
1688         if (!tmp3)
1689         {
1690             freebn(tmp);
1691             freebn(tmp2);
1692             freebn(outx);
1693             return NULL;
1694         }
1695         tmp4 = ecf_add(tmp, tmp3, a->curve);
1696         freebn(tmp);
1697         freebn(tmp3);
1698         if (!tmp4)
1699         {
1700             freebn(tmp2);
1701             freebn(outx);
1702             return NULL;
1703         }
1704
1705         outy = modmul(tmp4, tmp2, a->curve->p);
1706         freebn(tmp4);
1707         freebn(tmp2);
1708         if (!outy)
1709         {
1710             freebn(outx);
1711             return NULL;
1712         }
1713     }
1714
1715     return ec_point_new(a->curve, outx, outy, NULL, 0);
1716 }
1717
1718 static struct ec_point *ecp_add(const struct ec_point *a,
1719                                 const struct ec_point *b,
1720                                 const int aminus3)
1721 {
1722     if (a->curve != b->curve) {
1723         return NULL;
1724     }
1725
1726     /* Check if multiplying by infinity */
1727     if (a->infinity) return ec_point_copy(b);
1728     if (b->infinity) return ec_point_copy(a);
1729
1730     if (a->curve->type == EC_EDWARDS)
1731     {
1732         return ecp_adde(a, b);
1733     }
1734
1735     if (a->curve->type == EC_WEIERSTRASS)
1736     {
1737         return ecp_addw(a, b, aminus3);
1738     }
1739
1740     return NULL;
1741 }
1742
1743 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1744 {
1745     struct ec_point *A, *ret;
1746     int bits, i;
1747
1748     A = ec_point_copy(a);
1749     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1750
1751     bits = bignum_bitcount(b);
1752     for (i = 0; ret != NULL && A != NULL && i < bits; ++i)
1753     {
1754         if (bignum_bit(b, i))
1755         {
1756             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1757             ec_point_free(ret);
1758             ret = tmp;
1759         }
1760         if (i+1 != bits)
1761         {
1762             struct ec_point *tmp = ecp_double(A, aminus3);
1763             ec_point_free(A);
1764             A = tmp;
1765         }
1766     }
1767
1768     if (!A) {
1769         ec_point_free(ret);
1770         ret = NULL;
1771     } else {
1772         ec_point_free(A);
1773     }
1774
1775     return ret;
1776 }
1777
1778 static struct ec_point *ecp_mulw(const struct ec_point *a, const Bignum b)
1779 {
1780     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1781
1782     if (!ecp_normalise(ret)) {
1783         ec_point_free(ret);
1784         return NULL;
1785     }
1786
1787     return ret;
1788 }
1789
1790 static struct ec_point *ecp_mule(const struct ec_point *a, const Bignum b)
1791 {
1792     int i;
1793     struct ec_point *ret;
1794
1795     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1796
1797     for (i = bignum_bitcount(b); i >= 0 && ret; --i)
1798     {
1799         {
1800             struct ec_point *tmp = ecp_double(ret, 0);
1801             ec_point_free(ret);
1802             ret = tmp;
1803         }
1804         if (ret && bignum_bit(b, i))
1805         {
1806             struct ec_point *tmp = ecp_add(ret, a, 0);
1807             ec_point_free(ret);
1808             ret = tmp;
1809         }
1810     }
1811
1812     return ret;
1813 }
1814
1815 static struct ec_point *ecp_mulm(const struct ec_point *p, const Bignum n)
1816 {
1817     struct ec_point *P1, *P2;
1818     int bits, i;
1819
1820     /* P1 <- P and P2 <- [2]P */
1821     P2 = ecp_double(p, 0);
1822     if (!P2) {
1823         return NULL;
1824     }
1825     P1 = ec_point_copy(p);
1826     if (!P1) {
1827         ec_point_free(P2);
1828         return NULL;
1829     }
1830
1831     /* for i = bits âˆ’ 2 down to 0 */
1832     bits = bignum_bitcount(n);
1833     for (i = bits - 2; P1 != NULL && P2 != NULL && i >= 0; --i)
1834     {
1835         if (!bignum_bit(n, i))
1836         {
1837             /* P2 <- P1 + P2 */
1838             struct ec_point *tmp = ecp_addm(P1, P2, p);
1839             ec_point_free(P2);
1840             P2 = tmp;
1841
1842             /* P1 <- [2]P1 */
1843             tmp = ecp_double(P1, 0);
1844             ec_point_free(P1);
1845             P1 = tmp;
1846         }
1847         else
1848         {
1849             /* P1 <- P1 + P2 */
1850             struct ec_point *tmp = ecp_addm(P1, P2, p);
1851             ec_point_free(P1);
1852             P1 = tmp;
1853
1854             /* P2 <- [2]P2 */
1855             tmp = ecp_double(P2, 0);
1856             ec_point_free(P2);
1857             P2 = tmp;
1858         }
1859     }
1860
1861     if (!P2) {
1862         if (P1) ec_point_free(P1);
1863         P1 = NULL;
1864     } else {
1865         ec_point_free(P2);
1866     }
1867
1868     if (!ecp_normalise(P1)) {
1869         ec_point_free(P1);
1870         return NULL;
1871     }
1872
1873     return P1;
1874 }
1875
1876 /* Not static because it is used by sshecdsag.c to generate a new key */
1877 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1878 {
1879     if (a->curve->type == EC_WEIERSTRASS) {
1880         return ecp_mulw(a, b);
1881     } else if (a->curve->type == EC_EDWARDS) {
1882         return ecp_mule(a, b);
1883     } else {
1884         return ecp_mulm(a, b);
1885     }
1886 }
1887
1888 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
1889                                    const struct ec_point *point)
1890 {
1891     struct ec_point *aG, *bP, *ret;
1892     int aminus3;
1893
1894     if (point->curve->type != EC_WEIERSTRASS) {
1895         return NULL;
1896     }
1897
1898     aminus3 = ec_aminus3(point->curve);
1899
1900     aG = ecp_mul_(&point->curve->w.G, a, aminus3);
1901     if (!aG) return NULL;
1902     bP = ecp_mul_(point, b, aminus3);
1903     if (!bP) {
1904         ec_point_free(aG);
1905         return NULL;
1906     }
1907
1908     ret = ecp_add(aG, bP, aminus3);
1909
1910     ec_point_free(aG);
1911     ec_point_free(bP);
1912
1913     if (!ecp_normalise(ret)) {
1914         ec_point_free(ret);
1915         return NULL;
1916     }
1917
1918     return ret;
1919 }
1920 static Bignum *ecp_edx(const struct ec_curve *curve, const Bignum y)
1921 {
1922     /* Get the x value on the given Edwards curve for a given y */
1923     Bignum x, xx;
1924
1925     /* xx = (y^2 - 1) / (d * y^2 + 1) */
1926     {
1927         Bignum tmp, tmp2, tmp3;
1928
1929         tmp = ecf_square(y, curve);
1930         if (!tmp) {
1931             return NULL;
1932         }
1933         tmp2 = modmul(curve->e.d, tmp, curve->p);
1934         if (!tmp2) {
1935             freebn(tmp);
1936             return NULL;
1937         }
1938         tmp3 = ecf_add(tmp2, One, curve);
1939         freebn(tmp2);
1940         if (!tmp3) {
1941             freebn(tmp);
1942             return NULL;
1943         }
1944         tmp2 = modinv(tmp3, curve->p);
1945         freebn(tmp3);
1946         if (!tmp2) {
1947             freebn(tmp);
1948             return NULL;
1949         }
1950
1951         tmp3 = modsub(tmp, One, curve->p);
1952         freebn(tmp);
1953         if (!tmp3) {
1954             freebn(tmp2);
1955             return NULL;
1956         }
1957         xx = modmul(tmp3, tmp2, curve->p);
1958         freebn(tmp3);
1959         freebn(tmp2);
1960         if (!xx) {
1961             return NULL;
1962         }
1963     }
1964
1965     /* x = xx^((p + 3) / 8) */
1966     {
1967         Bignum tmp, tmp2;
1968
1969         tmp = bignum_add_long(curve->p, 3);
1970         if (!tmp) {
1971             freebn(xx);
1972             return NULL;
1973         }
1974         tmp2 = bignum_rshift(tmp, 3);
1975         freebn(tmp);
1976         if (!tmp2) {
1977             freebn(xx);
1978             return NULL;
1979         }
1980         x = modpow(xx, tmp2, curve->p);
1981         freebn(tmp2);
1982         if (!x) {
1983             freebn(xx);
1984             return NULL;
1985         }
1986     }
1987
1988     /* if x^2 - xx != 0 then x = x*(2^((p - 1) / 4)) */
1989     {
1990         Bignum tmp, tmp2;
1991
1992         tmp = ecf_square(x, curve);
1993         if (!tmp) {
1994             freebn(x);
1995             freebn(xx);
1996             return NULL;
1997         }
1998         tmp2 = modsub(tmp, xx, curve->p);
1999         freebn(tmp);
2000         freebn(xx);
2001         if (!tmp2) {
2002             freebn(x);
2003             return NULL;
2004         }
2005         if (bignum_cmp(tmp2, Zero)) {
2006             Bignum tmp3;
2007
2008             freebn(tmp2);
2009
2010             tmp = modsub(curve->p, One, curve->p);
2011             if (!tmp) {
2012                 freebn(x);
2013                 return NULL;
2014             }
2015             tmp2 = bignum_rshift(tmp, 2);
2016             freebn(tmp);
2017             if (!tmp2) {
2018                 freebn(x);
2019                 return NULL;
2020             }
2021             tmp = bignum_from_long(2);
2022             if (!tmp) {
2023                 freebn(tmp2);
2024                 freebn(x);
2025                 return NULL;
2026             }
2027             tmp3 = modpow(tmp, tmp2, curve->p);
2028             freebn(tmp);
2029             freebn(tmp2);
2030             if (!tmp3) {
2031                 freebn(x);
2032                 return NULL;
2033             }
2034
2035             tmp = modmul(x, tmp3, curve->p);
2036             freebn(x);
2037             freebn(tmp3);
2038             x = tmp;
2039             if (!tmp) {
2040                 return NULL;
2041             }
2042         } else {
2043             freebn(tmp2);
2044         }
2045     }
2046
2047     /* if x % 2 != 0 then x = p - x */
2048     if (bignum_bit(x, 0)) {
2049         Bignum tmp = modsub(curve->p, x, curve->p);
2050         freebn(x);
2051         x = tmp;
2052         if (!tmp) {
2053             return NULL;
2054         }
2055     }
2056
2057     return x;
2058 }
2059
2060 /* ----------------------------------------------------------------------
2061  * Public point from private
2062  */
2063
2064 struct ec_point *ec_public(const Bignum privateKey, const struct ec_curve *curve)
2065 {
2066     if (curve->type == EC_WEIERSTRASS) {
2067         return ecp_mul(&curve->w.G, privateKey);
2068     } else if (curve->type == EC_EDWARDS) {
2069         /* hash = H(sk) (where hash creates 2 * fieldBits)
2070          * b = fieldBits
2071          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
2072          * publicKey = aB */
2073         struct ec_point *ret;
2074         unsigned char hash[512/8];
2075         Bignum a;
2076         int i, keylen;
2077         SHA512_State s;
2078         SHA512_Init(&s);
2079
2080         keylen = curve->fieldBits / 8;
2081         for (i = 0; i < keylen; ++i) {
2082             unsigned char b = bignum_byte(privateKey, i);
2083             SHA512_Bytes(&s, &b, 1);
2084         }
2085         SHA512_Final(&s, hash);
2086
2087         /* The second part is simply turning the hash into a Bignum,
2088          * however the 2^(b-2) bit *must* be set, and the bottom 3
2089          * bits *must* not be */
2090         hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
2091         hash[31] &= 0x7f; /* Unset above (b-2) */
2092         hash[31] |= 0x40; /* Set 2^(b-2) */
2093         /* Chop off the top part and convert to int */
2094         a = bignum_from_bytes_le(hash, 32);
2095         if (!a) {
2096             return NULL;
2097         }
2098
2099         ret = ecp_mul(&curve->e.B, a);
2100         freebn(a);
2101         return ret;
2102     } else {
2103         return NULL;
2104     }
2105 }
2106
2107 /* ----------------------------------------------------------------------
2108  * Basic sign and verify routines
2109  */
2110
2111 static int _ecdsa_verify(const struct ec_point *publicKey,
2112                          const unsigned char *data, const int dataLen,
2113                          const Bignum r, const Bignum s)
2114 {
2115     int z_bits, n_bits;
2116     Bignum z;
2117     int valid = 0;
2118
2119     if (publicKey->curve->type != EC_WEIERSTRASS) {
2120         return 0;
2121     }
2122
2123     /* Sanity checks */
2124     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->w.n) >= 0
2125         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->w.n) >= 0)
2126     {
2127         return 0;
2128     }
2129
2130     /* z = left most bitlen(curve->n) of data */
2131     z = bignum_from_bytes(data, dataLen);
2132     if (!z) return 0;
2133     n_bits = bignum_bitcount(publicKey->curve->w.n);
2134     z_bits = bignum_bitcount(z);
2135     if (z_bits > n_bits)
2136     {
2137         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
2138         freebn(z);
2139         z = tmp;
2140         if (!z) return 0;
2141     }
2142
2143     /* Ensure z in range of n */
2144     {
2145         Bignum tmp = bigmod(z, publicKey->curve->w.n);
2146         freebn(z);
2147         z = tmp;
2148         if (!z) return 0;
2149     }
2150
2151     /* Calculate signature */
2152     {
2153         Bignum w, x, u1, u2;
2154         struct ec_point *tmp;
2155
2156         w = modinv(s, publicKey->curve->w.n);
2157         if (!w) {
2158             freebn(z);
2159             return 0;
2160         }
2161         u1 = modmul(z, w, publicKey->curve->w.n);
2162         if (!u1) {
2163             freebn(z);
2164             freebn(w);
2165             return 0;
2166         }
2167         u2 = modmul(r, w, publicKey->curve->w.n);
2168         freebn(w);
2169         if (!u2) {
2170             freebn(z);
2171             freebn(u1);
2172             return 0;
2173         }
2174
2175         tmp = ecp_summul(u1, u2, publicKey);
2176         freebn(u1);
2177         freebn(u2);
2178         if (!tmp) {
2179             freebn(z);
2180             return 0;
2181         }
2182
2183         x = bigmod(tmp->x, publicKey->curve->w.n);
2184         ec_point_free(tmp);
2185         if (!x) {
2186             freebn(z);
2187             return 0;
2188         }
2189
2190         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
2191         freebn(x);
2192     }
2193
2194     freebn(z);
2195
2196     return valid;
2197 }
2198
2199 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
2200                         const unsigned char *data, const int dataLen,
2201                         Bignum *r, Bignum *s)
2202 {
2203     unsigned char digest[20];
2204     int z_bits, n_bits;
2205     Bignum z, k;
2206     struct ec_point *kG;
2207
2208     *r = NULL;
2209     *s = NULL;
2210
2211     if (curve->type != EC_WEIERSTRASS) {
2212         return;
2213     }
2214
2215     /* z = left most bitlen(curve->n) of data */
2216     z = bignum_from_bytes(data, dataLen);
2217     if (!z) return;
2218     n_bits = bignum_bitcount(curve->w.n);
2219     z_bits = bignum_bitcount(z);
2220     if (z_bits > n_bits)
2221     {
2222         Bignum tmp;
2223         tmp = bignum_rshift(z, z_bits - n_bits);
2224         freebn(z);
2225         z = tmp;
2226         if (!z) return;
2227     }
2228
2229     /* Generate k between 1 and curve->n, using the same deterministic
2230      * k generation system we use for conventional DSA. */
2231     SHA_Simple(data, dataLen, digest);
2232     k = dss_gen_k("ECDSA deterministic k generator", curve->w.n, privateKey,
2233                   digest, sizeof(digest));
2234     if (!k) return;
2235
2236     kG = ecp_mul(&curve->w.G, k);
2237     if (!kG) {
2238         freebn(z);
2239         freebn(k);
2240         return;
2241     }
2242
2243     /* r = kG.x mod n */
2244     *r = bigmod(kG->x, curve->w.n);
2245     ec_point_free(kG);
2246     if (!*r) {
2247         freebn(z);
2248         freebn(k);
2249         return;
2250     }
2251
2252     /* s = (z + r * priv)/k mod n */
2253     {
2254         Bignum rPriv, zMod, first, firstMod, kInv;
2255         rPriv = modmul(*r, privateKey, curve->w.n);
2256         if (!rPriv) {
2257             freebn(*r);
2258             freebn(z);
2259             freebn(k);
2260             return;
2261         }
2262         zMod = bigmod(z, curve->w.n);
2263         freebn(z);
2264         if (!zMod) {
2265             freebn(rPriv);
2266             freebn(*r);
2267             freebn(k);
2268             return;
2269         }
2270         first = bigadd(rPriv, zMod);
2271         freebn(rPriv);
2272         freebn(zMod);
2273         if (!first) {
2274             freebn(*r);
2275             freebn(k);
2276             return;
2277         }
2278         firstMod = bigmod(first, curve->w.n);
2279         freebn(first);
2280         if (!firstMod) {
2281             freebn(*r);
2282             freebn(k);
2283             return;
2284         }
2285         kInv = modinv(k, curve->w.n);
2286         freebn(k);
2287         if (!kInv) {
2288             freebn(firstMod);
2289             freebn(*r);
2290             return;
2291         }
2292         *s = modmul(firstMod, kInv, curve->w.n);
2293         freebn(firstMod);
2294         freebn(kInv);
2295         if (!*s) {
2296             freebn(*r);
2297             return;
2298         }
2299     }
2300 }
2301
2302 /* ----------------------------------------------------------------------
2303  * Misc functions
2304  */
2305
2306 static void getstring(const char **data, int *datalen,
2307                       const char **p, int *length)
2308 {
2309     *p = NULL;
2310     if (*datalen < 4)
2311         return;
2312     *length = toint(GET_32BIT(*data));
2313     if (*length < 0)
2314         return;
2315     *datalen -= 4;
2316     *data += 4;
2317     if (*datalen < *length)
2318         return;
2319     *p = *data;
2320     *data += *length;
2321     *datalen -= *length;
2322 }
2323
2324 static Bignum getmp(const char **data, int *datalen)
2325 {
2326     const char *p;
2327     int length;
2328
2329     getstring(data, datalen, &p, &length);
2330     if (!p)
2331         return NULL;
2332     if (p[0] & 0x80)
2333         return NULL;                   /* negative mp */
2334     return bignum_from_bytes((unsigned char *)p, length);
2335 }
2336
2337 static Bignum getmp_le(const char **data, int *datalen)
2338 {
2339     const char *p;
2340     int length;
2341
2342     getstring(data, datalen, &p, &length);
2343     if (!p)
2344         return NULL;
2345     return bignum_from_bytes_le((const unsigned char *)p, length);
2346 }
2347
2348 static int decodepoint_ed(const char *p, int length, struct ec_point *point)
2349 {
2350     /* Got some conversion to do, first read in the y co-ord */
2351     int negative;
2352
2353     point->y = bignum_from_bytes_le((const unsigned char*)p, length);
2354     if (!point->y) {
2355         return 0;
2356     }
2357     if ((unsigned)bignum_bitcount(point->y) > point->curve->fieldBits) {
2358         freebn(point->y);
2359         point->y = NULL;
2360         return 0;
2361     }
2362     /* Read x bit and then reset it */
2363     negative = bignum_bit(point->y, point->curve->fieldBits - 1);
2364     bignum_set_bit(point->y, point->curve->fieldBits - 1, 0);
2365
2366     /* Get the x from the y */
2367     point->x = ecp_edx(point->curve, point->y);
2368     if (!point->x) {
2369         freebn(point->y);
2370         point->y = NULL;
2371         return 0;
2372     }
2373     if (negative) {
2374         Bignum tmp = modsub(point->curve->p, point->x, point->curve->p);
2375         freebn(point->x);
2376         point->x = tmp;
2377         if (!tmp) {
2378             freebn(point->y);
2379             point->y = NULL;
2380             return 0;
2381         }
2382     }
2383
2384     /* Verify the point is on the curve */
2385     if (!ec_point_verify(point)) {
2386         freebn(point->x);
2387         point->x = NULL;
2388         freebn(point->y);
2389         point->y = NULL;
2390         return 0;
2391     }
2392
2393     return 1;
2394 }
2395
2396 static int decodepoint(const char *p, int length, struct ec_point *point)
2397 {
2398     if (point->curve->type == EC_EDWARDS) {
2399         return decodepoint_ed(p, length, point);
2400     }
2401
2402     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
2403         return 0;
2404     /* Skip compression flag */
2405     ++p;
2406     --length;
2407     /* The two values must be equal length */
2408     if (length % 2 != 0) {
2409         point->x = NULL;
2410         point->y = NULL;
2411         point->z = NULL;
2412         return 0;
2413     }
2414     length = length / 2;
2415     point->x = bignum_from_bytes((const unsigned char *)p, length);
2416     if (!point->x) return 0;
2417     p += length;
2418     point->y = bignum_from_bytes((const unsigned char *)p, length);
2419     if (!point->y) {
2420         freebn(point->x);
2421         point->x = NULL;
2422         return 0;
2423     }
2424     point->z = NULL;
2425
2426     /* Verify the point is on the curve */
2427     if (!ec_point_verify(point)) {
2428         freebn(point->x);
2429         point->x = NULL;
2430         freebn(point->y);
2431         point->y = NULL;
2432         return 0;
2433     }
2434
2435     return 1;
2436 }
2437
2438 static int getmppoint(const char **data, int *datalen, struct ec_point *point)
2439 {
2440     const char *p;
2441     int length;
2442
2443     getstring(data, datalen, &p, &length);
2444     if (!p) return 0;
2445     return decodepoint(p, length, point);
2446 }
2447
2448 /* ----------------------------------------------------------------------
2449  * Exposed ECDSA interface
2450  */
2451
2452 struct ecsign_extra {
2453     struct ec_curve *(*curve)(void);
2454
2455     /* These fields are used by the OpenSSH PEM format importer/exporter */
2456     const unsigned char *oid;
2457     int oidlen;
2458 };
2459
2460 static void ecdsa_freekey(void *key)
2461 {
2462     struct ec_key *ec = (struct ec_key *) key;
2463     if (!ec) return;
2464
2465     if (ec->publicKey.x)
2466         freebn(ec->publicKey.x);
2467     if (ec->publicKey.y)
2468         freebn(ec->publicKey.y);
2469     if (ec->publicKey.z)
2470         freebn(ec->publicKey.z);
2471     if (ec->privateKey)
2472         freebn(ec->privateKey);
2473     sfree(ec);
2474 }
2475
2476 static void *ecdsa_newkey(const struct ssh_signkey *self,
2477                           const char *data, int len)
2478 {
2479     const struct ecsign_extra *extra =
2480         (const struct ecsign_extra *)self->extra;
2481     const char *p;
2482     int slen;
2483     struct ec_key *ec;
2484     struct ec_curve *curve;
2485
2486     getstring(&data, &len, &p, &slen);
2487
2488     if (!p) {
2489         return NULL;
2490     }
2491     curve = extra->curve();
2492     assert(curve->type == EC_WEIERSTRASS || curve->type == EC_EDWARDS);
2493
2494     /* Curve name is duplicated for Weierstrass form */
2495     if (curve->type == EC_WEIERSTRASS) {
2496         getstring(&data, &len, &p, &slen);
2497         if (!match_ssh_id(slen, p, curve->name)) return NULL;
2498     }
2499
2500     ec = snew(struct ec_key);
2501
2502     ec->signalg = self;
2503     ec->publicKey.curve = curve;
2504     ec->publicKey.infinity = 0;
2505     ec->publicKey.x = NULL;
2506     ec->publicKey.y = NULL;
2507     ec->publicKey.z = NULL;
2508     if (!getmppoint(&data, &len, &ec->publicKey)) {
2509         ecdsa_freekey(ec);
2510         return NULL;
2511     }
2512     ec->privateKey = NULL;
2513
2514     if (!ec->publicKey.x || !ec->publicKey.y ||
2515         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2516         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2517     {
2518         ecdsa_freekey(ec);
2519         ec = NULL;
2520     }
2521
2522     return ec;
2523 }
2524
2525 static char *ecdsa_fmtkey(void *key)
2526 {
2527     struct ec_key *ec = (struct ec_key *) key;
2528     char *p;
2529     int len, i, pos, nibbles;
2530     static const char hex[] = "0123456789abcdef";
2531     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2532         return NULL;
2533
2534     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
2535     if (ec->publicKey.curve->name)
2536         len += strlen(ec->publicKey.curve->name); /* Curve name */
2537     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
2538     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
2539     p = snewn(len, char);
2540
2541     pos = 0;
2542     if (ec->publicKey.curve->name)
2543         pos += sprintf(p + pos, "%s,", ec->publicKey.curve->name);
2544     pos += sprintf(p + pos, "0x");
2545     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
2546     if (nibbles < 1)
2547         nibbles = 1;
2548     for (i = nibbles; i--;) {
2549         p[pos++] =
2550             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
2551     }
2552     pos += sprintf(p + pos, ",0x");
2553     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
2554     if (nibbles < 1)
2555         nibbles = 1;
2556     for (i = nibbles; i--;) {
2557         p[pos++] =
2558             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
2559     }
2560     p[pos] = '\0';
2561     return p;
2562 }
2563
2564 static unsigned char *ecdsa_public_blob(void *key, int *len)
2565 {
2566     struct ec_key *ec = (struct ec_key *) key;
2567     int pointlen, bloblen, fullnamelen, namelen;
2568     int i;
2569     unsigned char *blob, *p;
2570
2571     fullnamelen = strlen(ec->signalg->name);
2572
2573     if (ec->publicKey.curve->type == EC_EDWARDS) {
2574         /* Edwards compressed form "ssh-ed25519" point y[:-1] + x[0:1] */
2575
2576         pointlen = ec->publicKey.curve->fieldBits / 8;
2577
2578         /* Can't handle this in our loop */
2579         if (pointlen < 2) return NULL;
2580
2581         bloblen = 4 + fullnamelen + 4 + pointlen;
2582         blob = snewn(bloblen, unsigned char);
2583         if (!blob) return NULL;
2584
2585         p = blob;
2586         PUT_32BIT(p, fullnamelen);
2587         p += 4;
2588         memcpy(p, ec->signalg->name, fullnamelen);
2589         p += fullnamelen;
2590         PUT_32BIT(p, pointlen);
2591         p += 4;
2592
2593         /* Unset last bit of y and set first bit of x in its place */
2594         for (i = 0; i < pointlen - 1; ++i) {
2595             *p++ = bignum_byte(ec->publicKey.y, i);
2596         }
2597         /* Unset last bit of y and set first bit of x in its place */
2598         *p = bignum_byte(ec->publicKey.y, i) & 0x7f;
2599         *p++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2600     } else if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2601         assert(ec->publicKey.curve->name);
2602         namelen = strlen(ec->publicKey.curve->name);
2603
2604         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2605
2606         /*
2607          * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
2608          */
2609         bloblen = 4 + fullnamelen + 4 + namelen + 4 + 1 + (pointlen * 2);
2610         blob = snewn(bloblen, unsigned char);
2611
2612         p = blob;
2613         PUT_32BIT(p, fullnamelen);
2614         p += 4;
2615         memcpy(p, ec->signalg->name, fullnamelen);
2616         p += fullnamelen;
2617         PUT_32BIT(p, namelen);
2618         p += 4;
2619         memcpy(p, ec->publicKey.curve->name, namelen);
2620         p += namelen;
2621         PUT_32BIT(p, (2 * pointlen) + 1);
2622         p += 4;
2623         *p++ = 0x04;
2624         for (i = pointlen; i--;) {
2625             *p++ = bignum_byte(ec->publicKey.x, i);
2626         }
2627         for (i = pointlen; i--;) {
2628             *p++ = bignum_byte(ec->publicKey.y, i);
2629         }
2630     } else {
2631         return NULL;
2632     }
2633
2634     assert(p == blob + bloblen);
2635     *len = bloblen;
2636
2637     return blob;
2638 }
2639
2640 static unsigned char *ecdsa_private_blob(void *key, int *len)
2641 {
2642     struct ec_key *ec = (struct ec_key *) key;
2643     int keylen, bloblen;
2644     int i;
2645     unsigned char *blob, *p;
2646
2647     if (!ec->privateKey) return NULL;
2648
2649     if (ec->publicKey.curve->type == EC_EDWARDS) {
2650         /* Unsigned */
2651         keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2652     } else {
2653         /* Signed */
2654         keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2655     }
2656
2657     /*
2658      * mpint privateKey. Total 4 + keylen.
2659      */
2660     bloblen = 4 + keylen;
2661     blob = snewn(bloblen, unsigned char);
2662
2663     p = blob;
2664     PUT_32BIT(p, keylen);
2665     p += 4;
2666     if (ec->publicKey.curve->type == EC_EDWARDS) {
2667         /* Little endian */
2668         for (i = 0; i < keylen; ++i)
2669             *p++ = bignum_byte(ec->privateKey, i);
2670     } else {
2671         for (i = keylen; i--;)
2672             *p++ = bignum_byte(ec->privateKey, i);
2673     }
2674
2675     assert(p == blob + bloblen);
2676     *len = bloblen;
2677     return blob;
2678 }
2679
2680 static void *ecdsa_createkey(const struct ssh_signkey *self,
2681                              const unsigned char *pub_blob, int pub_len,
2682                              const unsigned char *priv_blob, int priv_len)
2683 {
2684     struct ec_key *ec;
2685     struct ec_point *publicKey;
2686     const char *pb = (const char *) priv_blob;
2687
2688     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) pub_blob, pub_len);
2689     if (!ec) {
2690         return NULL;
2691     }
2692
2693     if (ec->publicKey.curve->type != EC_WEIERSTRASS
2694         && ec->publicKey.curve->type != EC_EDWARDS) {
2695         ecdsa_freekey(ec);
2696         return NULL;
2697     }
2698
2699     if (ec->publicKey.curve->type == EC_EDWARDS) {
2700         ec->privateKey = getmp_le(&pb, &priv_len);
2701     } else {
2702         ec->privateKey = getmp(&pb, &priv_len);
2703     }
2704     if (!ec->privateKey) {
2705         ecdsa_freekey(ec);
2706         return NULL;
2707     }
2708
2709     /* Check that private key generates public key */
2710     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2711
2712     if (!publicKey ||
2713         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2714         bignum_cmp(publicKey->y, ec->publicKey.y))
2715     {
2716         ecdsa_freekey(ec);
2717         ec = NULL;
2718     }
2719     ec_point_free(publicKey);
2720
2721     return ec;
2722 }
2723
2724 static void *ed25519_openssh_createkey(const struct ssh_signkey *self,
2725                                        const unsigned char **blob, int *len)
2726 {
2727     struct ec_key *ec;
2728     struct ec_point *publicKey;
2729     const char *p, *q;
2730     int plen, qlen;
2731
2732     getstring((const char**)blob, len, &p, &plen);
2733     if (!p)
2734     {
2735         return NULL;
2736     }
2737
2738     ec = snew(struct ec_key);
2739     if (!ec)
2740     {
2741         return NULL;
2742     }
2743
2744     ec->signalg = self;
2745     ec->publicKey.curve = ec_ed25519();
2746     ec->publicKey.infinity = 0;
2747     ec->privateKey = NULL;
2748     ec->publicKey.x = NULL;
2749     ec->publicKey.z = NULL;
2750     ec->publicKey.y = NULL;
2751
2752     if (!decodepoint_ed(p, plen, &ec->publicKey))
2753     {
2754         ecdsa_freekey(ec);
2755         return NULL;
2756     }
2757
2758     getstring((const char**)blob, len, &q, &qlen);
2759     if (!q)
2760         return NULL;
2761     if (qlen != 64)
2762         return NULL;
2763
2764     ec->privateKey = bignum_from_bytes_le((const unsigned char *)q, 32);
2765     if (!ec->privateKey) {
2766         ecdsa_freekey(ec);
2767         return NULL;
2768     }
2769
2770     /* Check that private key generates public key */
2771     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2772
2773     if (!publicKey ||
2774         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2775         bignum_cmp(publicKey->y, ec->publicKey.y))
2776     {
2777         ecdsa_freekey(ec);
2778         ec = NULL;
2779     }
2780     ec_point_free(publicKey);
2781
2782     /* The OpenSSH format for ed25519 private keys also for some
2783      * reason encodes an extra copy of the public key in the second
2784      * half of the secret-key string. Check that that's present and
2785      * correct as well, otherwise the key we think we've imported
2786      * won't behave identically to the way OpenSSH would have treated
2787      * it. */
2788     if (plen != 32 || 0 != memcmp(q + 32, p, 32)) {
2789         ecdsa_freekey(ec);
2790         return NULL;
2791     }
2792
2793     return ec;
2794 }
2795
2796 static int ed25519_openssh_fmtkey(void *key, unsigned char *blob, int len)
2797 {
2798     struct ec_key *ec = (struct ec_key *) key;
2799
2800     int pointlen;
2801     int keylen;
2802     int bloblen;
2803     int i;
2804
2805     if (ec->publicKey.curve->type != EC_EDWARDS) {
2806         return 0;
2807     }
2808
2809     pointlen = (bignum_bitcount(ec->publicKey.y) + 7) / 8;
2810     keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2811     bloblen = 4 + pointlen + 4 + keylen + pointlen;
2812
2813     if (bloblen > len)
2814         return bloblen;
2815
2816     /* Encode the public point */
2817     PUT_32BIT(blob, pointlen);
2818     blob += 4;
2819
2820     for (i = 0; i < pointlen - 1; ++i) {
2821          *blob++ = bignum_byte(ec->publicKey.y, i);
2822     }
2823     /* Unset last bit of y and set first bit of x in its place */
2824     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2825     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2826
2827     PUT_32BIT(blob, keylen + pointlen);
2828     blob += 4;
2829     for (i = 0; i < keylen; ++i) {
2830          *blob++ = bignum_byte(ec->privateKey, i);
2831     }
2832     /* Now encode an extra copy of the public point as the second half
2833      * of the private key string, as the OpenSSH format for some
2834      * reason requires */
2835     for (i = 0; i < pointlen - 1; ++i) {
2836          *blob++ = bignum_byte(ec->publicKey.y, i);
2837     }
2838     /* Unset last bit of y and set first bit of x in its place */
2839     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2840     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2841
2842     return bloblen;
2843 }
2844
2845 static void *ecdsa_openssh_createkey(const struct ssh_signkey *self,
2846                                      const unsigned char **blob, int *len)
2847 {
2848     const struct ecsign_extra *extra =
2849         (const struct ecsign_extra *)self->extra;
2850     const char **b = (const char **) blob;
2851     const char *p;
2852     int slen;
2853     struct ec_key *ec;
2854     struct ec_curve *curve;
2855     struct ec_point *publicKey;
2856
2857     getstring(b, len, &p, &slen);
2858
2859     if (!p) {
2860         return NULL;
2861     }
2862     curve = extra->curve();
2863     assert(curve->type == EC_WEIERSTRASS);
2864
2865     ec = snew(struct ec_key);
2866
2867     ec->signalg = self;
2868     ec->publicKey.curve = curve;
2869     ec->publicKey.infinity = 0;
2870     ec->publicKey.x = NULL;
2871     ec->publicKey.y = NULL;
2872     ec->publicKey.z = NULL;
2873     if (!getmppoint(b, len, &ec->publicKey)) {
2874         ecdsa_freekey(ec);
2875         return NULL;
2876     }
2877     ec->privateKey = NULL;
2878
2879     if (!ec->publicKey.x || !ec->publicKey.y ||
2880         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2881         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2882     {
2883         ecdsa_freekey(ec);
2884         return NULL;
2885     }
2886
2887     ec->privateKey = getmp(b, len);
2888     if (ec->privateKey == NULL)
2889     {
2890         ecdsa_freekey(ec);
2891         return NULL;
2892     }
2893
2894     /* Now check that the private key makes the public key */
2895     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2896     if (!publicKey)
2897     {
2898         ecdsa_freekey(ec);
2899         return NULL;
2900     }
2901
2902     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
2903         bignum_cmp(ec->publicKey.y, publicKey->y))
2904     {
2905         /* Private key doesn't make the public key on the given curve */
2906         ecdsa_freekey(ec);
2907         ec_point_free(publicKey);
2908         return NULL;
2909     }
2910
2911     ec_point_free(publicKey);
2912
2913     return ec;
2914 }
2915
2916 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
2917 {
2918     struct ec_key *ec = (struct ec_key *) key;
2919
2920     int pointlen;
2921     int namelen;
2922     int bloblen;
2923     int i;
2924
2925     if (ec->publicKey.curve->type != EC_WEIERSTRASS) {
2926         return 0;
2927     }
2928
2929     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2930     namelen = strlen(ec->publicKey.curve->name);
2931     bloblen =
2932         4 + namelen /* <LEN> nistpXXX */
2933         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
2934         + ssh2_bignum_length(ec->privateKey);
2935
2936     if (bloblen > len)
2937         return bloblen;
2938
2939     bloblen = 0;
2940
2941     PUT_32BIT(blob+bloblen, namelen);
2942     bloblen += 4;
2943     memcpy(blob+bloblen, ec->publicKey.curve->name, namelen);
2944     bloblen += namelen;
2945
2946     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
2947     bloblen += 4;
2948     blob[bloblen++] = 0x04;
2949     for (i = pointlen; i--; )
2950         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
2951     for (i = pointlen; i--; )
2952         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
2953
2954     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2955     PUT_32BIT(blob+bloblen, pointlen);
2956     bloblen += 4;
2957     for (i = pointlen; i--; )
2958         blob[bloblen++] = bignum_byte(ec->privateKey, i);
2959
2960     return bloblen;
2961 }
2962
2963 static int ecdsa_pubkey_bits(const struct ssh_signkey *self,
2964                              const void *blob, int len)
2965 {
2966     struct ec_key *ec;
2967     int ret;
2968
2969     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) blob, len);
2970     if (!ec)
2971         return -1;
2972     ret = ec->publicKey.curve->fieldBits;
2973     ecdsa_freekey(ec);
2974
2975     return ret;
2976 }
2977
2978 static int ecdsa_verifysig(void *key, const char *sig, int siglen,
2979                            const char *data, int datalen)
2980 {
2981     struct ec_key *ec = (struct ec_key *) key;
2982     const char *p;
2983     int slen;
2984     int digestLen;
2985     int ret;
2986
2987     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2988         return 0;
2989
2990     /* Check the signature starts with the algorithm name */
2991     getstring(&sig, &siglen, &p, &slen);
2992     if (!p) {
2993         return 0;
2994     }
2995     if (!match_ssh_id(slen, p, ec->signalg->name)) {
2996         return 0;
2997     }
2998
2999     getstring(&sig, &siglen, &p, &slen);
3000     if (ec->publicKey.curve->type == EC_EDWARDS) {
3001         struct ec_point *r;
3002         Bignum s, h;
3003
3004         /* Check that the signature is two times the length of a point */
3005         if (slen != (ec->publicKey.curve->fieldBits / 8) * 2) {
3006             return 0;
3007         }
3008
3009         /* Check it's the 256 bit field so that SHA512 is the correct hash */
3010         if (ec->publicKey.curve->fieldBits != 256) {
3011             return 0;
3012         }
3013
3014         /* Get the signature */
3015         r = ec_point_new(ec->publicKey.curve, NULL, NULL, NULL, 0);
3016         if (!r) {
3017             return 0;
3018         }
3019         if (!decodepoint(p, ec->publicKey.curve->fieldBits / 8, r)) {
3020             ec_point_free(r);
3021             return 0;
3022         }
3023         s = bignum_from_bytes_le((unsigned char*)p + (ec->publicKey.curve->fieldBits / 8),
3024                                  ec->publicKey.curve->fieldBits / 8);
3025         if (!s) {
3026             ec_point_free(r);
3027             return 0;
3028         }
3029
3030         /* Get the hash of the encoded value of R + encoded value of pk + message */
3031         {
3032             int i, pointlen;
3033             unsigned char b;
3034             unsigned char digest[512 / 8];
3035             SHA512_State hs;
3036             SHA512_Init(&hs);
3037
3038             /* Add encoded r (no need to encode it again, it was in the signature) */
3039             SHA512_Bytes(&hs, p, ec->publicKey.curve->fieldBits / 8);
3040
3041             /* Encode pk and add it */
3042             pointlen = ec->publicKey.curve->fieldBits / 8;
3043             for (i = 0; i < pointlen - 1; ++i) {
3044                 b = bignum_byte(ec->publicKey.y, i);
3045                 SHA512_Bytes(&hs, &b, 1);
3046             }
3047             /* Unset last bit of y and set first bit of x in its place */
3048             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
3049             b |= bignum_bit(ec->publicKey.x, 0) << 7;
3050             SHA512_Bytes(&hs, &b, 1);
3051
3052             /* Add the message itself */
3053             SHA512_Bytes(&hs, data, datalen);
3054
3055             /* Get the hash */
3056             SHA512_Final(&hs, digest);
3057
3058             /* Convert to Bignum */
3059             h = bignum_from_bytes_le(digest, sizeof(digest));
3060             if (!h) {
3061                 ec_point_free(r);
3062                 freebn(s);
3063                 return 0;
3064             }
3065         }
3066
3067         /* Verify sB == r + h*publicKey */
3068         {
3069             struct ec_point *lhs, *rhs, *tmp;
3070
3071             /* lhs = sB */
3072             lhs = ecp_mul(&ec->publicKey.curve->e.B, s);
3073             freebn(s);
3074             if (!lhs) {
3075                 ec_point_free(r);
3076                 freebn(h);
3077                 return 0;
3078             }
3079
3080             /* rhs = r + h*publicKey */
3081             tmp = ecp_mul(&ec->publicKey, h);
3082             freebn(h);
3083             if (!tmp) {
3084                 ec_point_free(lhs);
3085                 ec_point_free(r);
3086                 return 0;
3087             }
3088             rhs = ecp_add(r, tmp, 0);
3089             ec_point_free(r);
3090             ec_point_free(tmp);
3091             if (!rhs) {
3092                 ec_point_free(lhs);
3093                 return 0;
3094             }
3095
3096             /* Check the point is the same */
3097             ret = !bignum_cmp(lhs->x, rhs->x);
3098             if (ret) {
3099                 ret = !bignum_cmp(lhs->y, rhs->y);
3100                 if (ret) {
3101                     ret = 1;
3102                 }
3103             }
3104             ec_point_free(lhs);
3105             ec_point_free(rhs);
3106         }
3107     } else {
3108         Bignum r, s;
3109         unsigned char digest[512 / 8];
3110
3111         r = getmp(&p, &slen);
3112         if (!r) return 0;
3113         s = getmp(&p, &slen);
3114         if (!s) {
3115             freebn(r);
3116             return 0;
3117         }
3118
3119         /* Perform correct hash function depending on curve size */
3120         if (ec->publicKey.curve->fieldBits <= 256) {
3121             SHA256_Simple(data, datalen, digest);
3122             digestLen = 256 / 8;
3123         } else if (ec->publicKey.curve->fieldBits <= 384) {
3124             SHA384_Simple(data, datalen, digest);
3125             digestLen = 384 / 8;
3126         } else {
3127             SHA512_Simple(data, datalen, digest);
3128             digestLen = 512 / 8;
3129         }
3130
3131         /* Verify the signature */
3132         ret = _ecdsa_verify(&ec->publicKey, digest, digestLen, r, s);
3133
3134         freebn(r);
3135         freebn(s);
3136     }
3137
3138     return ret;
3139 }
3140
3141 static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
3142                                  int *siglen)
3143 {
3144     struct ec_key *ec = (struct ec_key *) key;
3145     unsigned char digest[512 / 8];
3146     int digestLen;
3147     Bignum r = NULL, s = NULL;
3148     unsigned char *buf, *p;
3149     int rlen, slen, namelen;
3150     int i;
3151
3152     if (!ec->privateKey || !ec->publicKey.curve) {
3153         return NULL;
3154     }
3155
3156     if (ec->publicKey.curve->type == EC_EDWARDS) {
3157         struct ec_point *rp;
3158         int pointlen = ec->publicKey.curve->fieldBits / 8;
3159
3160         /* hash = H(sk) (where hash creates 2 * fieldBits)
3161          * b = fieldBits
3162          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
3163          * r = H(h[b/8:b/4] + m)
3164          * R = rB
3165          * S = (r + H(encodepoint(R) + encodepoint(pk) + m) * a) % l */
3166         {
3167             unsigned char hash[512/8];
3168             unsigned char b;
3169             Bignum a;
3170             SHA512_State hs;
3171             SHA512_Init(&hs);
3172
3173             for (i = 0; i < pointlen; ++i) {
3174                 unsigned char b = (unsigned char)bignum_byte(ec->privateKey, i);
3175                 SHA512_Bytes(&hs, &b, 1);
3176             }
3177
3178             SHA512_Final(&hs, hash);
3179
3180             /* The second part is simply turning the hash into a
3181              * Bignum, however the 2^(b-2) bit *must* be set, and the
3182              * bottom 3 bits *must* not be */
3183             hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
3184             hash[31] &= 0x7f; /* Unset above (b-2) */
3185             hash[31] |= 0x40; /* Set 2^(b-2) */
3186             /* Chop off the top part and convert to int */
3187             a = bignum_from_bytes_le(hash, 32);
3188             if (!a) {
3189                 return NULL;
3190             }
3191
3192             SHA512_Init(&hs);
3193             SHA512_Bytes(&hs,
3194                          hash+(ec->publicKey.curve->fieldBits / 8),
3195                          (ec->publicKey.curve->fieldBits / 4)
3196                          - (ec->publicKey.curve->fieldBits / 8));
3197             SHA512_Bytes(&hs, data, datalen);
3198             SHA512_Final(&hs, hash);
3199
3200             r = bignum_from_bytes_le(hash, 512/8);
3201             if (!r) {
3202                 freebn(a);
3203                 return NULL;
3204             }
3205             rp = ecp_mul(&ec->publicKey.curve->e.B, r);
3206             if (!rp) {
3207                 freebn(r);
3208                 freebn(a);
3209                 return NULL;
3210             }
3211
3212             /* Now calculate s */
3213             SHA512_Init(&hs);
3214             /* Encode the point R */
3215             for (i = 0; i < pointlen - 1; ++i) {
3216                 b = bignum_byte(rp->y, i);
3217                 SHA512_Bytes(&hs, &b, 1);
3218             }
3219             /* Unset last bit of y and set first bit of x in its place */
3220             b = bignum_byte(rp->y, i) & 0x7f;
3221             b |= bignum_bit(rp->x, 0) << 7;
3222             SHA512_Bytes(&hs, &b, 1);
3223
3224             /* Encode the point pk */
3225             for (i = 0; i < pointlen - 1; ++i) {
3226                 b = bignum_byte(ec->publicKey.y, i);
3227                 SHA512_Bytes(&hs, &b, 1);
3228             }
3229             /* Unset last bit of y and set first bit of x in its place */
3230             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
3231             b |= bignum_bit(ec->publicKey.x, 0) << 7;
3232             SHA512_Bytes(&hs, &b, 1);
3233
3234             /* Add the message */
3235             SHA512_Bytes(&hs, data, datalen);
3236             SHA512_Final(&hs, hash);
3237
3238             {
3239                 Bignum tmp, tmp2;
3240
3241                 tmp = bignum_from_bytes_le(hash, 512/8);
3242                 if (!tmp) {
3243                     ec_point_free(rp);
3244                     freebn(r);
3245                     freebn(a);
3246                     return NULL;
3247                 }
3248                 tmp2 = modmul(tmp, a, ec->publicKey.curve->e.l);
3249                 freebn(a);
3250                 freebn(tmp);
3251                 if (!tmp2) {
3252                     ec_point_free(rp);
3253                     freebn(r);
3254                     return NULL;
3255                 }
3256                 tmp = bigadd(r, tmp2);
3257                 freebn(r);
3258                 freebn(tmp2);
3259                 if (!tmp) {
3260                     ec_point_free(rp);
3261                     return NULL;
3262                 }
3263                 s = bigmod(tmp, ec->publicKey.curve->e.l);
3264                 freebn(tmp);
3265                 if (!s) {
3266                     ec_point_free(rp);
3267                     return NULL;
3268                 }
3269             }
3270         }
3271
3272         /* Format the output */
3273         namelen = strlen(ec->signalg->name);
3274         *siglen = 4+namelen+4+((ec->publicKey.curve->fieldBits / 8)*2);
3275         buf = snewn(*siglen, unsigned char);
3276         p = buf;
3277         PUT_32BIT(p, namelen);
3278         p += 4;
3279         memcpy(p, ec->signalg->name, namelen);
3280         p += namelen;
3281         PUT_32BIT(p, ((ec->publicKey.curve->fieldBits / 8)*2));
3282         p += 4;
3283
3284         /* Encode the point */
3285         pointlen = ec->publicKey.curve->fieldBits / 8;
3286         for (i = 0; i < pointlen - 1; ++i) {
3287             *p++ = bignum_byte(rp->y, i);
3288         }
3289         /* Unset last bit of y and set first bit of x in its place */
3290         *p = bignum_byte(rp->y, i) & 0x7f;
3291         *p++ |= bignum_bit(rp->x, 0) << 7;
3292         ec_point_free(rp);
3293
3294         /* Encode the int */
3295         for (i = 0; i < pointlen; ++i) {
3296             *p++ = bignum_byte(s, i);
3297         }
3298         freebn(s);
3299     } else {
3300         /* Perform correct hash function depending on curve size */
3301         if (ec->publicKey.curve->fieldBits <= 256) {
3302             SHA256_Simple(data, datalen, digest);
3303             digestLen = 256 / 8;
3304         } else if (ec->publicKey.curve->fieldBits <= 384) {
3305             SHA384_Simple(data, datalen, digest);
3306             digestLen = 384 / 8;
3307         } else {
3308             SHA512_Simple(data, datalen, digest);
3309             digestLen = 512 / 8;
3310         }
3311
3312         /* Do the signature */
3313         _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
3314         if (!r || !s) {
3315             if (r) freebn(r);
3316             if (s) freebn(s);
3317             return NULL;
3318         }
3319
3320         rlen = (bignum_bitcount(r) + 8) / 8;
3321         slen = (bignum_bitcount(s) + 8) / 8;
3322
3323         namelen = strlen(ec->signalg->name);
3324
3325         /* Format the output */
3326         *siglen = 8+namelen+rlen+slen+8;
3327         buf = snewn(*siglen, unsigned char);
3328         p = buf;
3329         PUT_32BIT(p, namelen);
3330         p += 4;
3331         memcpy(p, ec->signalg->name, namelen);
3332         p += namelen;
3333         PUT_32BIT(p, rlen + slen + 8);
3334         p += 4;
3335         PUT_32BIT(p, rlen);
3336         p += 4;
3337         for (i = rlen; i--;)
3338             *p++ = bignum_byte(r, i);
3339         PUT_32BIT(p, slen);
3340         p += 4;
3341         for (i = slen; i--;)
3342             *p++ = bignum_byte(s, i);
3343
3344         freebn(r);
3345         freebn(s);
3346     }
3347
3348     return buf;
3349 }
3350
3351 const struct ecsign_extra sign_extra_ed25519 = {
3352     ec_ed25519,
3353     NULL, 0,
3354 };
3355 const struct ssh_signkey ssh_ecdsa_ed25519 = {
3356     ecdsa_newkey,
3357     ecdsa_freekey,
3358     ecdsa_fmtkey,
3359     ecdsa_public_blob,
3360     ecdsa_private_blob,
3361     ecdsa_createkey,
3362     ed25519_openssh_createkey,
3363     ed25519_openssh_fmtkey,
3364     2 /* point, private exponent */,
3365     ecdsa_pubkey_bits,
3366     ecdsa_verifysig,
3367     ecdsa_sign,
3368     "ssh-ed25519",
3369     "ssh-ed25519",
3370     &sign_extra_ed25519,
3371 };
3372
3373 /* OID: 1.2.840.10045.3.1.7 (ansiX9p256r1) */
3374 static const unsigned char nistp256_oid[] = {
3375     0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07
3376 };
3377 const struct ecsign_extra sign_extra_nistp256 = {
3378     ec_p256,
3379     nistp256_oid, lenof(nistp256_oid),
3380 };
3381 const struct ssh_signkey ssh_ecdsa_nistp256 = {
3382     ecdsa_newkey,
3383     ecdsa_freekey,
3384     ecdsa_fmtkey,
3385     ecdsa_public_blob,
3386     ecdsa_private_blob,
3387     ecdsa_createkey,
3388     ecdsa_openssh_createkey,
3389     ecdsa_openssh_fmtkey,
3390     3 /* curve name, point, private exponent */,
3391     ecdsa_pubkey_bits,
3392     ecdsa_verifysig,
3393     ecdsa_sign,
3394     "ecdsa-sha2-nistp256",
3395     "ecdsa-sha2-nistp256",
3396     &sign_extra_nistp256,
3397 };
3398
3399 /* OID: 1.3.132.0.34 (secp384r1) */
3400 static const unsigned char nistp384_oid[] = {
3401     0x2b, 0x81, 0x04, 0x00, 0x22
3402 };
3403 const struct ecsign_extra sign_extra_nistp384 = {
3404     ec_p384,
3405     nistp384_oid, lenof(nistp384_oid),
3406 };
3407 const struct ssh_signkey ssh_ecdsa_nistp384 = {
3408     ecdsa_newkey,
3409     ecdsa_freekey,
3410     ecdsa_fmtkey,
3411     ecdsa_public_blob,
3412     ecdsa_private_blob,
3413     ecdsa_createkey,
3414     ecdsa_openssh_createkey,
3415     ecdsa_openssh_fmtkey,
3416     3 /* curve name, point, private exponent */,
3417     ecdsa_pubkey_bits,
3418     ecdsa_verifysig,
3419     ecdsa_sign,
3420     "ecdsa-sha2-nistp384",
3421     "ecdsa-sha2-nistp384",
3422     &sign_extra_nistp384,
3423 };
3424
3425 /* OID: 1.3.132.0.35 (secp521r1) */
3426 static const unsigned char nistp521_oid[] = {
3427     0x2b, 0x81, 0x04, 0x00, 0x23
3428 };
3429 const struct ecsign_extra sign_extra_nistp521 = {
3430     ec_p521,
3431     nistp521_oid, lenof(nistp521_oid),
3432 };
3433 const struct ssh_signkey ssh_ecdsa_nistp521 = {
3434     ecdsa_newkey,
3435     ecdsa_freekey,
3436     ecdsa_fmtkey,
3437     ecdsa_public_blob,
3438     ecdsa_private_blob,
3439     ecdsa_createkey,
3440     ecdsa_openssh_createkey,
3441     ecdsa_openssh_fmtkey,
3442     3 /* curve name, point, private exponent */,
3443     ecdsa_pubkey_bits,
3444     ecdsa_verifysig,
3445     ecdsa_sign,
3446     "ecdsa-sha2-nistp521",
3447     "ecdsa-sha2-nistp521",
3448     &sign_extra_nistp521,
3449 };
3450
3451 /* ----------------------------------------------------------------------
3452  * Exposed ECDH interface
3453  */
3454
3455 struct eckex_extra {
3456     struct ec_curve *(*curve)(void);
3457 };
3458
3459 static Bignum ecdh_calculate(const Bignum private,
3460                              const struct ec_point *public)
3461 {
3462     struct ec_point *p;
3463     Bignum ret;
3464     p = ecp_mul(public, private);
3465     if (!p) return NULL;
3466     ret = p->x;
3467     p->x = NULL;
3468
3469     if (p->curve->type == EC_MONTGOMERY) {
3470         /* Do conversion in network byte order */
3471         int i;
3472         int bytes = (bignum_bitcount(ret)+7) / 8;
3473         unsigned char *byteorder = snewn(bytes, unsigned char);
3474         if (!byteorder) {
3475             ec_point_free(p);
3476             freebn(ret);
3477             return NULL;
3478         }
3479         for (i = 0; i < bytes; ++i) {
3480             byteorder[i] = bignum_byte(ret, i);
3481         }
3482         freebn(ret);
3483         ret = bignum_from_bytes(byteorder, bytes);
3484         sfree(byteorder);
3485     }
3486
3487     ec_point_free(p);
3488     return ret;
3489 }
3490
3491 void *ssh_ecdhkex_newkey(const struct ssh_kex *kex)
3492 {
3493     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
3494     struct ec_curve *curve;
3495     struct ec_key *key;
3496     struct ec_point *publicKey;
3497
3498     curve = extra->curve();
3499
3500     key = snew(struct ec_key);
3501     if (!key) {
3502         return NULL;
3503     }
3504
3505     key->signalg = NULL;
3506     key->publicKey.curve = curve;
3507
3508     if (curve->type == EC_MONTGOMERY) {
3509         unsigned char bytes[32] = {0};
3510         int i;
3511
3512         for (i = 0; i < sizeof(bytes); ++i)
3513         {
3514             bytes[i] = (unsigned char)random_byte();
3515         }
3516         bytes[0] &= 248;
3517         bytes[31] &= 127;
3518         bytes[31] |= 64;
3519         key->privateKey = bignum_from_bytes(bytes, sizeof(bytes));
3520         for (i = 0; i < sizeof(bytes); ++i)
3521         {
3522             ((volatile char*)bytes)[i] = 0;
3523         }
3524         if (!key->privateKey) {
3525             sfree(key);
3526             return NULL;
3527         }
3528         publicKey = ecp_mul(&key->publicKey.curve->m.G, key->privateKey);
3529         if (!publicKey) {
3530             freebn(key->privateKey);
3531             sfree(key);
3532             return NULL;
3533         }
3534         key->publicKey.x = publicKey->x;
3535         key->publicKey.y = publicKey->y;
3536         key->publicKey.z = NULL;
3537         sfree(publicKey);
3538     } else {
3539         key->privateKey = bignum_random_in_range(One, key->publicKey.curve->w.n);
3540         if (!key->privateKey) {
3541             sfree(key);
3542             return NULL;
3543         }
3544         publicKey = ecp_mul(&key->publicKey.curve->w.G, key->privateKey);
3545         if (!publicKey) {
3546             freebn(key->privateKey);
3547             sfree(key);
3548             return NULL;
3549         }
3550         key->publicKey.x = publicKey->x;
3551         key->publicKey.y = publicKey->y;
3552         key->publicKey.z = NULL;
3553         sfree(publicKey);
3554     }
3555     return key;
3556 }
3557
3558 char *ssh_ecdhkex_getpublic(void *key, int *len)
3559 {
3560     struct ec_key *ec = (struct ec_key*)key;
3561     char *point, *p;
3562     int i;
3563     int pointlen;
3564
3565     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
3566
3567     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
3568         *len = 1 + pointlen * 2;
3569     } else {
3570         *len = pointlen;
3571     }
3572     point = (char*)snewn(*len, char);
3573     if (!point) {
3574         return NULL;
3575     }
3576
3577     p = point;
3578     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
3579         *p++ = 0x04;
3580         for (i = pointlen; i--;) {
3581             *p++ = bignum_byte(ec->publicKey.x, i);
3582         }
3583         for (i = pointlen; i--;) {
3584             *p++ = bignum_byte(ec->publicKey.y, i);
3585         }
3586     } else {
3587         for (i = 0; i < pointlen; ++i) {
3588             *p++ = bignum_byte(ec->publicKey.x, i);
3589         }
3590     }
3591
3592     return point;
3593 }
3594
3595 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
3596 {
3597     struct ec_key *ec = (struct ec_key*) key;
3598     struct ec_point remote;
3599     Bignum ret;
3600
3601     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
3602         remote.curve = ec->publicKey.curve;
3603         remote.infinity = 0;
3604         if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
3605             return NULL;
3606         }
3607     } else {
3608         /* Point length has to be the same length */
3609         if (remoteKeyLen != (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8) {
3610             return NULL;
3611         }
3612
3613         remote.curve = ec->publicKey.curve;
3614         remote.infinity = 0;
3615         remote.x = bignum_from_bytes_le((unsigned char*)remoteKey, remoteKeyLen);
3616         remote.y = NULL;
3617         remote.z = NULL;
3618     }
3619
3620     ret = ecdh_calculate(ec->privateKey, &remote);
3621     if (remote.x) freebn(remote.x);
3622     if (remote.y) freebn(remote.y);
3623     return ret;
3624 }
3625
3626 void ssh_ecdhkex_freekey(void *key)
3627 {
3628     ecdsa_freekey(key);
3629 }
3630
3631 static const struct eckex_extra kex_extra_curve25519 = { ec_curve25519 };
3632 static const struct ssh_kex ssh_ec_kex_curve25519 = {
3633     "curve25519-sha256@libssh.org", NULL, KEXTYPE_ECDH,
3634     &ssh_sha256, &kex_extra_curve25519,
3635 };
3636
3637 const struct eckex_extra kex_extra_nistp256 = { ec_p256 };
3638 static const struct ssh_kex ssh_ec_kex_nistp256 = {
3639     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH,
3640     &ssh_sha256, &kex_extra_nistp256,
3641 };
3642
3643 const struct eckex_extra kex_extra_nistp384 = { ec_p384 };
3644 static const struct ssh_kex ssh_ec_kex_nistp384 = {
3645     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH,
3646     &ssh_sha384, &kex_extra_nistp384,
3647 };
3648
3649 const struct eckex_extra kex_extra_nistp521 = { ec_p521 };
3650 static const struct ssh_kex ssh_ec_kex_nistp521 = {
3651     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH,
3652     &ssh_sha512, &kex_extra_nistp521,
3653 };
3654
3655 static const struct ssh_kex *const ec_kex_list[] = {
3656     &ssh_ec_kex_curve25519,
3657     &ssh_ec_kex_nistp256,
3658     &ssh_ec_kex_nistp384,
3659     &ssh_ec_kex_nistp521,
3660 };
3661
3662 const struct ssh_kexes ssh_ecdh_kex = {
3663     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
3664     ec_kex_list
3665 };
3666
3667 /* ----------------------------------------------------------------------
3668  * Helper functions for finding key algorithms and returning auxiliary
3669  * data.
3670  */
3671
3672 const struct ssh_signkey *ec_alg_by_oid(int len, const void *oid,
3673                                         const struct ec_curve **curve)
3674 {
3675     static const struct ssh_signkey *algs_with_oid[] = {
3676         &ssh_ecdsa_nistp256,
3677         &ssh_ecdsa_nistp384,
3678         &ssh_ecdsa_nistp521,
3679     };
3680     int i;
3681
3682     for (i = 0; i < lenof(algs_with_oid); i++) {
3683         const struct ssh_signkey *alg = algs_with_oid[i];
3684         const struct ecsign_extra *extra =
3685             (const struct ecsign_extra *)alg->extra;
3686         if (len == extra->oidlen && !memcmp(oid, extra->oid, len)) {
3687             *curve = extra->curve();
3688             return alg;
3689         }
3690     }
3691     return NULL;
3692 }
3693
3694 const unsigned char *ec_alg_oid(const struct ssh_signkey *alg,
3695                                 int *oidlen)
3696 {
3697     const struct ecsign_extra *extra = (const struct ecsign_extra *)alg->extra;
3698     *oidlen = extra->oidlen;
3699     return extra->oid;
3700 }
3701
3702 const int ec_nist_alg_and_curve_by_bits(int bits,
3703                                         const struct ec_curve **curve,
3704                                         const struct ssh_signkey **alg)
3705 {
3706     switch (bits) {
3707       case 256: *alg = &ssh_ecdsa_nistp256; break;
3708       case 384: *alg = &ssh_ecdsa_nistp384; break;
3709       case 521: *alg = &ssh_ecdsa_nistp521; break;
3710       default: return FALSE;
3711     }
3712     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
3713     return TRUE;
3714 }
3715
3716 const int ec_ed_alg_and_curve_by_bits(int bits,
3717                                       const struct ec_curve **curve,
3718                                       const struct ssh_signkey **alg)
3719 {
3720     switch (bits) {
3721       case 256: *alg = &ssh_ecdsa_ed25519; break;
3722       default: return FALSE;
3723     }
3724     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
3725     return TRUE;
3726 }