]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
Clean up hash selection in ECDSA.
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  *
5  * NOTE: Only curves on prime field are handled by the maths functions
6  *       in Weierstrass form using Jacobian co-ordinates.
7  *
8  *       Montgomery form curves are supported for DH. (Curve25519)
9  *
10  *       Edwards form curves are supported for DSA. (Ed25519)
11  */
12
13 /*
14  * References:
15  *
16  * Elliptic curves in SSH are specified in RFC 5656:
17  *   http://tools.ietf.org/html/rfc5656
18  *
19  * That specification delegates details of public key formatting and a
20  * lot of underlying mechanism to SEC 1:
21  *   http://www.secg.org/sec1-v2.pdf
22  *
23  * Montgomery maths from:
24  * Handbook of elliptic and hyperelliptic curve cryptography, Chapter 13
25  *   http://cs.ucsb.edu/~koc/ccs130h/2013/EllipticHyperelliptic-CohenFrey.pdf
26  *
27  * Edwards DSA:
28  *   http://ed25519.cr.yp.to/ed25519-20110926.pdf
29  */
30
31 #include <stdlib.h>
32 #include <assert.h>
33
34 #include "ssh.h"
35
36 /* ----------------------------------------------------------------------
37  * Elliptic curve definitions
38  */
39
40 static int initialise_wcurve(struct ec_curve *curve, int bits, unsigned char *p,
41                              unsigned char *a, unsigned char *b,
42                              unsigned char *n, unsigned char *Gx,
43                              unsigned char *Gy)
44 {
45     int length = bits / 8;
46     if (bits % 8) ++length;
47
48     curve->type = EC_WEIERSTRASS;
49
50     curve->fieldBits = bits;
51     curve->p = bignum_from_bytes(p, length);
52     if (!curve->p) goto error;
53
54     /* Curve co-efficients */
55     curve->w.a = bignum_from_bytes(a, length);
56     if (!curve->w.a) goto error;
57     curve->w.b = bignum_from_bytes(b, length);
58     if (!curve->w.b) goto error;
59
60     /* Group order and generator */
61     curve->w.n = bignum_from_bytes(n, length);
62     if (!curve->w.n) goto error;
63     curve->w.G.x = bignum_from_bytes(Gx, length);
64     if (!curve->w.G.x) goto error;
65     curve->w.G.y = bignum_from_bytes(Gy, length);
66     if (!curve->w.G.y) goto error;
67     curve->w.G.curve = curve;
68     curve->w.G.infinity = 0;
69
70     return 1;
71   error:
72     if (curve->p) freebn(curve->p);
73     if (curve->w.a) freebn(curve->w.a);
74     if (curve->w.b) freebn(curve->w.b);
75     if (curve->w.n) freebn(curve->w.n);
76     if (curve->w.G.x) freebn(curve->w.G.x);
77     return 0;
78 }
79
80 static int initialise_mcurve(struct ec_curve *curve, int bits, unsigned char *p,
81                              unsigned char *a, unsigned char *b,
82                              unsigned char *Gx)
83 {
84     int length = bits / 8;
85     if (bits % 8) ++length;
86
87     curve->type = EC_MONTGOMERY;
88
89     curve->fieldBits = bits;
90     curve->p = bignum_from_bytes(p, length);
91     if (!curve->p) goto error;
92
93     /* Curve co-efficients */
94     curve->m.a = bignum_from_bytes(a, length);
95     if (!curve->m.a) goto error;
96     curve->m.b = bignum_from_bytes(b, length);
97     if (!curve->m.b) goto error;
98
99     /* Generator */
100     curve->m.G.x = bignum_from_bytes(Gx, length);
101     if (!curve->m.G.x) goto error;
102     curve->m.G.y = NULL;
103     curve->m.G.z = NULL;
104     curve->m.G.curve = curve;
105     curve->m.G.infinity = 0;
106
107     return 1;
108   error:
109     if (curve->p) freebn(curve->p);
110     if (curve->m.a) freebn(curve->m.a);
111     if (curve->m.b) freebn(curve->m.b);
112     return 0;
113 }
114
115 static int initialise_ecurve(struct ec_curve *curve, int bits, unsigned char *p,
116                              unsigned char *l, unsigned char *d,
117                              unsigned char *Bx, unsigned char *By)
118 {
119     int length = bits / 8;
120     if (bits % 8) ++length;
121
122     curve->type = EC_EDWARDS;
123
124     curve->fieldBits = bits;
125     curve->p = bignum_from_bytes(p, length);
126     if (!curve->p) goto error;
127
128     /* Curve co-efficients */
129     curve->e.l = bignum_from_bytes(l, length);
130     if (!curve->e.l) goto error;
131     curve->e.d = bignum_from_bytes(d, length);
132     if (!curve->e.d) goto error;
133
134     /* Group order and generator */
135     curve->e.B.x = bignum_from_bytes(Bx, length);
136     if (!curve->e.B.x) goto error;
137     curve->e.B.y = bignum_from_bytes(By, length);
138     if (!curve->e.B.y) goto error;
139     curve->e.B.curve = curve;
140     curve->e.B.infinity = 0;
141
142     return 1;
143   error:
144     if (curve->p) freebn(curve->p);
145     if (curve->e.l) freebn(curve->e.l);
146     if (curve->e.d) freebn(curve->e.d);
147     if (curve->e.B.x) freebn(curve->e.B.x);
148     return 0;
149 }
150
151 static struct ec_curve *ec_p256(void)
152 {
153     static struct ec_curve curve = { 0 };
154     static unsigned char initialised = 0;
155
156     if (!initialised)
157     {
158         unsigned char p[] = {
159             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
160             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
161             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
162             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
163         };
164         unsigned char a[] = {
165             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
166             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
167             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
168             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
169         };
170         unsigned char b[] = {
171             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
172             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
173             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
174             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
175         };
176         unsigned char n[] = {
177             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
178             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
179             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
180             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
181         };
182         unsigned char Gx[] = {
183             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
184             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
185             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
186             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
187         };
188         unsigned char Gy[] = {
189             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
190             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
191             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
192             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
193         };
194
195         if (!initialise_wcurve(&curve, 256, p, a, b, n, Gx, Gy)) {
196             return NULL;
197         }
198
199         curve.name = "nistp256";
200
201         /* Now initialised, no need to do it again */
202         initialised = 1;
203     }
204
205     return &curve;
206 }
207
208 static struct ec_curve *ec_p384(void)
209 {
210     static struct ec_curve curve = { 0 };
211     static unsigned char initialised = 0;
212
213     if (!initialised)
214     {
215         unsigned char p[] = {
216             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
217             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
218             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
219             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
220             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
221             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
222         };
223         unsigned char a[] = {
224             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
225             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
226             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
227             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
228             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
229             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
230         };
231         unsigned char b[] = {
232             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
233             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
234             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
235             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
236             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
237             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
238         };
239         unsigned char n[] = {
240             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
241             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
242             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
243             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
244             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
245             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
246         };
247         unsigned char Gx[] = {
248             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
249             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
250             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
251             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
252             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
253             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
254         };
255         unsigned char Gy[] = {
256             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
257             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
258             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
259             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
260             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
261             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
262         };
263
264         if (!initialise_wcurve(&curve, 384, p, a, b, n, Gx, Gy)) {
265             return NULL;
266         }
267
268         curve.name = "nistp384";
269
270         /* Now initialised, no need to do it again */
271         initialised = 1;
272     }
273
274     return &curve;
275 }
276
277 static struct ec_curve *ec_p521(void)
278 {
279     static struct ec_curve curve = { 0 };
280     static unsigned char initialised = 0;
281
282     if (!initialised)
283     {
284         unsigned char p[] = {
285             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
286             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
287             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
288             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
289             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
290             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
291             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
292             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
293             0xff, 0xff
294         };
295         unsigned char a[] = {
296             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
297             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
298             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
299             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
300             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
301             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
302             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
303             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
304             0xff, 0xfc
305         };
306         unsigned char b[] = {
307             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
308             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
309             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
310             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
311             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
312             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
313             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
314             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
315             0x3f, 0x00
316         };
317         unsigned char n[] = {
318             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
319             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
320             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
321             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
322             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
323             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
324             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
325             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
326             0x64, 0x09
327         };
328         unsigned char Gx[] = {
329             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
330             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
331             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
332             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
333             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
334             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
335             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
336             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
337             0xbd, 0x66
338         };
339         unsigned char Gy[] = {
340             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
341             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
342             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
343             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
344             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
345             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
346             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
347             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
348             0x66, 0x50
349         };
350
351         if (!initialise_wcurve(&curve, 521, p, a, b, n, Gx, Gy)) {
352             return NULL;
353         }
354
355         curve.name = "nistp521";
356
357         /* Now initialised, no need to do it again */
358         initialised = 1;
359     }
360
361     return &curve;
362 }
363
364 static struct ec_curve *ec_curve25519(void)
365 {
366     static struct ec_curve curve = { 0 };
367     static unsigned char initialised = 0;
368
369     if (!initialised)
370     {
371         unsigned char p[] = {
372             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
373             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
374             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
375             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
376         };
377         unsigned char a[] = {
378             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
379             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
380             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
381             0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x6d, 0x06
382         };
383         unsigned char b[] = {
384             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
385             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
386             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
387             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01
388         };
389         unsigned char gx[32] = {
390             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
391             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
392             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
393             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09
394         };
395
396         if (!initialise_mcurve(&curve, 256, p, a, b, gx)) {
397             return NULL;
398         }
399
400         /* This curve doesn't need a name, because it's never used in
401          * any format that embeds the curve name */
402         curve.name = NULL;
403
404         /* Now initialised, no need to do it again */
405         initialised = 1;
406     }
407
408     return &curve;
409 }
410
411 static struct ec_curve *ec_ed25519(void)
412 {
413     static struct ec_curve curve = { 0 };
414     static unsigned char initialised = 0;
415
416     if (!initialised)
417     {
418         unsigned char q[] = {
419             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
420             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
421             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
422             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
423         };
424         unsigned char l[32] = {
425             0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
426             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
427             0x14, 0xde, 0xf9, 0xde, 0xa2, 0xf7, 0x9c, 0xd6,
428             0x58, 0x12, 0x63, 0x1a, 0x5c, 0xf5, 0xd3, 0xed
429         };
430         unsigned char d[32] = {
431             0x52, 0x03, 0x6c, 0xee, 0x2b, 0x6f, 0xfe, 0x73,
432             0x8c, 0xc7, 0x40, 0x79, 0x77, 0x79, 0xe8, 0x98,
433             0x00, 0x70, 0x0a, 0x4d, 0x41, 0x41, 0xd8, 0xab,
434             0x75, 0xeb, 0x4d, 0xca, 0x13, 0x59, 0x78, 0xa3
435         };
436         unsigned char Bx[32] = {
437             0x21, 0x69, 0x36, 0xd3, 0xcd, 0x6e, 0x53, 0xfe,
438             0xc0, 0xa4, 0xe2, 0x31, 0xfd, 0xd6, 0xdc, 0x5c,
439             0x69, 0x2c, 0xc7, 0x60, 0x95, 0x25, 0xa7, 0xb2,
440             0xc9, 0x56, 0x2d, 0x60, 0x8f, 0x25, 0xd5, 0x1a
441         };
442         unsigned char By[32] = {
443             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
444             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
445             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
446             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x58
447         };
448
449         /* This curve doesn't need a name, because it's never used in
450          * any format that embeds the curve name */
451         curve.name = NULL;
452
453         if (!initialise_ecurve(&curve, 256, q, l, d, Bx, By)) {
454             return NULL;
455         }
456
457         /* Now initialised, no need to do it again */
458         initialised = 1;
459     }
460
461     return &curve;
462 }
463
464 /* Return 1 if a is -3 % p, otherwise return 0
465  * This is used because there are some maths optimisations */
466 static int ec_aminus3(const struct ec_curve *curve)
467 {
468     int ret;
469     Bignum _p;
470
471     if (curve->type != EC_WEIERSTRASS) {
472         return 0;
473     }
474
475     _p = bignum_add_long(curve->w.a, 3);
476     if (!_p) return 0;
477
478     ret = !bignum_cmp(curve->p, _p);
479     freebn(_p);
480     return ret;
481 }
482
483 /* ----------------------------------------------------------------------
484  * Elliptic curve field maths
485  */
486
487 static Bignum ecf_add(const Bignum a, const Bignum b,
488                       const struct ec_curve *curve)
489 {
490     Bignum a1, b1, ab, ret;
491
492     a1 = bigmod(a, curve->p);
493     if (!a1) return NULL;
494     b1 = bigmod(b, curve->p);
495     if (!b1)
496     {
497         freebn(a1);
498         return NULL;
499     }
500
501     ab = bigadd(a1, b1);
502     freebn(a1);
503     freebn(b1);
504     if (!ab) return NULL;
505
506     ret = bigmod(ab, curve->p);
507     freebn(ab);
508
509     return ret;
510 }
511
512 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
513 {
514     return modmul(a, a, curve->p);
515 }
516
517 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
518 {
519     Bignum ret, tmp;
520
521     /* Double */
522     tmp = bignum_lshift(a, 1);
523     if (!tmp) return NULL;
524
525     /* Add itself (i.e. treble) */
526     ret = bigadd(tmp, a);
527     freebn(tmp);
528
529     /* Normalise */
530     while (ret != NULL && bignum_cmp(ret, curve->p) >= 0)
531     {
532         tmp = bigsub(ret, curve->p);
533         freebn(ret);
534         ret = tmp;
535     }
536
537     return ret;
538 }
539
540 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
541 {
542     Bignum ret = bignum_lshift(a, 1);
543     if (!ret) return NULL;
544     if (bignum_cmp(ret, curve->p) >= 0)
545     {
546         Bignum tmp = bigsub(ret, curve->p);
547         freebn(ret);
548         return tmp;
549     }
550     else
551     {
552         return ret;
553     }
554 }
555
556 /* ----------------------------------------------------------------------
557  * Memory functions
558  */
559
560 void ec_point_free(struct ec_point *point)
561 {
562     if (point == NULL) return;
563     point->curve = 0;
564     if (point->x) freebn(point->x);
565     if (point->y) freebn(point->y);
566     if (point->z) freebn(point->z);
567     point->infinity = 0;
568     sfree(point);
569 }
570
571 static struct ec_point *ec_point_new(const struct ec_curve *curve,
572                                      const Bignum x, const Bignum y, const Bignum z,
573                                      unsigned char infinity)
574 {
575     struct ec_point *point = snewn(1, struct ec_point);
576     point->curve = curve;
577     point->x = x;
578     point->y = y;
579     point->z = z;
580     point->infinity = infinity ? 1 : 0;
581     return point;
582 }
583
584 static struct ec_point *ec_point_copy(const struct ec_point *a)
585 {
586     if (a == NULL) return NULL;
587     return ec_point_new(a->curve,
588                         a->x ? copybn(a->x) : NULL,
589                         a->y ? copybn(a->y) : NULL,
590                         a->z ? copybn(a->z) : NULL,
591                         a->infinity);
592 }
593
594 static int ec_point_verify(const struct ec_point *a)
595 {
596     if (a->infinity) {
597         return 1;
598     } else if (a->curve->type == EC_EDWARDS) {
599         /* Check y^2 - x^2 - 1 - d * x^2 * y^2 == 0 */
600         Bignum y2, x2, tmp, tmp2, tmp3;
601         int ret;
602
603         y2 = ecf_square(a->y, a->curve);
604         if (!y2) {
605             return 0;
606         }
607         x2 = ecf_square(a->x, a->curve);
608         if (!x2) {
609             freebn(y2);
610             return 0;
611         }
612         tmp = modmul(a->curve->e.d, x2, a->curve->p);
613         if (!tmp) {
614             freebn(x2);
615             freebn(y2);
616             return 0;
617         }
618         tmp2 = modmul(tmp, y2, a->curve->p);
619         freebn(tmp);
620         if (!tmp2) {
621             freebn(x2);
622             freebn(y2);
623             return 0;
624         }
625         tmp = modsub(y2, x2, a->curve->p);
626         freebn(y2);
627         freebn(x2);
628         if (!tmp) {
629             freebn(tmp2);
630             return 0;
631         }
632         tmp3 = modsub(tmp, tmp2, a->curve->p);
633         freebn(tmp);
634         freebn(tmp2);
635         if (!tmp3) {
636             return 0;
637         }
638         ret = !bignum_cmp(tmp3, One);
639         freebn(tmp3);
640         return ret;
641     } else if (a->curve->type == EC_WEIERSTRASS) {
642         /* Verify y^2 = x^3 + ax + b */
643         int ret = 0;
644
645         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
646
647         Bignum Three = bignum_from_long(3);
648         if (!Three) return 0;
649
650         lhs = modmul(a->y, a->y, a->curve->p);
651         if (!lhs) goto error;
652
653         /* This uses montgomery multiplication to optimise */
654         x3 = modpow(a->x, Three, a->curve->p);
655         freebn(Three);
656         if (!x3) goto error;
657         ax = modmul(a->curve->w.a, a->x, a->curve->p);
658         if (!ax) goto error;
659         x3ax = bigadd(x3, ax);
660         if (!x3ax) goto error;
661         freebn(x3); x3 = NULL;
662         freebn(ax); ax = NULL;
663         x3axm = bigmod(x3ax, a->curve->p);
664         if (!x3axm) goto error;
665         freebn(x3ax); x3ax = NULL;
666         x3axb = bigadd(x3axm, a->curve->w.b);
667         if (!x3axb) goto error;
668         freebn(x3axm); x3axm = NULL;
669         rhs = bigmod(x3axb, a->curve->p);
670         if (!rhs) goto error;
671         freebn(x3axb);
672
673         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
674         freebn(lhs);
675         freebn(rhs);
676
677         return ret;
678
679       error:
680         if (x3) freebn(x3);
681         if (ax) freebn(ax);
682         if (x3ax) freebn(x3ax);
683         if (x3axm) freebn(x3axm);
684         if (x3axb) freebn(x3axb);
685         if (lhs) freebn(lhs);
686         return 0;
687     } else {
688         return 0;
689     }
690 }
691
692 /* ----------------------------------------------------------------------
693  * Elliptic curve point maths
694  */
695
696 /* Returns 1 on success and 0 on memory error */
697 static int ecp_normalise(struct ec_point *a)
698 {
699     if (!a) {
700         /* No point */
701         return 0;
702     }
703
704     if (a->infinity) {
705         /* Point is at infinity - i.e. normalised */
706         return 1;
707     }
708
709     if (a->curve->type == EC_WEIERSTRASS) {
710         /* In Jacobian Coordinates the triple (X, Y, Z) represents
711            the affine point (X / Z^2, Y / Z^3) */
712
713         Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
714
715         if (!a->x || !a->y) {
716             /* No point defined */
717             return 0;
718         } else if (!a->z) {
719             /* Already normalised */
720             return 1;
721         }
722
723         Z2 = ecf_square(a->z, a->curve);
724         if (!Z2) {
725             return 0;
726         }
727         Z2inv = modinv(Z2, a->curve->p);
728         if (!Z2inv) {
729             freebn(Z2);
730             return 0;
731         }
732         tx = modmul(a->x, Z2inv, a->curve->p);
733         freebn(Z2inv);
734         if (!tx) {
735             freebn(Z2);
736             return 0;
737         }
738
739         Z3 = modmul(Z2, a->z, a->curve->p);
740         freebn(Z2);
741         if (!Z3) {
742             freebn(tx);
743             return 0;
744         }
745         Z3inv = modinv(Z3, a->curve->p);
746         freebn(Z3);
747         if (!Z3inv) {
748             freebn(tx);
749             return 0;
750         }
751         ty = modmul(a->y, Z3inv, a->curve->p);
752         freebn(Z3inv);
753         if (!ty) {
754             freebn(tx);
755             return 0;
756         }
757
758         freebn(a->x);
759         a->x = tx;
760         freebn(a->y);
761         a->y = ty;
762         freebn(a->z);
763         a->z = NULL;
764         return 1;
765     } else if (a->curve->type == EC_MONTGOMERY) {
766         /* In Montgomery (X : Z) represents the x co-ord (X / Z, ?) */
767
768         Bignum tmp, tmp2;
769
770         if (!a->x) {
771             /* No point defined */
772             return 0;
773         } else if (!a->z) {
774             /* Already normalised */
775             return 1;
776         }
777
778         tmp = modinv(a->z, a->curve->p);
779         if (!tmp) {
780             return 0;
781         }
782         tmp2 = modmul(a->x, tmp, a->curve->p);
783         freebn(tmp);
784         if (!tmp2) {
785             return 0;
786         }
787
788         freebn(a->z);
789         a->z = NULL;
790         freebn(a->x);
791         a->x = tmp2;
792         return 1;
793     } else if (a->curve->type == EC_EDWARDS) {
794         /* Always normalised */
795         return 1;
796     } else {
797         return 0;
798     }
799 }
800
801 static struct ec_point *ecp_doublew(const struct ec_point *a, const int aminus3)
802 {
803     Bignum S, M, outx, outy, outz;
804
805     if (bignum_cmp(a->y, Zero) == 0)
806     {
807         /* Identity */
808         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
809     }
810
811     /* S = 4*X*Y^2 */
812     {
813         Bignum Y2, XY2, _2XY2;
814
815         Y2 = ecf_square(a->y, a->curve);
816         if (!Y2) {
817             return NULL;
818         }
819         XY2 = modmul(a->x, Y2, a->curve->p);
820         freebn(Y2);
821         if (!XY2) {
822             return NULL;
823         }
824
825         _2XY2 = ecf_double(XY2, a->curve);
826         freebn(XY2);
827         if (!_2XY2) {
828             return NULL;
829         }
830         S = ecf_double(_2XY2, a->curve);
831         freebn(_2XY2);
832         if (!S) {
833             return NULL;
834         }
835     }
836
837     /* Faster calculation if a = -3 */
838     if (aminus3) {
839         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
840         Bignum Z2, XpZ2, XmZ2, second;
841
842         if (a->z == NULL) {
843             Z2 = copybn(One);
844         } else {
845             Z2 = ecf_square(a->z, a->curve);
846         }
847         if (!Z2) {
848             freebn(S);
849             return NULL;
850         }
851
852         XpZ2 = ecf_add(a->x, Z2, a->curve);
853         if (!XpZ2) {
854             freebn(S);
855             freebn(Z2);
856             return NULL;
857         }
858         XmZ2 = modsub(a->x, Z2, a->curve->p);
859         freebn(Z2);
860         if (!XmZ2) {
861             freebn(S);
862             freebn(XpZ2);
863             return NULL;
864         }
865
866         second = modmul(XpZ2, XmZ2, a->curve->p);
867         freebn(XpZ2);
868         freebn(XmZ2);
869         if (!second) {
870             freebn(S);
871             return NULL;
872         }
873
874         M = ecf_treble(second, a->curve);
875         freebn(second);
876         if (!M) {
877             freebn(S);
878             return NULL;
879         }
880     } else {
881         /* M = 3*X^2 + a*Z^4 */
882         Bignum _3X2, X2, aZ4;
883
884         if (a->z == NULL) {
885             aZ4 = copybn(a->curve->w.a);
886         } else {
887             Bignum Z2, Z4;
888
889             Z2 = ecf_square(a->z, a->curve);
890             if (!Z2) {
891                 freebn(S);
892                 return NULL;
893             }
894             Z4 = ecf_square(Z2, a->curve);
895             freebn(Z2);
896             if (!Z4) {
897                 freebn(S);
898                 return NULL;
899             }
900             aZ4 = modmul(a->curve->w.a, Z4, a->curve->p);
901             freebn(Z4);
902         }
903         if (!aZ4) {
904             freebn(S);
905             return NULL;
906         }
907
908         X2 = modmul(a->x, a->x, a->curve->p);
909         if (!X2) {
910             freebn(S);
911             freebn(aZ4);
912             return NULL;
913         }
914         _3X2 = ecf_treble(X2, a->curve);
915         freebn(X2);
916         if (!_3X2) {
917             freebn(S);
918             freebn(aZ4);
919             return NULL;
920         }
921         M = ecf_add(_3X2, aZ4, a->curve);
922         freebn(_3X2);
923         freebn(aZ4);
924         if (!M) {
925             freebn(S);
926             return NULL;
927         }
928     }
929
930     /* X' = M^2 - 2*S */
931     {
932         Bignum M2, _2S;
933
934         M2 = ecf_square(M, a->curve);
935         if (!M2) {
936             freebn(S);
937             freebn(M);
938             return NULL;
939         }
940
941         _2S = ecf_double(S, a->curve);
942         if (!_2S) {
943             freebn(M2);
944             freebn(S);
945             freebn(M);
946             return NULL;
947         }
948
949         outx = modsub(M2, _2S, a->curve->p);
950         freebn(M2);
951         freebn(_2S);
952         if (!outx) {
953             freebn(S);
954             freebn(M);
955             return NULL;
956         }
957     }
958
959     /* Y' = M*(S - X') - 8*Y^4 */
960     {
961         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
962
963         SX = modsub(S, outx, a->curve->p);
964         freebn(S);
965         if (!SX) {
966             freebn(M);
967             freebn(outx);
968             return NULL;
969         }
970         MSX = modmul(M, SX, a->curve->p);
971         freebn(SX);
972         freebn(M);
973         if (!MSX) {
974             freebn(outx);
975             return NULL;
976         }
977         Y2 = ecf_square(a->y, a->curve);
978         if (!Y2) {
979             freebn(outx);
980             freebn(MSX);
981             return NULL;
982         }
983         Y4 = ecf_square(Y2, a->curve);
984         freebn(Y2);
985         if (!Y4) {
986             freebn(outx);
987             freebn(MSX);
988             return NULL;
989         }
990         Eight = bignum_from_long(8);
991         if (!Eight) {
992             freebn(outx);
993             freebn(MSX);
994             freebn(Y4);
995             return NULL;
996         }
997         _8Y4 = modmul(Eight, Y4, a->curve->p);
998         freebn(Eight);
999         freebn(Y4);
1000         if (!_8Y4) {
1001             freebn(outx);
1002             freebn(MSX);
1003             return NULL;
1004         }
1005         outy = modsub(MSX, _8Y4, a->curve->p);
1006         freebn(MSX);
1007         freebn(_8Y4);
1008         if (!outy) {
1009             freebn(outx);
1010             return NULL;
1011         }
1012     }
1013
1014     /* Z' = 2*Y*Z */
1015     {
1016         Bignum YZ;
1017
1018         if (a->z == NULL) {
1019             YZ = copybn(a->y);
1020         } else {
1021             YZ = modmul(a->y, a->z, a->curve->p);
1022         }
1023         if (!YZ) {
1024             freebn(outx);
1025             freebn(outy);
1026             return NULL;
1027         }
1028
1029         outz = ecf_double(YZ, a->curve);
1030         freebn(YZ);
1031         if (!outz) {
1032             freebn(outx);
1033             freebn(outy);
1034             return NULL;
1035         }
1036     }
1037
1038     return ec_point_new(a->curve, outx, outy, outz, 0);
1039 }
1040
1041 static struct ec_point *ecp_doublem(const struct ec_point *a)
1042 {
1043     Bignum z, outx, outz, xpz, xmz;
1044
1045     z = a->z;
1046     if (!z) {
1047         z = One;
1048     }
1049
1050     /* 4xz = (x + z)^2 - (x - z)^2 */
1051     {
1052         Bignum tmp;
1053
1054         tmp = ecf_add(a->x, z, a->curve);
1055         if (!tmp) {
1056             return NULL;
1057         }
1058         xpz = ecf_square(tmp, a->curve);
1059         freebn(tmp);
1060         if (!xpz) {
1061             return NULL;
1062         }
1063
1064         tmp = modsub(a->x, z, a->curve->p);
1065         if (!tmp) {
1066             freebn(xpz);
1067             return NULL;
1068         }
1069         xmz = ecf_square(tmp, a->curve);
1070         freebn(tmp);
1071         if (!xmz) {
1072             freebn(xpz);
1073             return NULL;
1074         }
1075     }
1076
1077     /* outx = (x + z)^2 * (x - z)^2 */
1078     outx = modmul(xpz, xmz, a->curve->p);
1079     if (!outx) {
1080         freebn(xpz);
1081         freebn(xmz);
1082         return NULL;
1083     }
1084
1085     /* outz = 4xz * ((x - z)^2 + ((A + 2) / 4)*4xz) */
1086     {
1087         Bignum _4xz, tmp, tmp2, tmp3;
1088
1089         tmp = bignum_from_long(2);
1090         if (!tmp) {
1091             freebn(xpz);
1092             freebn(outx);
1093             freebn(xmz);
1094             return NULL;
1095         }
1096         tmp2 = ecf_add(a->curve->m.a, tmp, a->curve);
1097         freebn(tmp);
1098         if (!tmp2) {
1099             freebn(xpz);
1100             freebn(outx);
1101             freebn(xmz);
1102             return NULL;
1103         }
1104
1105         _4xz = modsub(xpz, xmz, a->curve->p);
1106         freebn(xpz);
1107         if (!_4xz) {
1108             freebn(tmp2);
1109             freebn(outx);
1110             freebn(xmz);
1111             return NULL;
1112         }
1113         tmp = modmul(tmp2, _4xz, a->curve->p);
1114         freebn(tmp2);
1115         if (!tmp) {
1116             freebn(_4xz);
1117             freebn(outx);
1118             freebn(xmz);
1119             return NULL;
1120         }
1121
1122         tmp2 = bignum_from_long(4);
1123         if (!tmp2) {
1124             freebn(tmp);
1125             freebn(_4xz);
1126             freebn(outx);
1127             freebn(xmz);
1128             return NULL;
1129         }
1130         tmp3 = modinv(tmp2, a->curve->p);
1131         freebn(tmp2);
1132         if (!tmp3) {
1133             freebn(tmp);
1134             freebn(_4xz);
1135             freebn(outx);
1136             freebn(xmz);
1137             return NULL;
1138         }
1139         tmp2 = modmul(tmp, tmp3, a->curve->p);
1140         freebn(tmp);
1141         freebn(tmp3);
1142         if (!tmp2) {
1143             freebn(_4xz);
1144             freebn(outx);
1145             freebn(xmz);
1146             return NULL;
1147         }
1148
1149         tmp = ecf_add(xmz, tmp2, a->curve);
1150         freebn(xmz);
1151         freebn(tmp2);
1152         if (!tmp) {
1153             freebn(_4xz);
1154             freebn(outx);
1155             return NULL;
1156         }
1157         outz = modmul(_4xz, tmp, a->curve->p);
1158         freebn(_4xz);
1159         freebn(tmp);
1160         if (!outz) {
1161             freebn(outx);
1162             return NULL;
1163         }
1164     }
1165
1166     return ec_point_new(a->curve, outx, NULL, outz, 0);
1167 }
1168
1169 /* Forward declaration for Edwards curve doubling */
1170 static struct ec_point *ecp_add(const struct ec_point *a,
1171                                 const struct ec_point *b,
1172                                 const int aminus3);
1173
1174 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
1175 {
1176     if (a->infinity)
1177     {
1178         /* Identity */
1179         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
1180     }
1181
1182     if (a->curve->type == EC_EDWARDS)
1183     {
1184         return ecp_add(a, a, aminus3);
1185     }
1186     else if (a->curve->type == EC_WEIERSTRASS)
1187     {
1188         return ecp_doublew(a, aminus3);
1189     }
1190     else
1191     {
1192         return ecp_doublem(a);
1193     }
1194 }
1195
1196 static struct ec_point *ecp_addw(const struct ec_point *a,
1197                                  const struct ec_point *b,
1198                                  const int aminus3)
1199 {
1200     Bignum U1, U2, S1, S2, outx, outy, outz;
1201
1202     /* U1 = X1*Z2^2 */
1203     /* S1 = Y1*Z2^3 */
1204     if (b->z) {
1205         Bignum Z2, Z3;
1206
1207         Z2 = ecf_square(b->z, a->curve);
1208         if (!Z2) {
1209             return NULL;
1210         }
1211         U1 = modmul(a->x, Z2, a->curve->p);
1212         if (!U1) {
1213             freebn(Z2);
1214             return NULL;
1215         }
1216         Z3 = modmul(Z2, b->z, a->curve->p);
1217         freebn(Z2);
1218         if (!Z3) {
1219             freebn(U1);
1220             return NULL;
1221         }
1222         S1 = modmul(a->y, Z3, a->curve->p);
1223         freebn(Z3);
1224         if (!S1) {
1225             freebn(U1);
1226             return NULL;
1227         }
1228     } else {
1229         U1 = copybn(a->x);
1230         if (!U1) {
1231             return NULL;
1232         }
1233         S1 = copybn(a->y);
1234         if (!S1) {
1235             freebn(U1);
1236             return NULL;
1237         }
1238     }
1239
1240     /* U2 = X2*Z1^2 */
1241     /* S2 = Y2*Z1^3 */
1242     if (a->z) {
1243         Bignum Z2, Z3;
1244
1245         Z2 = ecf_square(a->z, b->curve);
1246         if (!Z2) {
1247             freebn(U1);
1248             freebn(S1);
1249             return NULL;
1250         }
1251         U2 = modmul(b->x, Z2, b->curve->p);
1252         if (!U2) {
1253             freebn(U1);
1254             freebn(S1);
1255             freebn(Z2);
1256             return NULL;
1257         }
1258         Z3 = modmul(Z2, a->z, b->curve->p);
1259         freebn(Z2);
1260         if (!Z3) {
1261             freebn(U1);
1262             freebn(S1);
1263             freebn(U2);
1264             return NULL;
1265         }
1266         S2 = modmul(b->y, Z3, b->curve->p);
1267         freebn(Z3);
1268         if (!S2) {
1269             freebn(U1);
1270             freebn(S1);
1271             freebn(U2);
1272             return NULL;
1273         }
1274     } else {
1275         U2 = copybn(b->x);
1276         if (!U2) {
1277             freebn(U1);
1278             freebn(S1);
1279             return NULL;
1280         }
1281         S2 = copybn(b->y);
1282         if (!S2) {
1283             freebn(U1);
1284             freebn(S1);
1285             freebn(U2);
1286             return NULL;
1287         }
1288     }
1289
1290     /* Check if multiplying by self */
1291     if (bignum_cmp(U1, U2) == 0)
1292     {
1293         freebn(U1);
1294         freebn(U2);
1295         if (bignum_cmp(S1, S2) == 0)
1296         {
1297             freebn(S1);
1298             freebn(S2);
1299             return ecp_double(a, aminus3);
1300         }
1301         else
1302         {
1303             freebn(S1);
1304             freebn(S2);
1305             /* Infinity */
1306             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
1307         }
1308     }
1309
1310     {
1311         Bignum H, R, UH2, H3;
1312
1313         /* H = U2 - U1 */
1314         H = modsub(U2, U1, a->curve->p);
1315         freebn(U2);
1316         if (!H) {
1317             freebn(U1);
1318             freebn(S1);
1319             freebn(S2);
1320             return NULL;
1321         }
1322
1323         /* R = S2 - S1 */
1324         R = modsub(S2, S1, a->curve->p);
1325         freebn(S2);
1326         if (!R) {
1327             freebn(H);
1328             freebn(S1);
1329             freebn(U1);
1330             return NULL;
1331         }
1332
1333         /* X3 = R^2 - H^3 - 2*U1*H^2 */
1334         {
1335             Bignum R2, H2, _2UH2, first;
1336
1337             H2 = ecf_square(H, a->curve);
1338             if (!H2) {
1339                 freebn(U1);
1340                 freebn(S1);
1341                 freebn(H);
1342                 freebn(R);
1343                 return NULL;
1344             }
1345             UH2 = modmul(U1, H2, a->curve->p);
1346             freebn(U1);
1347             if (!UH2) {
1348                 freebn(H2);
1349                 freebn(S1);
1350                 freebn(H);
1351                 freebn(R);
1352                 return NULL;
1353             }
1354             H3 = modmul(H2, H, a->curve->p);
1355             freebn(H2);
1356             if (!H3) {
1357                 freebn(UH2);
1358                 freebn(S1);
1359                 freebn(H);
1360                 freebn(R);
1361                 return NULL;
1362             }
1363             R2 = ecf_square(R, a->curve);
1364             if (!R2) {
1365                 freebn(H3);
1366                 freebn(UH2);
1367                 freebn(S1);
1368                 freebn(H);
1369                 freebn(R);
1370                 return NULL;
1371             }
1372             _2UH2 = ecf_double(UH2, a->curve);
1373             if (!_2UH2) {
1374                 freebn(R2);
1375                 freebn(H3);
1376                 freebn(UH2);
1377                 freebn(S1);
1378                 freebn(H);
1379                 freebn(R);
1380                 return NULL;
1381             }
1382             first = modsub(R2, H3, a->curve->p);
1383             freebn(R2);
1384             if (!first) {
1385                 freebn(H3);
1386                 freebn(_2UH2);
1387                 freebn(UH2);
1388                 freebn(S1);
1389                 freebn(H);
1390                 freebn(R);
1391                 return NULL;
1392             }
1393             outx = modsub(first, _2UH2, a->curve->p);
1394             freebn(first);
1395             freebn(_2UH2);
1396             if (!outx) {
1397                 freebn(H3);
1398                 freebn(UH2);
1399                 freebn(S1);
1400                 freebn(H);
1401                 freebn(R);
1402                 return NULL;
1403             }
1404         }
1405
1406         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
1407         {
1408             Bignum RUH2mX, UH2mX, SH3;
1409
1410             UH2mX = modsub(UH2, outx, a->curve->p);
1411             freebn(UH2);
1412             if (!UH2mX) {
1413                 freebn(outx);
1414                 freebn(H3);
1415                 freebn(S1);
1416                 freebn(H);
1417                 freebn(R);
1418                 return NULL;
1419             }
1420             RUH2mX = modmul(R, UH2mX, a->curve->p);
1421             freebn(UH2mX);
1422             freebn(R);
1423             if (!RUH2mX) {
1424                 freebn(outx);
1425                 freebn(H3);
1426                 freebn(S1);
1427                 freebn(H);
1428                 return NULL;
1429             }
1430             SH3 = modmul(S1, H3, a->curve->p);
1431             freebn(S1);
1432             freebn(H3);
1433             if (!SH3) {
1434                 freebn(RUH2mX);
1435                 freebn(outx);
1436                 freebn(H);
1437                 return NULL;
1438             }
1439
1440             outy = modsub(RUH2mX, SH3, a->curve->p);
1441             freebn(RUH2mX);
1442             freebn(SH3);
1443             if (!outy) {
1444                 freebn(outx);
1445                 freebn(H);
1446                 return NULL;
1447             }
1448         }
1449
1450         /* Z3 = H*Z1*Z2 */
1451         if (a->z && b->z) {
1452             Bignum ZZ;
1453
1454             ZZ = modmul(a->z, b->z, a->curve->p);
1455             if (!ZZ) {
1456                 freebn(outx);
1457                 freebn(outy);
1458                 freebn(H);
1459                 return NULL;
1460             }
1461             outz = modmul(H, ZZ, a->curve->p);
1462             freebn(H);
1463             freebn(ZZ);
1464             if (!outz) {
1465                 freebn(outx);
1466                 freebn(outy);
1467                 return NULL;
1468             }
1469         } else if (a->z) {
1470             outz = modmul(H, a->z, a->curve->p);
1471             freebn(H);
1472             if (!outz) {
1473                 freebn(outx);
1474                 freebn(outy);
1475                 return NULL;
1476             }
1477         } else if (b->z) {
1478             outz = modmul(H, b->z, a->curve->p);
1479             freebn(H);
1480             if (!outz) {
1481                 freebn(outx);
1482                 freebn(outy);
1483                 return NULL;
1484             }
1485         } else {
1486             outz = H;
1487         }
1488     }
1489
1490     return ec_point_new(a->curve, outx, outy, outz, 0);
1491 }
1492
1493 static struct ec_point *ecp_addm(const struct ec_point *a,
1494                                  const struct ec_point *b,
1495                                  const struct ec_point *base)
1496 {
1497     Bignum outx, outz, az, bz;
1498
1499     az = a->z;
1500     if (!az) {
1501         az = One;
1502     }
1503     bz = b->z;
1504     if (!bz) {
1505         bz = One;
1506     }
1507
1508     /* a-b is maintained at 1 due to Montgomery ladder implementation */
1509     /* Xa+b = Za-b * ((Xa - Za)*(Xb + Zb) + (Xa + Za)*(Xb - Zb))^2 */
1510     /* Za+b = Xa-b * ((Xa - Za)*(Xb + Zb) - (Xa + Za)*(Xb - Zb))^2 */
1511     {
1512         Bignum tmp, tmp2, tmp3, tmp4;
1513
1514         /* (Xa + Za) * (Xb - Zb) */
1515         tmp = ecf_add(a->x, az, a->curve);
1516         if (!tmp) {
1517             return NULL;
1518         }
1519         tmp2 = modsub(b->x, bz, a->curve->p);
1520         if (!tmp2) {
1521             freebn(tmp);
1522             return NULL;
1523         }
1524         tmp3 = modmul(tmp, tmp2, a->curve->p);
1525         freebn(tmp);
1526         freebn(tmp2);
1527         if (!tmp3) {
1528             return NULL;
1529         }
1530
1531         /* (Xa - Za) * (Xb + Zb) */
1532         tmp = modsub(a->x, az, a->curve->p);
1533         if (!tmp) {
1534             freebn(tmp3);
1535             return NULL;
1536         }
1537         tmp2 = ecf_add(b->x, bz, a->curve);
1538         if (!tmp2) {
1539             freebn(tmp);
1540             freebn(tmp3);
1541             return NULL;
1542         }
1543         tmp4 = modmul(tmp, tmp2, a->curve->p);
1544         freebn(tmp);
1545         freebn(tmp2);
1546         if (!tmp4) {
1547             freebn(tmp3);
1548             return NULL;
1549         }
1550
1551         tmp = ecf_add(tmp3, tmp4, a->curve);
1552         if (!tmp) {
1553             freebn(tmp3);
1554             freebn(tmp4);
1555             return NULL;
1556         }
1557         outx = ecf_square(tmp, a->curve);
1558         freebn(tmp);
1559         if (!outx) {
1560             freebn(tmp3);
1561             freebn(tmp4);
1562             return NULL;
1563         }
1564
1565         tmp = modsub(tmp3, tmp4, a->curve->p);
1566         freebn(tmp3);
1567         freebn(tmp4);
1568         if (!tmp) {
1569             freebn(outx);
1570             return NULL;
1571         }
1572         tmp2 = ecf_square(tmp, a->curve);
1573         freebn(tmp);
1574         if (!tmp2) {
1575             freebn(outx);
1576             return NULL;
1577         }
1578         outz = modmul(base->x, tmp2, a->curve->p);
1579         freebn(tmp2);
1580         if (!outz) {
1581             freebn(outx);
1582             return NULL;
1583         }
1584     }
1585
1586     return ec_point_new(a->curve, outx, NULL, outz, 0);
1587 }
1588
1589 static struct ec_point *ecp_adde(const struct ec_point *a,
1590                                  const struct ec_point *b)
1591 {
1592     Bignum outx, outy, dmul;
1593
1594     /* outx = (a->x * b->y + b->x * a->y) /
1595      *        (1 + a->curve->e.d * a->x * b->x * a->y * b->y) */
1596     {
1597         Bignum tmp, tmp2, tmp3, tmp4;
1598
1599         tmp = modmul(a->x, b->y, a->curve->p);
1600         if (!tmp)
1601         {
1602             return NULL;
1603         }
1604         tmp2 = modmul(b->x, a->y, a->curve->p);
1605         if (!tmp2)
1606         {
1607             freebn(tmp);
1608             return NULL;
1609         }
1610         tmp3 = ecf_add(tmp, tmp2, a->curve);
1611         if (!tmp3)
1612         {
1613             freebn(tmp);
1614             freebn(tmp2);
1615             return NULL;
1616         }
1617
1618         tmp4 = modmul(tmp, tmp2, a->curve->p);
1619         freebn(tmp);
1620         freebn(tmp2);
1621         if (!tmp4)
1622         {
1623             freebn(tmp3);
1624             return NULL;
1625         }
1626         dmul = modmul(a->curve->e.d, tmp4, a->curve->p);
1627         freebn(tmp4);
1628         if (!dmul) {
1629             freebn(tmp3);
1630             return NULL;
1631         }
1632
1633         tmp = ecf_add(One, dmul, a->curve);
1634         if (!tmp)
1635         {
1636             freebn(tmp3);
1637             freebn(dmul);
1638             return NULL;
1639         }
1640         tmp2 = modinv(tmp, a->curve->p);
1641         freebn(tmp);
1642         if (!tmp2)
1643         {
1644             freebn(tmp3);
1645             freebn(dmul);
1646             return NULL;
1647         }
1648
1649         outx = modmul(tmp3, tmp2, a->curve->p);
1650         freebn(tmp3);
1651         freebn(tmp2);
1652         if (!outx)
1653         {
1654             freebn(dmul);
1655             return NULL;
1656         }
1657     }
1658
1659     /* outy = (a->y * b->y + a->x * b->x) /
1660      *        (1 - a->curve->e.d * a->x * b->x * a->y * b->y) */
1661     {
1662         Bignum tmp, tmp2, tmp3, tmp4;
1663
1664         tmp = modsub(One, dmul, a->curve->p);
1665         freebn(dmul);
1666         if (!tmp)
1667         {
1668             freebn(outx);
1669             return NULL;
1670         }
1671
1672         tmp2 = modinv(tmp, a->curve->p);
1673         freebn(tmp);
1674         if (!tmp2)
1675         {
1676             freebn(outx);
1677             return NULL;
1678         }
1679
1680         tmp = modmul(a->y, b->y, a->curve->p);
1681         if (!tmp)
1682         {
1683             freebn(tmp2);
1684             freebn(outx);
1685             return NULL;
1686         }
1687         tmp3 = modmul(a->x, b->x, a->curve->p);
1688         if (!tmp3)
1689         {
1690             freebn(tmp);
1691             freebn(tmp2);
1692             freebn(outx);
1693             return NULL;
1694         }
1695         tmp4 = ecf_add(tmp, tmp3, a->curve);
1696         freebn(tmp);
1697         freebn(tmp3);
1698         if (!tmp4)
1699         {
1700             freebn(tmp2);
1701             freebn(outx);
1702             return NULL;
1703         }
1704
1705         outy = modmul(tmp4, tmp2, a->curve->p);
1706         freebn(tmp4);
1707         freebn(tmp2);
1708         if (!outy)
1709         {
1710             freebn(outx);
1711             return NULL;
1712         }
1713     }
1714
1715     return ec_point_new(a->curve, outx, outy, NULL, 0);
1716 }
1717
1718 static struct ec_point *ecp_add(const struct ec_point *a,
1719                                 const struct ec_point *b,
1720                                 const int aminus3)
1721 {
1722     if (a->curve != b->curve) {
1723         return NULL;
1724     }
1725
1726     /* Check if multiplying by infinity */
1727     if (a->infinity) return ec_point_copy(b);
1728     if (b->infinity) return ec_point_copy(a);
1729
1730     if (a->curve->type == EC_EDWARDS)
1731     {
1732         return ecp_adde(a, b);
1733     }
1734
1735     if (a->curve->type == EC_WEIERSTRASS)
1736     {
1737         return ecp_addw(a, b, aminus3);
1738     }
1739
1740     return NULL;
1741 }
1742
1743 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1744 {
1745     struct ec_point *A, *ret;
1746     int bits, i;
1747
1748     A = ec_point_copy(a);
1749     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1750
1751     bits = bignum_bitcount(b);
1752     for (i = 0; ret != NULL && A != NULL && i < bits; ++i)
1753     {
1754         if (bignum_bit(b, i))
1755         {
1756             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1757             ec_point_free(ret);
1758             ret = tmp;
1759         }
1760         if (i+1 != bits)
1761         {
1762             struct ec_point *tmp = ecp_double(A, aminus3);
1763             ec_point_free(A);
1764             A = tmp;
1765         }
1766     }
1767
1768     if (!A) {
1769         ec_point_free(ret);
1770         ret = NULL;
1771     } else {
1772         ec_point_free(A);
1773     }
1774
1775     return ret;
1776 }
1777
1778 static struct ec_point *ecp_mulw(const struct ec_point *a, const Bignum b)
1779 {
1780     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1781
1782     if (!ecp_normalise(ret)) {
1783         ec_point_free(ret);
1784         return NULL;
1785     }
1786
1787     return ret;
1788 }
1789
1790 static struct ec_point *ecp_mule(const struct ec_point *a, const Bignum b)
1791 {
1792     int i;
1793     struct ec_point *ret;
1794
1795     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1796
1797     for (i = bignum_bitcount(b); i >= 0 && ret; --i)
1798     {
1799         {
1800             struct ec_point *tmp = ecp_double(ret, 0);
1801             ec_point_free(ret);
1802             ret = tmp;
1803         }
1804         if (ret && bignum_bit(b, i))
1805         {
1806             struct ec_point *tmp = ecp_add(ret, a, 0);
1807             ec_point_free(ret);
1808             ret = tmp;
1809         }
1810     }
1811
1812     return ret;
1813 }
1814
1815 static struct ec_point *ecp_mulm(const struct ec_point *p, const Bignum n)
1816 {
1817     struct ec_point *P1, *P2;
1818     int bits, i;
1819
1820     /* P1 <- P and P2 <- [2]P */
1821     P2 = ecp_double(p, 0);
1822     if (!P2) {
1823         return NULL;
1824     }
1825     P1 = ec_point_copy(p);
1826     if (!P1) {
1827         ec_point_free(P2);
1828         return NULL;
1829     }
1830
1831     /* for i = bits âˆ’ 2 down to 0 */
1832     bits = bignum_bitcount(n);
1833     for (i = bits - 2; P1 != NULL && P2 != NULL && i >= 0; --i)
1834     {
1835         if (!bignum_bit(n, i))
1836         {
1837             /* P2 <- P1 + P2 */
1838             struct ec_point *tmp = ecp_addm(P1, P2, p);
1839             ec_point_free(P2);
1840             P2 = tmp;
1841
1842             /* P1 <- [2]P1 */
1843             tmp = ecp_double(P1, 0);
1844             ec_point_free(P1);
1845             P1 = tmp;
1846         }
1847         else
1848         {
1849             /* P1 <- P1 + P2 */
1850             struct ec_point *tmp = ecp_addm(P1, P2, p);
1851             ec_point_free(P1);
1852             P1 = tmp;
1853
1854             /* P2 <- [2]P2 */
1855             tmp = ecp_double(P2, 0);
1856             ec_point_free(P2);
1857             P2 = tmp;
1858         }
1859     }
1860
1861     if (!P2) {
1862         if (P1) ec_point_free(P1);
1863         P1 = NULL;
1864     } else {
1865         ec_point_free(P2);
1866     }
1867
1868     if (!ecp_normalise(P1)) {
1869         ec_point_free(P1);
1870         return NULL;
1871     }
1872
1873     return P1;
1874 }
1875
1876 /* Not static because it is used by sshecdsag.c to generate a new key */
1877 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1878 {
1879     if (a->curve->type == EC_WEIERSTRASS) {
1880         return ecp_mulw(a, b);
1881     } else if (a->curve->type == EC_EDWARDS) {
1882         return ecp_mule(a, b);
1883     } else {
1884         return ecp_mulm(a, b);
1885     }
1886 }
1887
1888 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
1889                                    const struct ec_point *point)
1890 {
1891     struct ec_point *aG, *bP, *ret;
1892     int aminus3;
1893
1894     if (point->curve->type != EC_WEIERSTRASS) {
1895         return NULL;
1896     }
1897
1898     aminus3 = ec_aminus3(point->curve);
1899
1900     aG = ecp_mul_(&point->curve->w.G, a, aminus3);
1901     if (!aG) return NULL;
1902     bP = ecp_mul_(point, b, aminus3);
1903     if (!bP) {
1904         ec_point_free(aG);
1905         return NULL;
1906     }
1907
1908     ret = ecp_add(aG, bP, aminus3);
1909
1910     ec_point_free(aG);
1911     ec_point_free(bP);
1912
1913     if (!ecp_normalise(ret)) {
1914         ec_point_free(ret);
1915         return NULL;
1916     }
1917
1918     return ret;
1919 }
1920 static Bignum *ecp_edx(const struct ec_curve *curve, const Bignum y)
1921 {
1922     /* Get the x value on the given Edwards curve for a given y */
1923     Bignum x, xx;
1924
1925     /* xx = (y^2 - 1) / (d * y^2 + 1) */
1926     {
1927         Bignum tmp, tmp2, tmp3;
1928
1929         tmp = ecf_square(y, curve);
1930         if (!tmp) {
1931             return NULL;
1932         }
1933         tmp2 = modmul(curve->e.d, tmp, curve->p);
1934         if (!tmp2) {
1935             freebn(tmp);
1936             return NULL;
1937         }
1938         tmp3 = ecf_add(tmp2, One, curve);
1939         freebn(tmp2);
1940         if (!tmp3) {
1941             freebn(tmp);
1942             return NULL;
1943         }
1944         tmp2 = modinv(tmp3, curve->p);
1945         freebn(tmp3);
1946         if (!tmp2) {
1947             freebn(tmp);
1948             return NULL;
1949         }
1950
1951         tmp3 = modsub(tmp, One, curve->p);
1952         freebn(tmp);
1953         if (!tmp3) {
1954             freebn(tmp2);
1955             return NULL;
1956         }
1957         xx = modmul(tmp3, tmp2, curve->p);
1958         freebn(tmp3);
1959         freebn(tmp2);
1960         if (!xx) {
1961             return NULL;
1962         }
1963     }
1964
1965     /* x = xx^((p + 3) / 8) */
1966     {
1967         Bignum tmp, tmp2;
1968
1969         tmp = bignum_add_long(curve->p, 3);
1970         if (!tmp) {
1971             freebn(xx);
1972             return NULL;
1973         }
1974         tmp2 = bignum_rshift(tmp, 3);
1975         freebn(tmp);
1976         if (!tmp2) {
1977             freebn(xx);
1978             return NULL;
1979         }
1980         x = modpow(xx, tmp2, curve->p);
1981         freebn(tmp2);
1982         if (!x) {
1983             freebn(xx);
1984             return NULL;
1985         }
1986     }
1987
1988     /* if x^2 - xx != 0 then x = x*(2^((p - 1) / 4)) */
1989     {
1990         Bignum tmp, tmp2;
1991
1992         tmp = ecf_square(x, curve);
1993         if (!tmp) {
1994             freebn(x);
1995             freebn(xx);
1996             return NULL;
1997         }
1998         tmp2 = modsub(tmp, xx, curve->p);
1999         freebn(tmp);
2000         freebn(xx);
2001         if (!tmp2) {
2002             freebn(x);
2003             return NULL;
2004         }
2005         if (bignum_cmp(tmp2, Zero)) {
2006             Bignum tmp3;
2007
2008             freebn(tmp2);
2009
2010             tmp = modsub(curve->p, One, curve->p);
2011             if (!tmp) {
2012                 freebn(x);
2013                 return NULL;
2014             }
2015             tmp2 = bignum_rshift(tmp, 2);
2016             freebn(tmp);
2017             if (!tmp2) {
2018                 freebn(x);
2019                 return NULL;
2020             }
2021             tmp = bignum_from_long(2);
2022             if (!tmp) {
2023                 freebn(tmp2);
2024                 freebn(x);
2025                 return NULL;
2026             }
2027             tmp3 = modpow(tmp, tmp2, curve->p);
2028             freebn(tmp);
2029             freebn(tmp2);
2030             if (!tmp3) {
2031                 freebn(x);
2032                 return NULL;
2033             }
2034
2035             tmp = modmul(x, tmp3, curve->p);
2036             freebn(x);
2037             freebn(tmp3);
2038             x = tmp;
2039             if (!tmp) {
2040                 return NULL;
2041             }
2042         } else {
2043             freebn(tmp2);
2044         }
2045     }
2046
2047     /* if x % 2 != 0 then x = p - x */
2048     if (bignum_bit(x, 0)) {
2049         Bignum tmp = modsub(curve->p, x, curve->p);
2050         freebn(x);
2051         x = tmp;
2052         if (!tmp) {
2053             return NULL;
2054         }
2055     }
2056
2057     return x;
2058 }
2059
2060 /* ----------------------------------------------------------------------
2061  * Public point from private
2062  */
2063
2064 struct ec_point *ec_public(const Bignum privateKey, const struct ec_curve *curve)
2065 {
2066     if (curve->type == EC_WEIERSTRASS) {
2067         return ecp_mul(&curve->w.G, privateKey);
2068     } else if (curve->type == EC_EDWARDS) {
2069         /* hash = H(sk) (where hash creates 2 * fieldBits)
2070          * b = fieldBits
2071          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
2072          * publicKey = aB */
2073         struct ec_point *ret;
2074         unsigned char hash[512/8];
2075         Bignum a;
2076         int i, keylen;
2077         SHA512_State s;
2078         SHA512_Init(&s);
2079
2080         keylen = curve->fieldBits / 8;
2081         for (i = 0; i < keylen; ++i) {
2082             unsigned char b = bignum_byte(privateKey, i);
2083             SHA512_Bytes(&s, &b, 1);
2084         }
2085         SHA512_Final(&s, hash);
2086
2087         /* The second part is simply turning the hash into a Bignum,
2088          * however the 2^(b-2) bit *must* be set, and the bottom 3
2089          * bits *must* not be */
2090         hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
2091         hash[31] &= 0x7f; /* Unset above (b-2) */
2092         hash[31] |= 0x40; /* Set 2^(b-2) */
2093         /* Chop off the top part and convert to int */
2094         a = bignum_from_bytes_le(hash, 32);
2095         if (!a) {
2096             return NULL;
2097         }
2098
2099         ret = ecp_mul(&curve->e.B, a);
2100         freebn(a);
2101         return ret;
2102     } else {
2103         return NULL;
2104     }
2105 }
2106
2107 /* ----------------------------------------------------------------------
2108  * Basic sign and verify routines
2109  */
2110
2111 static int _ecdsa_verify(const struct ec_point *publicKey,
2112                          const unsigned char *data, const int dataLen,
2113                          const Bignum r, const Bignum s)
2114 {
2115     int z_bits, n_bits;
2116     Bignum z;
2117     int valid = 0;
2118
2119     if (publicKey->curve->type != EC_WEIERSTRASS) {
2120         return 0;
2121     }
2122
2123     /* Sanity checks */
2124     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->w.n) >= 0
2125         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->w.n) >= 0)
2126     {
2127         return 0;
2128     }
2129
2130     /* z = left most bitlen(curve->n) of data */
2131     z = bignum_from_bytes(data, dataLen);
2132     if (!z) return 0;
2133     n_bits = bignum_bitcount(publicKey->curve->w.n);
2134     z_bits = bignum_bitcount(z);
2135     if (z_bits > n_bits)
2136     {
2137         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
2138         freebn(z);
2139         z = tmp;
2140         if (!z) return 0;
2141     }
2142
2143     /* Ensure z in range of n */
2144     {
2145         Bignum tmp = bigmod(z, publicKey->curve->w.n);
2146         freebn(z);
2147         z = tmp;
2148         if (!z) return 0;
2149     }
2150
2151     /* Calculate signature */
2152     {
2153         Bignum w, x, u1, u2;
2154         struct ec_point *tmp;
2155
2156         w = modinv(s, publicKey->curve->w.n);
2157         if (!w) {
2158             freebn(z);
2159             return 0;
2160         }
2161         u1 = modmul(z, w, publicKey->curve->w.n);
2162         if (!u1) {
2163             freebn(z);
2164             freebn(w);
2165             return 0;
2166         }
2167         u2 = modmul(r, w, publicKey->curve->w.n);
2168         freebn(w);
2169         if (!u2) {
2170             freebn(z);
2171             freebn(u1);
2172             return 0;
2173         }
2174
2175         tmp = ecp_summul(u1, u2, publicKey);
2176         freebn(u1);
2177         freebn(u2);
2178         if (!tmp) {
2179             freebn(z);
2180             return 0;
2181         }
2182
2183         x = bigmod(tmp->x, publicKey->curve->w.n);
2184         ec_point_free(tmp);
2185         if (!x) {
2186             freebn(z);
2187             return 0;
2188         }
2189
2190         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
2191         freebn(x);
2192     }
2193
2194     freebn(z);
2195
2196     return valid;
2197 }
2198
2199 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
2200                         const unsigned char *data, const int dataLen,
2201                         Bignum *r, Bignum *s)
2202 {
2203     unsigned char digest[20];
2204     int z_bits, n_bits;
2205     Bignum z, k;
2206     struct ec_point *kG;
2207
2208     *r = NULL;
2209     *s = NULL;
2210
2211     if (curve->type != EC_WEIERSTRASS) {
2212         return;
2213     }
2214
2215     /* z = left most bitlen(curve->n) of data */
2216     z = bignum_from_bytes(data, dataLen);
2217     if (!z) return;
2218     n_bits = bignum_bitcount(curve->w.n);
2219     z_bits = bignum_bitcount(z);
2220     if (z_bits > n_bits)
2221     {
2222         Bignum tmp;
2223         tmp = bignum_rshift(z, z_bits - n_bits);
2224         freebn(z);
2225         z = tmp;
2226         if (!z) return;
2227     }
2228
2229     /* Generate k between 1 and curve->n, using the same deterministic
2230      * k generation system we use for conventional DSA. */
2231     SHA_Simple(data, dataLen, digest);
2232     k = dss_gen_k("ECDSA deterministic k generator", curve->w.n, privateKey,
2233                   digest, sizeof(digest));
2234     if (!k) return;
2235
2236     kG = ecp_mul(&curve->w.G, k);
2237     if (!kG) {
2238         freebn(z);
2239         freebn(k);
2240         return;
2241     }
2242
2243     /* r = kG.x mod n */
2244     *r = bigmod(kG->x, curve->w.n);
2245     ec_point_free(kG);
2246     if (!*r) {
2247         freebn(z);
2248         freebn(k);
2249         return;
2250     }
2251
2252     /* s = (z + r * priv)/k mod n */
2253     {
2254         Bignum rPriv, zMod, first, firstMod, kInv;
2255         rPriv = modmul(*r, privateKey, curve->w.n);
2256         if (!rPriv) {
2257             freebn(*r);
2258             freebn(z);
2259             freebn(k);
2260             return;
2261         }
2262         zMod = bigmod(z, curve->w.n);
2263         freebn(z);
2264         if (!zMod) {
2265             freebn(rPriv);
2266             freebn(*r);
2267             freebn(k);
2268             return;
2269         }
2270         first = bigadd(rPriv, zMod);
2271         freebn(rPriv);
2272         freebn(zMod);
2273         if (!first) {
2274             freebn(*r);
2275             freebn(k);
2276             return;
2277         }
2278         firstMod = bigmod(first, curve->w.n);
2279         freebn(first);
2280         if (!firstMod) {
2281             freebn(*r);
2282             freebn(k);
2283             return;
2284         }
2285         kInv = modinv(k, curve->w.n);
2286         freebn(k);
2287         if (!kInv) {
2288             freebn(firstMod);
2289             freebn(*r);
2290             return;
2291         }
2292         *s = modmul(firstMod, kInv, curve->w.n);
2293         freebn(firstMod);
2294         freebn(kInv);
2295         if (!*s) {
2296             freebn(*r);
2297             return;
2298         }
2299     }
2300 }
2301
2302 /* ----------------------------------------------------------------------
2303  * Misc functions
2304  */
2305
2306 static void getstring(const char **data, int *datalen,
2307                       const char **p, int *length)
2308 {
2309     *p = NULL;
2310     if (*datalen < 4)
2311         return;
2312     *length = toint(GET_32BIT(*data));
2313     if (*length < 0)
2314         return;
2315     *datalen -= 4;
2316     *data += 4;
2317     if (*datalen < *length)
2318         return;
2319     *p = *data;
2320     *data += *length;
2321     *datalen -= *length;
2322 }
2323
2324 static Bignum getmp(const char **data, int *datalen)
2325 {
2326     const char *p;
2327     int length;
2328
2329     getstring(data, datalen, &p, &length);
2330     if (!p)
2331         return NULL;
2332     if (p[0] & 0x80)
2333         return NULL;                   /* negative mp */
2334     return bignum_from_bytes((unsigned char *)p, length);
2335 }
2336
2337 static Bignum getmp_le(const char **data, int *datalen)
2338 {
2339     const char *p;
2340     int length;
2341
2342     getstring(data, datalen, &p, &length);
2343     if (!p)
2344         return NULL;
2345     return bignum_from_bytes_le((const unsigned char *)p, length);
2346 }
2347
2348 static int decodepoint_ed(const char *p, int length, struct ec_point *point)
2349 {
2350     /* Got some conversion to do, first read in the y co-ord */
2351     int negative;
2352
2353     point->y = bignum_from_bytes_le((const unsigned char*)p, length);
2354     if (!point->y) {
2355         return 0;
2356     }
2357     if ((unsigned)bignum_bitcount(point->y) > point->curve->fieldBits) {
2358         freebn(point->y);
2359         point->y = NULL;
2360         return 0;
2361     }
2362     /* Read x bit and then reset it */
2363     negative = bignum_bit(point->y, point->curve->fieldBits - 1);
2364     bignum_set_bit(point->y, point->curve->fieldBits - 1, 0);
2365
2366     /* Get the x from the y */
2367     point->x = ecp_edx(point->curve, point->y);
2368     if (!point->x) {
2369         freebn(point->y);
2370         point->y = NULL;
2371         return 0;
2372     }
2373     if (negative) {
2374         Bignum tmp = modsub(point->curve->p, point->x, point->curve->p);
2375         freebn(point->x);
2376         point->x = tmp;
2377         if (!tmp) {
2378             freebn(point->y);
2379             point->y = NULL;
2380             return 0;
2381         }
2382     }
2383
2384     /* Verify the point is on the curve */
2385     if (!ec_point_verify(point)) {
2386         freebn(point->x);
2387         point->x = NULL;
2388         freebn(point->y);
2389         point->y = NULL;
2390         return 0;
2391     }
2392
2393     return 1;
2394 }
2395
2396 static int decodepoint(const char *p, int length, struct ec_point *point)
2397 {
2398     if (point->curve->type == EC_EDWARDS) {
2399         return decodepoint_ed(p, length, point);
2400     }
2401
2402     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
2403         return 0;
2404     /* Skip compression flag */
2405     ++p;
2406     --length;
2407     /* The two values must be equal length */
2408     if (length % 2 != 0) {
2409         point->x = NULL;
2410         point->y = NULL;
2411         point->z = NULL;
2412         return 0;
2413     }
2414     length = length / 2;
2415     point->x = bignum_from_bytes((const unsigned char *)p, length);
2416     if (!point->x) return 0;
2417     p += length;
2418     point->y = bignum_from_bytes((const unsigned char *)p, length);
2419     if (!point->y) {
2420         freebn(point->x);
2421         point->x = NULL;
2422         return 0;
2423     }
2424     point->z = NULL;
2425
2426     /* Verify the point is on the curve */
2427     if (!ec_point_verify(point)) {
2428         freebn(point->x);
2429         point->x = NULL;
2430         freebn(point->y);
2431         point->y = NULL;
2432         return 0;
2433     }
2434
2435     return 1;
2436 }
2437
2438 static int getmppoint(const char **data, int *datalen, struct ec_point *point)
2439 {
2440     const char *p;
2441     int length;
2442
2443     getstring(data, datalen, &p, &length);
2444     if (!p) return 0;
2445     return decodepoint(p, length, point);
2446 }
2447
2448 /* ----------------------------------------------------------------------
2449  * Exposed ECDSA interface
2450  */
2451
2452 struct ecsign_extra {
2453     struct ec_curve *(*curve)(void);
2454     const struct ssh_hash *hash;
2455
2456     /* These fields are used by the OpenSSH PEM format importer/exporter */
2457     const unsigned char *oid;
2458     int oidlen;
2459 };
2460
2461 static void ecdsa_freekey(void *key)
2462 {
2463     struct ec_key *ec = (struct ec_key *) key;
2464     if (!ec) return;
2465
2466     if (ec->publicKey.x)
2467         freebn(ec->publicKey.x);
2468     if (ec->publicKey.y)
2469         freebn(ec->publicKey.y);
2470     if (ec->publicKey.z)
2471         freebn(ec->publicKey.z);
2472     if (ec->privateKey)
2473         freebn(ec->privateKey);
2474     sfree(ec);
2475 }
2476
2477 static void *ecdsa_newkey(const struct ssh_signkey *self,
2478                           const char *data, int len)
2479 {
2480     const struct ecsign_extra *extra =
2481         (const struct ecsign_extra *)self->extra;
2482     const char *p;
2483     int slen;
2484     struct ec_key *ec;
2485     struct ec_curve *curve;
2486
2487     getstring(&data, &len, &p, &slen);
2488
2489     if (!p) {
2490         return NULL;
2491     }
2492     curve = extra->curve();
2493     assert(curve->type == EC_WEIERSTRASS || curve->type == EC_EDWARDS);
2494
2495     /* Curve name is duplicated for Weierstrass form */
2496     if (curve->type == EC_WEIERSTRASS) {
2497         getstring(&data, &len, &p, &slen);
2498         if (!match_ssh_id(slen, p, curve->name)) return NULL;
2499     }
2500
2501     ec = snew(struct ec_key);
2502
2503     ec->signalg = self;
2504     ec->publicKey.curve = curve;
2505     ec->publicKey.infinity = 0;
2506     ec->publicKey.x = NULL;
2507     ec->publicKey.y = NULL;
2508     ec->publicKey.z = NULL;
2509     if (!getmppoint(&data, &len, &ec->publicKey)) {
2510         ecdsa_freekey(ec);
2511         return NULL;
2512     }
2513     ec->privateKey = NULL;
2514
2515     if (!ec->publicKey.x || !ec->publicKey.y ||
2516         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2517         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2518     {
2519         ecdsa_freekey(ec);
2520         ec = NULL;
2521     }
2522
2523     return ec;
2524 }
2525
2526 static char *ecdsa_fmtkey(void *key)
2527 {
2528     struct ec_key *ec = (struct ec_key *) key;
2529     char *p;
2530     int len, i, pos, nibbles;
2531     static const char hex[] = "0123456789abcdef";
2532     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2533         return NULL;
2534
2535     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
2536     if (ec->publicKey.curve->name)
2537         len += strlen(ec->publicKey.curve->name); /* Curve name */
2538     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
2539     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
2540     p = snewn(len, char);
2541
2542     pos = 0;
2543     if (ec->publicKey.curve->name)
2544         pos += sprintf(p + pos, "%s,", ec->publicKey.curve->name);
2545     pos += sprintf(p + pos, "0x");
2546     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
2547     if (nibbles < 1)
2548         nibbles = 1;
2549     for (i = nibbles; i--;) {
2550         p[pos++] =
2551             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
2552     }
2553     pos += sprintf(p + pos, ",0x");
2554     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
2555     if (nibbles < 1)
2556         nibbles = 1;
2557     for (i = nibbles; i--;) {
2558         p[pos++] =
2559             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
2560     }
2561     p[pos] = '\0';
2562     return p;
2563 }
2564
2565 static unsigned char *ecdsa_public_blob(void *key, int *len)
2566 {
2567     struct ec_key *ec = (struct ec_key *) key;
2568     int pointlen, bloblen, fullnamelen, namelen;
2569     int i;
2570     unsigned char *blob, *p;
2571
2572     fullnamelen = strlen(ec->signalg->name);
2573
2574     if (ec->publicKey.curve->type == EC_EDWARDS) {
2575         /* Edwards compressed form "ssh-ed25519" point y[:-1] + x[0:1] */
2576
2577         pointlen = ec->publicKey.curve->fieldBits / 8;
2578
2579         /* Can't handle this in our loop */
2580         if (pointlen < 2) return NULL;
2581
2582         bloblen = 4 + fullnamelen + 4 + pointlen;
2583         blob = snewn(bloblen, unsigned char);
2584         if (!blob) return NULL;
2585
2586         p = blob;
2587         PUT_32BIT(p, fullnamelen);
2588         p += 4;
2589         memcpy(p, ec->signalg->name, fullnamelen);
2590         p += fullnamelen;
2591         PUT_32BIT(p, pointlen);
2592         p += 4;
2593
2594         /* Unset last bit of y and set first bit of x in its place */
2595         for (i = 0; i < pointlen - 1; ++i) {
2596             *p++ = bignum_byte(ec->publicKey.y, i);
2597         }
2598         /* Unset last bit of y and set first bit of x in its place */
2599         *p = bignum_byte(ec->publicKey.y, i) & 0x7f;
2600         *p++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2601     } else if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2602         assert(ec->publicKey.curve->name);
2603         namelen = strlen(ec->publicKey.curve->name);
2604
2605         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2606
2607         /*
2608          * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
2609          */
2610         bloblen = 4 + fullnamelen + 4 + namelen + 4 + 1 + (pointlen * 2);
2611         blob = snewn(bloblen, unsigned char);
2612
2613         p = blob;
2614         PUT_32BIT(p, fullnamelen);
2615         p += 4;
2616         memcpy(p, ec->signalg->name, fullnamelen);
2617         p += fullnamelen;
2618         PUT_32BIT(p, namelen);
2619         p += 4;
2620         memcpy(p, ec->publicKey.curve->name, namelen);
2621         p += namelen;
2622         PUT_32BIT(p, (2 * pointlen) + 1);
2623         p += 4;
2624         *p++ = 0x04;
2625         for (i = pointlen; i--;) {
2626             *p++ = bignum_byte(ec->publicKey.x, i);
2627         }
2628         for (i = pointlen; i--;) {
2629             *p++ = bignum_byte(ec->publicKey.y, i);
2630         }
2631     } else {
2632         return NULL;
2633     }
2634
2635     assert(p == blob + bloblen);
2636     *len = bloblen;
2637
2638     return blob;
2639 }
2640
2641 static unsigned char *ecdsa_private_blob(void *key, int *len)
2642 {
2643     struct ec_key *ec = (struct ec_key *) key;
2644     int keylen, bloblen;
2645     int i;
2646     unsigned char *blob, *p;
2647
2648     if (!ec->privateKey) return NULL;
2649
2650     if (ec->publicKey.curve->type == EC_EDWARDS) {
2651         /* Unsigned */
2652         keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2653     } else {
2654         /* Signed */
2655         keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2656     }
2657
2658     /*
2659      * mpint privateKey. Total 4 + keylen.
2660      */
2661     bloblen = 4 + keylen;
2662     blob = snewn(bloblen, unsigned char);
2663
2664     p = blob;
2665     PUT_32BIT(p, keylen);
2666     p += 4;
2667     if (ec->publicKey.curve->type == EC_EDWARDS) {
2668         /* Little endian */
2669         for (i = 0; i < keylen; ++i)
2670             *p++ = bignum_byte(ec->privateKey, i);
2671     } else {
2672         for (i = keylen; i--;)
2673             *p++ = bignum_byte(ec->privateKey, i);
2674     }
2675
2676     assert(p == blob + bloblen);
2677     *len = bloblen;
2678     return blob;
2679 }
2680
2681 static void *ecdsa_createkey(const struct ssh_signkey *self,
2682                              const unsigned char *pub_blob, int pub_len,
2683                              const unsigned char *priv_blob, int priv_len)
2684 {
2685     struct ec_key *ec;
2686     struct ec_point *publicKey;
2687     const char *pb = (const char *) priv_blob;
2688
2689     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) pub_blob, pub_len);
2690     if (!ec) {
2691         return NULL;
2692     }
2693
2694     if (ec->publicKey.curve->type != EC_WEIERSTRASS
2695         && ec->publicKey.curve->type != EC_EDWARDS) {
2696         ecdsa_freekey(ec);
2697         return NULL;
2698     }
2699
2700     if (ec->publicKey.curve->type == EC_EDWARDS) {
2701         ec->privateKey = getmp_le(&pb, &priv_len);
2702     } else {
2703         ec->privateKey = getmp(&pb, &priv_len);
2704     }
2705     if (!ec->privateKey) {
2706         ecdsa_freekey(ec);
2707         return NULL;
2708     }
2709
2710     /* Check that private key generates public key */
2711     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2712
2713     if (!publicKey ||
2714         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2715         bignum_cmp(publicKey->y, ec->publicKey.y))
2716     {
2717         ecdsa_freekey(ec);
2718         ec = NULL;
2719     }
2720     ec_point_free(publicKey);
2721
2722     return ec;
2723 }
2724
2725 static void *ed25519_openssh_createkey(const struct ssh_signkey *self,
2726                                        const unsigned char **blob, int *len)
2727 {
2728     struct ec_key *ec;
2729     struct ec_point *publicKey;
2730     const char *p, *q;
2731     int plen, qlen;
2732
2733     getstring((const char**)blob, len, &p, &plen);
2734     if (!p)
2735     {
2736         return NULL;
2737     }
2738
2739     ec = snew(struct ec_key);
2740     if (!ec)
2741     {
2742         return NULL;
2743     }
2744
2745     ec->signalg = self;
2746     ec->publicKey.curve = ec_ed25519();
2747     ec->publicKey.infinity = 0;
2748     ec->privateKey = NULL;
2749     ec->publicKey.x = NULL;
2750     ec->publicKey.z = NULL;
2751     ec->publicKey.y = NULL;
2752
2753     if (!decodepoint_ed(p, plen, &ec->publicKey))
2754     {
2755         ecdsa_freekey(ec);
2756         return NULL;
2757     }
2758
2759     getstring((const char**)blob, len, &q, &qlen);
2760     if (!q)
2761         return NULL;
2762     if (qlen != 64)
2763         return NULL;
2764
2765     ec->privateKey = bignum_from_bytes_le((const unsigned char *)q, 32);
2766     if (!ec->privateKey) {
2767         ecdsa_freekey(ec);
2768         return NULL;
2769     }
2770
2771     /* Check that private key generates public key */
2772     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2773
2774     if (!publicKey ||
2775         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2776         bignum_cmp(publicKey->y, ec->publicKey.y))
2777     {
2778         ecdsa_freekey(ec);
2779         ec = NULL;
2780     }
2781     ec_point_free(publicKey);
2782
2783     /* The OpenSSH format for ed25519 private keys also for some
2784      * reason encodes an extra copy of the public key in the second
2785      * half of the secret-key string. Check that that's present and
2786      * correct as well, otherwise the key we think we've imported
2787      * won't behave identically to the way OpenSSH would have treated
2788      * it. */
2789     if (plen != 32 || 0 != memcmp(q + 32, p, 32)) {
2790         ecdsa_freekey(ec);
2791         return NULL;
2792     }
2793
2794     return ec;
2795 }
2796
2797 static int ed25519_openssh_fmtkey(void *key, unsigned char *blob, int len)
2798 {
2799     struct ec_key *ec = (struct ec_key *) key;
2800
2801     int pointlen;
2802     int keylen;
2803     int bloblen;
2804     int i;
2805
2806     if (ec->publicKey.curve->type != EC_EDWARDS) {
2807         return 0;
2808     }
2809
2810     pointlen = (bignum_bitcount(ec->publicKey.y) + 7) / 8;
2811     keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2812     bloblen = 4 + pointlen + 4 + keylen + pointlen;
2813
2814     if (bloblen > len)
2815         return bloblen;
2816
2817     /* Encode the public point */
2818     PUT_32BIT(blob, pointlen);
2819     blob += 4;
2820
2821     for (i = 0; i < pointlen - 1; ++i) {
2822          *blob++ = bignum_byte(ec->publicKey.y, i);
2823     }
2824     /* Unset last bit of y and set first bit of x in its place */
2825     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2826     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2827
2828     PUT_32BIT(blob, keylen + pointlen);
2829     blob += 4;
2830     for (i = 0; i < keylen; ++i) {
2831          *blob++ = bignum_byte(ec->privateKey, i);
2832     }
2833     /* Now encode an extra copy of the public point as the second half
2834      * of the private key string, as the OpenSSH format for some
2835      * reason requires */
2836     for (i = 0; i < pointlen - 1; ++i) {
2837          *blob++ = bignum_byte(ec->publicKey.y, i);
2838     }
2839     /* Unset last bit of y and set first bit of x in its place */
2840     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2841     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2842
2843     return bloblen;
2844 }
2845
2846 static void *ecdsa_openssh_createkey(const struct ssh_signkey *self,
2847                                      const unsigned char **blob, int *len)
2848 {
2849     const struct ecsign_extra *extra =
2850         (const struct ecsign_extra *)self->extra;
2851     const char **b = (const char **) blob;
2852     const char *p;
2853     int slen;
2854     struct ec_key *ec;
2855     struct ec_curve *curve;
2856     struct ec_point *publicKey;
2857
2858     getstring(b, len, &p, &slen);
2859
2860     if (!p) {
2861         return NULL;
2862     }
2863     curve = extra->curve();
2864     assert(curve->type == EC_WEIERSTRASS);
2865
2866     ec = snew(struct ec_key);
2867
2868     ec->signalg = self;
2869     ec->publicKey.curve = curve;
2870     ec->publicKey.infinity = 0;
2871     ec->publicKey.x = NULL;
2872     ec->publicKey.y = NULL;
2873     ec->publicKey.z = NULL;
2874     if (!getmppoint(b, len, &ec->publicKey)) {
2875         ecdsa_freekey(ec);
2876         return NULL;
2877     }
2878     ec->privateKey = NULL;
2879
2880     if (!ec->publicKey.x || !ec->publicKey.y ||
2881         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2882         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2883     {
2884         ecdsa_freekey(ec);
2885         return NULL;
2886     }
2887
2888     ec->privateKey = getmp(b, len);
2889     if (ec->privateKey == NULL)
2890     {
2891         ecdsa_freekey(ec);
2892         return NULL;
2893     }
2894
2895     /* Now check that the private key makes the public key */
2896     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2897     if (!publicKey)
2898     {
2899         ecdsa_freekey(ec);
2900         return NULL;
2901     }
2902
2903     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
2904         bignum_cmp(ec->publicKey.y, publicKey->y))
2905     {
2906         /* Private key doesn't make the public key on the given curve */
2907         ecdsa_freekey(ec);
2908         ec_point_free(publicKey);
2909         return NULL;
2910     }
2911
2912     ec_point_free(publicKey);
2913
2914     return ec;
2915 }
2916
2917 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
2918 {
2919     struct ec_key *ec = (struct ec_key *) key;
2920
2921     int pointlen;
2922     int namelen;
2923     int bloblen;
2924     int i;
2925
2926     if (ec->publicKey.curve->type != EC_WEIERSTRASS) {
2927         return 0;
2928     }
2929
2930     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2931     namelen = strlen(ec->publicKey.curve->name);
2932     bloblen =
2933         4 + namelen /* <LEN> nistpXXX */
2934         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
2935         + ssh2_bignum_length(ec->privateKey);
2936
2937     if (bloblen > len)
2938         return bloblen;
2939
2940     bloblen = 0;
2941
2942     PUT_32BIT(blob+bloblen, namelen);
2943     bloblen += 4;
2944     memcpy(blob+bloblen, ec->publicKey.curve->name, namelen);
2945     bloblen += namelen;
2946
2947     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
2948     bloblen += 4;
2949     blob[bloblen++] = 0x04;
2950     for (i = pointlen; i--; )
2951         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
2952     for (i = pointlen; i--; )
2953         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
2954
2955     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2956     PUT_32BIT(blob+bloblen, pointlen);
2957     bloblen += 4;
2958     for (i = pointlen; i--; )
2959         blob[bloblen++] = bignum_byte(ec->privateKey, i);
2960
2961     return bloblen;
2962 }
2963
2964 static int ecdsa_pubkey_bits(const struct ssh_signkey *self,
2965                              const void *blob, int len)
2966 {
2967     struct ec_key *ec;
2968     int ret;
2969
2970     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) blob, len);
2971     if (!ec)
2972         return -1;
2973     ret = ec->publicKey.curve->fieldBits;
2974     ecdsa_freekey(ec);
2975
2976     return ret;
2977 }
2978
2979 static int ecdsa_verifysig(void *key, const char *sig, int siglen,
2980                            const char *data, int datalen)
2981 {
2982     struct ec_key *ec = (struct ec_key *) key;
2983     const struct ecsign_extra *extra =
2984         (const struct ecsign_extra *)ec->signalg->extra;
2985     const char *p;
2986     int slen;
2987     int digestLen;
2988     int ret;
2989
2990     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2991         return 0;
2992
2993     /* Check the signature starts with the algorithm name */
2994     getstring(&sig, &siglen, &p, &slen);
2995     if (!p) {
2996         return 0;
2997     }
2998     if (!match_ssh_id(slen, p, ec->signalg->name)) {
2999         return 0;
3000     }
3001
3002     getstring(&sig, &siglen, &p, &slen);
3003     if (ec->publicKey.curve->type == EC_EDWARDS) {
3004         struct ec_point *r;
3005         Bignum s, h;
3006
3007         /* Check that the signature is two times the length of a point */
3008         if (slen != (ec->publicKey.curve->fieldBits / 8) * 2) {
3009             return 0;
3010         }
3011
3012         /* Check it's the 256 bit field so that SHA512 is the correct hash */
3013         if (ec->publicKey.curve->fieldBits != 256) {
3014             return 0;
3015         }
3016
3017         /* Get the signature */
3018         r = ec_point_new(ec->publicKey.curve, NULL, NULL, NULL, 0);
3019         if (!r) {
3020             return 0;
3021         }
3022         if (!decodepoint(p, ec->publicKey.curve->fieldBits / 8, r)) {
3023             ec_point_free(r);
3024             return 0;
3025         }
3026         s = bignum_from_bytes_le((unsigned char*)p + (ec->publicKey.curve->fieldBits / 8),
3027                                  ec->publicKey.curve->fieldBits / 8);
3028         if (!s) {
3029             ec_point_free(r);
3030             return 0;
3031         }
3032
3033         /* Get the hash of the encoded value of R + encoded value of pk + message */
3034         {
3035             int i, pointlen;
3036             unsigned char b;
3037             unsigned char digest[512 / 8];
3038             SHA512_State hs;
3039             SHA512_Init(&hs);
3040
3041             /* Add encoded r (no need to encode it again, it was in the signature) */
3042             SHA512_Bytes(&hs, p, ec->publicKey.curve->fieldBits / 8);
3043
3044             /* Encode pk and add it */
3045             pointlen = ec->publicKey.curve->fieldBits / 8;
3046             for (i = 0; i < pointlen - 1; ++i) {
3047                 b = bignum_byte(ec->publicKey.y, i);
3048                 SHA512_Bytes(&hs, &b, 1);
3049             }
3050             /* Unset last bit of y and set first bit of x in its place */
3051             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
3052             b |= bignum_bit(ec->publicKey.x, 0) << 7;
3053             SHA512_Bytes(&hs, &b, 1);
3054
3055             /* Add the message itself */
3056             SHA512_Bytes(&hs, data, datalen);
3057
3058             /* Get the hash */
3059             SHA512_Final(&hs, digest);
3060
3061             /* Convert to Bignum */
3062             h = bignum_from_bytes_le(digest, sizeof(digest));
3063             if (!h) {
3064                 ec_point_free(r);
3065                 freebn(s);
3066                 return 0;
3067             }
3068         }
3069
3070         /* Verify sB == r + h*publicKey */
3071         {
3072             struct ec_point *lhs, *rhs, *tmp;
3073
3074             /* lhs = sB */
3075             lhs = ecp_mul(&ec->publicKey.curve->e.B, s);
3076             freebn(s);
3077             if (!lhs) {
3078                 ec_point_free(r);
3079                 freebn(h);
3080                 return 0;
3081             }
3082
3083             /* rhs = r + h*publicKey */
3084             tmp = ecp_mul(&ec->publicKey, h);
3085             freebn(h);
3086             if (!tmp) {
3087                 ec_point_free(lhs);
3088                 ec_point_free(r);
3089                 return 0;
3090             }
3091             rhs = ecp_add(r, tmp, 0);
3092             ec_point_free(r);
3093             ec_point_free(tmp);
3094             if (!rhs) {
3095                 ec_point_free(lhs);
3096                 return 0;
3097             }
3098
3099             /* Check the point is the same */
3100             ret = !bignum_cmp(lhs->x, rhs->x);
3101             if (ret) {
3102                 ret = !bignum_cmp(lhs->y, rhs->y);
3103                 if (ret) {
3104                     ret = 1;
3105                 }
3106             }
3107             ec_point_free(lhs);
3108             ec_point_free(rhs);
3109         }
3110     } else {
3111         Bignum r, s;
3112         unsigned char digest[512 / 8];
3113         void *hashctx;
3114
3115         r = getmp(&p, &slen);
3116         if (!r) return 0;
3117         s = getmp(&p, &slen);
3118         if (!s) {
3119             freebn(r);
3120             return 0;
3121         }
3122
3123         digestLen = extra->hash->hlen;
3124         assert(digestLen <= sizeof(digest));
3125         hashctx = extra->hash->init();
3126         extra->hash->bytes(hashctx, data, datalen);
3127         extra->hash->final(hashctx, digest);
3128
3129         /* Verify the signature */
3130         ret = _ecdsa_verify(&ec->publicKey, digest, digestLen, r, s);
3131
3132         freebn(r);
3133         freebn(s);
3134     }
3135
3136     return ret;
3137 }
3138
3139 static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
3140                                  int *siglen)
3141 {
3142     struct ec_key *ec = (struct ec_key *) key;
3143     const struct ecsign_extra *extra =
3144         (const struct ecsign_extra *)ec->signalg->extra;
3145     unsigned char digest[512 / 8];
3146     int digestLen;
3147     Bignum r = NULL, s = NULL;
3148     unsigned char *buf, *p;
3149     int rlen, slen, namelen;
3150     int i;
3151
3152     if (!ec->privateKey || !ec->publicKey.curve) {
3153         return NULL;
3154     }
3155
3156     if (ec->publicKey.curve->type == EC_EDWARDS) {
3157         struct ec_point *rp;
3158         int pointlen = ec->publicKey.curve->fieldBits / 8;
3159
3160         /* hash = H(sk) (where hash creates 2 * fieldBits)
3161          * b = fieldBits
3162          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
3163          * r = H(h[b/8:b/4] + m)
3164          * R = rB
3165          * S = (r + H(encodepoint(R) + encodepoint(pk) + m) * a) % l */
3166         {
3167             unsigned char hash[512/8];
3168             unsigned char b;
3169             Bignum a;
3170             SHA512_State hs;
3171             SHA512_Init(&hs);
3172
3173             for (i = 0; i < pointlen; ++i) {
3174                 unsigned char b = (unsigned char)bignum_byte(ec->privateKey, i);
3175                 SHA512_Bytes(&hs, &b, 1);
3176             }
3177
3178             SHA512_Final(&hs, hash);
3179
3180             /* The second part is simply turning the hash into a
3181              * Bignum, however the 2^(b-2) bit *must* be set, and the
3182              * bottom 3 bits *must* not be */
3183             hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
3184             hash[31] &= 0x7f; /* Unset above (b-2) */
3185             hash[31] |= 0x40; /* Set 2^(b-2) */
3186             /* Chop off the top part and convert to int */
3187             a = bignum_from_bytes_le(hash, 32);
3188             if (!a) {
3189                 return NULL;
3190             }
3191
3192             SHA512_Init(&hs);
3193             SHA512_Bytes(&hs,
3194                          hash+(ec->publicKey.curve->fieldBits / 8),
3195                          (ec->publicKey.curve->fieldBits / 4)
3196                          - (ec->publicKey.curve->fieldBits / 8));
3197             SHA512_Bytes(&hs, data, datalen);
3198             SHA512_Final(&hs, hash);
3199
3200             r = bignum_from_bytes_le(hash, 512/8);
3201             if (!r) {
3202                 freebn(a);
3203                 return NULL;
3204             }
3205             rp = ecp_mul(&ec->publicKey.curve->e.B, r);
3206             if (!rp) {
3207                 freebn(r);
3208                 freebn(a);
3209                 return NULL;
3210             }
3211
3212             /* Now calculate s */
3213             SHA512_Init(&hs);
3214             /* Encode the point R */
3215             for (i = 0; i < pointlen - 1; ++i) {
3216                 b = bignum_byte(rp->y, i);
3217                 SHA512_Bytes(&hs, &b, 1);
3218             }
3219             /* Unset last bit of y and set first bit of x in its place */
3220             b = bignum_byte(rp->y, i) & 0x7f;
3221             b |= bignum_bit(rp->x, 0) << 7;
3222             SHA512_Bytes(&hs, &b, 1);
3223
3224             /* Encode the point pk */
3225             for (i = 0; i < pointlen - 1; ++i) {
3226                 b = bignum_byte(ec->publicKey.y, i);
3227                 SHA512_Bytes(&hs, &b, 1);
3228             }
3229             /* Unset last bit of y and set first bit of x in its place */
3230             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
3231             b |= bignum_bit(ec->publicKey.x, 0) << 7;
3232             SHA512_Bytes(&hs, &b, 1);
3233
3234             /* Add the message */
3235             SHA512_Bytes(&hs, data, datalen);
3236             SHA512_Final(&hs, hash);
3237
3238             {
3239                 Bignum tmp, tmp2;
3240
3241                 tmp = bignum_from_bytes_le(hash, 512/8);
3242                 if (!tmp) {
3243                     ec_point_free(rp);
3244                     freebn(r);
3245                     freebn(a);
3246                     return NULL;
3247                 }
3248                 tmp2 = modmul(tmp, a, ec->publicKey.curve->e.l);
3249                 freebn(a);
3250                 freebn(tmp);
3251                 if (!tmp2) {
3252                     ec_point_free(rp);
3253                     freebn(r);
3254                     return NULL;
3255                 }
3256                 tmp = bigadd(r, tmp2);
3257                 freebn(r);
3258                 freebn(tmp2);
3259                 if (!tmp) {
3260                     ec_point_free(rp);
3261                     return NULL;
3262                 }
3263                 s = bigmod(tmp, ec->publicKey.curve->e.l);
3264                 freebn(tmp);
3265                 if (!s) {
3266                     ec_point_free(rp);
3267                     return NULL;
3268                 }
3269             }
3270         }
3271
3272         /* Format the output */
3273         namelen = strlen(ec->signalg->name);
3274         *siglen = 4+namelen+4+((ec->publicKey.curve->fieldBits / 8)*2);
3275         buf = snewn(*siglen, unsigned char);
3276         p = buf;
3277         PUT_32BIT(p, namelen);
3278         p += 4;
3279         memcpy(p, ec->signalg->name, namelen);
3280         p += namelen;
3281         PUT_32BIT(p, ((ec->publicKey.curve->fieldBits / 8)*2));
3282         p += 4;
3283
3284         /* Encode the point */
3285         pointlen = ec->publicKey.curve->fieldBits / 8;
3286         for (i = 0; i < pointlen - 1; ++i) {
3287             *p++ = bignum_byte(rp->y, i);
3288         }
3289         /* Unset last bit of y and set first bit of x in its place */
3290         *p = bignum_byte(rp->y, i) & 0x7f;
3291         *p++ |= bignum_bit(rp->x, 0) << 7;
3292         ec_point_free(rp);
3293
3294         /* Encode the int */
3295         for (i = 0; i < pointlen; ++i) {
3296             *p++ = bignum_byte(s, i);
3297         }
3298         freebn(s);
3299     } else {
3300         void *hashctx;
3301
3302         digestLen = extra->hash->hlen;
3303         assert(digestLen <= sizeof(digest));
3304         hashctx = extra->hash->init();
3305         extra->hash->bytes(hashctx, data, datalen);
3306         extra->hash->final(hashctx, digest);
3307
3308         /* Do the signature */
3309         _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
3310         if (!r || !s) {
3311             if (r) freebn(r);
3312             if (s) freebn(s);
3313             return NULL;
3314         }
3315
3316         rlen = (bignum_bitcount(r) + 8) / 8;
3317         slen = (bignum_bitcount(s) + 8) / 8;
3318
3319         namelen = strlen(ec->signalg->name);
3320
3321         /* Format the output */
3322         *siglen = 8+namelen+rlen+slen+8;
3323         buf = snewn(*siglen, unsigned char);
3324         p = buf;
3325         PUT_32BIT(p, namelen);
3326         p += 4;
3327         memcpy(p, ec->signalg->name, namelen);
3328         p += namelen;
3329         PUT_32BIT(p, rlen + slen + 8);
3330         p += 4;
3331         PUT_32BIT(p, rlen);
3332         p += 4;
3333         for (i = rlen; i--;)
3334             *p++ = bignum_byte(r, i);
3335         PUT_32BIT(p, slen);
3336         p += 4;
3337         for (i = slen; i--;)
3338             *p++ = bignum_byte(s, i);
3339
3340         freebn(r);
3341         freebn(s);
3342     }
3343
3344     return buf;
3345 }
3346
3347 const struct ecsign_extra sign_extra_ed25519 = {
3348     ec_ed25519, NULL,
3349     NULL, 0,
3350 };
3351 const struct ssh_signkey ssh_ecdsa_ed25519 = {
3352     ecdsa_newkey,
3353     ecdsa_freekey,
3354     ecdsa_fmtkey,
3355     ecdsa_public_blob,
3356     ecdsa_private_blob,
3357     ecdsa_createkey,
3358     ed25519_openssh_createkey,
3359     ed25519_openssh_fmtkey,
3360     2 /* point, private exponent */,
3361     ecdsa_pubkey_bits,
3362     ecdsa_verifysig,
3363     ecdsa_sign,
3364     "ssh-ed25519",
3365     "ssh-ed25519",
3366     &sign_extra_ed25519,
3367 };
3368
3369 /* OID: 1.2.840.10045.3.1.7 (ansiX9p256r1) */
3370 static const unsigned char nistp256_oid[] = {
3371     0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07
3372 };
3373 const struct ecsign_extra sign_extra_nistp256 = {
3374     ec_p256, &ssh_sha256,
3375     nistp256_oid, lenof(nistp256_oid),
3376 };
3377 const struct ssh_signkey ssh_ecdsa_nistp256 = {
3378     ecdsa_newkey,
3379     ecdsa_freekey,
3380     ecdsa_fmtkey,
3381     ecdsa_public_blob,
3382     ecdsa_private_blob,
3383     ecdsa_createkey,
3384     ecdsa_openssh_createkey,
3385     ecdsa_openssh_fmtkey,
3386     3 /* curve name, point, private exponent */,
3387     ecdsa_pubkey_bits,
3388     ecdsa_verifysig,
3389     ecdsa_sign,
3390     "ecdsa-sha2-nistp256",
3391     "ecdsa-sha2-nistp256",
3392     &sign_extra_nistp256,
3393 };
3394
3395 /* OID: 1.3.132.0.34 (secp384r1) */
3396 static const unsigned char nistp384_oid[] = {
3397     0x2b, 0x81, 0x04, 0x00, 0x22
3398 };
3399 const struct ecsign_extra sign_extra_nistp384 = {
3400     ec_p384, &ssh_sha384,
3401     nistp384_oid, lenof(nistp384_oid),
3402 };
3403 const struct ssh_signkey ssh_ecdsa_nistp384 = {
3404     ecdsa_newkey,
3405     ecdsa_freekey,
3406     ecdsa_fmtkey,
3407     ecdsa_public_blob,
3408     ecdsa_private_blob,
3409     ecdsa_createkey,
3410     ecdsa_openssh_createkey,
3411     ecdsa_openssh_fmtkey,
3412     3 /* curve name, point, private exponent */,
3413     ecdsa_pubkey_bits,
3414     ecdsa_verifysig,
3415     ecdsa_sign,
3416     "ecdsa-sha2-nistp384",
3417     "ecdsa-sha2-nistp384",
3418     &sign_extra_nistp384,
3419 };
3420
3421 /* OID: 1.3.132.0.35 (secp521r1) */
3422 static const unsigned char nistp521_oid[] = {
3423     0x2b, 0x81, 0x04, 0x00, 0x23
3424 };
3425 const struct ecsign_extra sign_extra_nistp521 = {
3426     ec_p521, &ssh_sha512,
3427     nistp521_oid, lenof(nistp521_oid),
3428 };
3429 const struct ssh_signkey ssh_ecdsa_nistp521 = {
3430     ecdsa_newkey,
3431     ecdsa_freekey,
3432     ecdsa_fmtkey,
3433     ecdsa_public_blob,
3434     ecdsa_private_blob,
3435     ecdsa_createkey,
3436     ecdsa_openssh_createkey,
3437     ecdsa_openssh_fmtkey,
3438     3 /* curve name, point, private exponent */,
3439     ecdsa_pubkey_bits,
3440     ecdsa_verifysig,
3441     ecdsa_sign,
3442     "ecdsa-sha2-nistp521",
3443     "ecdsa-sha2-nistp521",
3444     &sign_extra_nistp521,
3445 };
3446
3447 /* ----------------------------------------------------------------------
3448  * Exposed ECDH interface
3449  */
3450
3451 struct eckex_extra {
3452     struct ec_curve *(*curve)(void);
3453 };
3454
3455 static Bignum ecdh_calculate(const Bignum private,
3456                              const struct ec_point *public)
3457 {
3458     struct ec_point *p;
3459     Bignum ret;
3460     p = ecp_mul(public, private);
3461     if (!p) return NULL;
3462     ret = p->x;
3463     p->x = NULL;
3464
3465     if (p->curve->type == EC_MONTGOMERY) {
3466         /* Do conversion in network byte order */
3467         int i;
3468         int bytes = (bignum_bitcount(ret)+7) / 8;
3469         unsigned char *byteorder = snewn(bytes, unsigned char);
3470         if (!byteorder) {
3471             ec_point_free(p);
3472             freebn(ret);
3473             return NULL;
3474         }
3475         for (i = 0; i < bytes; ++i) {
3476             byteorder[i] = bignum_byte(ret, i);
3477         }
3478         freebn(ret);
3479         ret = bignum_from_bytes(byteorder, bytes);
3480         sfree(byteorder);
3481     }
3482
3483     ec_point_free(p);
3484     return ret;
3485 }
3486
3487 void *ssh_ecdhkex_newkey(const struct ssh_kex *kex)
3488 {
3489     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
3490     struct ec_curve *curve;
3491     struct ec_key *key;
3492     struct ec_point *publicKey;
3493
3494     curve = extra->curve();
3495
3496     key = snew(struct ec_key);
3497     if (!key) {
3498         return NULL;
3499     }
3500
3501     key->signalg = NULL;
3502     key->publicKey.curve = curve;
3503
3504     if (curve->type == EC_MONTGOMERY) {
3505         unsigned char bytes[32] = {0};
3506         int i;
3507
3508         for (i = 0; i < sizeof(bytes); ++i)
3509         {
3510             bytes[i] = (unsigned char)random_byte();
3511         }
3512         bytes[0] &= 248;
3513         bytes[31] &= 127;
3514         bytes[31] |= 64;
3515         key->privateKey = bignum_from_bytes(bytes, sizeof(bytes));
3516         for (i = 0; i < sizeof(bytes); ++i)
3517         {
3518             ((volatile char*)bytes)[i] = 0;
3519         }
3520         if (!key->privateKey) {
3521             sfree(key);
3522             return NULL;
3523         }
3524         publicKey = ecp_mul(&key->publicKey.curve->m.G, key->privateKey);
3525         if (!publicKey) {
3526             freebn(key->privateKey);
3527             sfree(key);
3528             return NULL;
3529         }
3530         key->publicKey.x = publicKey->x;
3531         key->publicKey.y = publicKey->y;
3532         key->publicKey.z = NULL;
3533         sfree(publicKey);
3534     } else {
3535         key->privateKey = bignum_random_in_range(One, key->publicKey.curve->w.n);
3536         if (!key->privateKey) {
3537             sfree(key);
3538             return NULL;
3539         }
3540         publicKey = ecp_mul(&key->publicKey.curve->w.G, key->privateKey);
3541         if (!publicKey) {
3542             freebn(key->privateKey);
3543             sfree(key);
3544             return NULL;
3545         }
3546         key->publicKey.x = publicKey->x;
3547         key->publicKey.y = publicKey->y;
3548         key->publicKey.z = NULL;
3549         sfree(publicKey);
3550     }
3551     return key;
3552 }
3553
3554 char *ssh_ecdhkex_getpublic(void *key, int *len)
3555 {
3556     struct ec_key *ec = (struct ec_key*)key;
3557     char *point, *p;
3558     int i;
3559     int pointlen;
3560
3561     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
3562
3563     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
3564         *len = 1 + pointlen * 2;
3565     } else {
3566         *len = pointlen;
3567     }
3568     point = (char*)snewn(*len, char);
3569     if (!point) {
3570         return NULL;
3571     }
3572
3573     p = point;
3574     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
3575         *p++ = 0x04;
3576         for (i = pointlen; i--;) {
3577             *p++ = bignum_byte(ec->publicKey.x, i);
3578         }
3579         for (i = pointlen; i--;) {
3580             *p++ = bignum_byte(ec->publicKey.y, i);
3581         }
3582     } else {
3583         for (i = 0; i < pointlen; ++i) {
3584             *p++ = bignum_byte(ec->publicKey.x, i);
3585         }
3586     }
3587
3588     return point;
3589 }
3590
3591 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
3592 {
3593     struct ec_key *ec = (struct ec_key*) key;
3594     struct ec_point remote;
3595     Bignum ret;
3596
3597     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
3598         remote.curve = ec->publicKey.curve;
3599         remote.infinity = 0;
3600         if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
3601             return NULL;
3602         }
3603     } else {
3604         /* Point length has to be the same length */
3605         if (remoteKeyLen != (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8) {
3606             return NULL;
3607         }
3608
3609         remote.curve = ec->publicKey.curve;
3610         remote.infinity = 0;
3611         remote.x = bignum_from_bytes_le((unsigned char*)remoteKey, remoteKeyLen);
3612         remote.y = NULL;
3613         remote.z = NULL;
3614     }
3615
3616     ret = ecdh_calculate(ec->privateKey, &remote);
3617     if (remote.x) freebn(remote.x);
3618     if (remote.y) freebn(remote.y);
3619     return ret;
3620 }
3621
3622 void ssh_ecdhkex_freekey(void *key)
3623 {
3624     ecdsa_freekey(key);
3625 }
3626
3627 static const struct eckex_extra kex_extra_curve25519 = { ec_curve25519 };
3628 static const struct ssh_kex ssh_ec_kex_curve25519 = {
3629     "curve25519-sha256@libssh.org", NULL, KEXTYPE_ECDH,
3630     &ssh_sha256, &kex_extra_curve25519,
3631 };
3632
3633 const struct eckex_extra kex_extra_nistp256 = { ec_p256 };
3634 static const struct ssh_kex ssh_ec_kex_nistp256 = {
3635     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH,
3636     &ssh_sha256, &kex_extra_nistp256,
3637 };
3638
3639 const struct eckex_extra kex_extra_nistp384 = { ec_p384 };
3640 static const struct ssh_kex ssh_ec_kex_nistp384 = {
3641     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH,
3642     &ssh_sha384, &kex_extra_nistp384,
3643 };
3644
3645 const struct eckex_extra kex_extra_nistp521 = { ec_p521 };
3646 static const struct ssh_kex ssh_ec_kex_nistp521 = {
3647     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH,
3648     &ssh_sha512, &kex_extra_nistp521,
3649 };
3650
3651 static const struct ssh_kex *const ec_kex_list[] = {
3652     &ssh_ec_kex_curve25519,
3653     &ssh_ec_kex_nistp256,
3654     &ssh_ec_kex_nistp384,
3655     &ssh_ec_kex_nistp521,
3656 };
3657
3658 const struct ssh_kexes ssh_ecdh_kex = {
3659     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
3660     ec_kex_list
3661 };
3662
3663 /* ----------------------------------------------------------------------
3664  * Helper functions for finding key algorithms and returning auxiliary
3665  * data.
3666  */
3667
3668 const struct ssh_signkey *ec_alg_by_oid(int len, const void *oid,
3669                                         const struct ec_curve **curve)
3670 {
3671     static const struct ssh_signkey *algs_with_oid[] = {
3672         &ssh_ecdsa_nistp256,
3673         &ssh_ecdsa_nistp384,
3674         &ssh_ecdsa_nistp521,
3675     };
3676     int i;
3677
3678     for (i = 0; i < lenof(algs_with_oid); i++) {
3679         const struct ssh_signkey *alg = algs_with_oid[i];
3680         const struct ecsign_extra *extra =
3681             (const struct ecsign_extra *)alg->extra;
3682         if (len == extra->oidlen && !memcmp(oid, extra->oid, len)) {
3683             *curve = extra->curve();
3684             return alg;
3685         }
3686     }
3687     return NULL;
3688 }
3689
3690 const unsigned char *ec_alg_oid(const struct ssh_signkey *alg,
3691                                 int *oidlen)
3692 {
3693     const struct ecsign_extra *extra = (const struct ecsign_extra *)alg->extra;
3694     *oidlen = extra->oidlen;
3695     return extra->oid;
3696 }
3697
3698 const int ec_nist_alg_and_curve_by_bits(int bits,
3699                                         const struct ec_curve **curve,
3700                                         const struct ssh_signkey **alg)
3701 {
3702     switch (bits) {
3703       case 256: *alg = &ssh_ecdsa_nistp256; break;
3704       case 384: *alg = &ssh_ecdsa_nistp384; break;
3705       case 521: *alg = &ssh_ecdsa_nistp521; break;
3706       default: return FALSE;
3707     }
3708     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
3709     return TRUE;
3710 }
3711
3712 const int ec_ed_alg_and_curve_by_bits(int bits,
3713                                       const struct ec_curve **curve,
3714                                       const struct ssh_signkey **alg)
3715 {
3716     switch (bits) {
3717       case 256: *alg = &ssh_ecdsa_ed25519; break;
3718       default: return FALSE;
3719     }
3720     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
3721     return TRUE;
3722 }