]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
Clear an extra low bit in EdDSA exponent calculation.
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  *
5  * NOTE: Only curves on prime field are handled by the maths functions
6  *       in Weierstrass form using Jacobian co-ordinates.
7  *
8  *       Montgomery form curves are supported for DH. (Curve25519)
9  *
10  *       Edwards form curves are supported for DSA. (Ed25519)
11  */
12
13 /*
14  * References:
15  *
16  * Elliptic curves in SSH are specified in RFC 5656:
17  *   http://tools.ietf.org/html/rfc5656
18  *
19  * That specification delegates details of public key formatting and a
20  * lot of underlying mechanism to SEC 1:
21  *   http://www.secg.org/sec1-v2.pdf
22  *
23  * Montgomery maths from:
24  * Handbook of elliptic and hyperelliptic curve cryptography, Chapter 13
25  *   http://cs.ucsb.edu/~koc/ccs130h/2013/EllipticHyperelliptic-CohenFrey.pdf
26  *
27  * Edwards DSA:
28  *   http://ed25519.cr.yp.to/ed25519-20110926.pdf
29  */
30
31 #include <stdlib.h>
32 #include <assert.h>
33
34 #include "ssh.h"
35
36 /* ----------------------------------------------------------------------
37  * Elliptic curve definitions
38  */
39
40 static int initialise_wcurve(struct ec_curve *curve, int bits, unsigned char *p,
41                              unsigned char *a, unsigned char *b,
42                              unsigned char *n, unsigned char *Gx,
43                              unsigned char *Gy)
44 {
45     int length = bits / 8;
46     if (bits % 8) ++length;
47
48     curve->type = EC_WEIERSTRASS;
49
50     curve->fieldBits = bits;
51     curve->p = bignum_from_bytes(p, length);
52     if (!curve->p) goto error;
53
54     /* Curve co-efficients */
55     curve->w.a = bignum_from_bytes(a, length);
56     if (!curve->w.a) goto error;
57     curve->w.b = bignum_from_bytes(b, length);
58     if (!curve->w.b) goto error;
59
60     /* Group order and generator */
61     curve->w.n = bignum_from_bytes(n, length);
62     if (!curve->w.n) goto error;
63     curve->w.G.x = bignum_from_bytes(Gx, length);
64     if (!curve->w.G.x) goto error;
65     curve->w.G.y = bignum_from_bytes(Gy, length);
66     if (!curve->w.G.y) goto error;
67     curve->w.G.curve = curve;
68     curve->w.G.infinity = 0;
69
70     return 1;
71   error:
72     if (curve->p) freebn(curve->p);
73     if (curve->w.a) freebn(curve->w.a);
74     if (curve->w.b) freebn(curve->w.b);
75     if (curve->w.n) freebn(curve->w.n);
76     if (curve->w.G.x) freebn(curve->w.G.x);
77     return 0;
78 }
79
80 static int initialise_mcurve(struct ec_curve *curve, int bits, unsigned char *p,
81                              unsigned char *a, unsigned char *b,
82                              unsigned char *Gx)
83 {
84     int length = bits / 8;
85     if (bits % 8) ++length;
86
87     curve->type = EC_MONTGOMERY;
88
89     curve->fieldBits = bits;
90     curve->p = bignum_from_bytes(p, length);
91     if (!curve->p) goto error;
92
93     /* Curve co-efficients */
94     curve->m.a = bignum_from_bytes(a, length);
95     if (!curve->m.a) goto error;
96     curve->m.b = bignum_from_bytes(b, length);
97     if (!curve->m.b) goto error;
98
99     /* Generator */
100     curve->m.G.x = bignum_from_bytes(Gx, length);
101     if (!curve->m.G.x) goto error;
102     curve->m.G.y = NULL;
103     curve->m.G.z = NULL;
104     curve->m.G.curve = curve;
105     curve->m.G.infinity = 0;
106
107     return 1;
108   error:
109     if (curve->p) freebn(curve->p);
110     if (curve->m.a) freebn(curve->m.a);
111     if (curve->m.b) freebn(curve->m.b);
112     return 0;
113 }
114
115 static int initialise_ecurve(struct ec_curve *curve, int bits, unsigned char *p,
116                              unsigned char *l, unsigned char *d,
117                              unsigned char *Bx, unsigned char *By)
118 {
119     int length = bits / 8;
120     if (bits % 8) ++length;
121
122     curve->type = EC_EDWARDS;
123
124     curve->fieldBits = bits;
125     curve->p = bignum_from_bytes(p, length);
126     if (!curve->p) goto error;
127
128     /* Curve co-efficients */
129     curve->e.l = bignum_from_bytes(l, length);
130     if (!curve->e.l) goto error;
131     curve->e.d = bignum_from_bytes(d, length);
132     if (!curve->e.d) goto error;
133
134     /* Group order and generator */
135     curve->e.B.x = bignum_from_bytes(Bx, length);
136     if (!curve->e.B.x) goto error;
137     curve->e.B.y = bignum_from_bytes(By, length);
138     if (!curve->e.B.y) goto error;
139     curve->e.B.curve = curve;
140     curve->e.B.infinity = 0;
141
142     return 1;
143   error:
144     if (curve->p) freebn(curve->p);
145     if (curve->e.l) freebn(curve->e.l);
146     if (curve->e.d) freebn(curve->e.d);
147     if (curve->e.B.x) freebn(curve->e.B.x);
148     return 0;
149 }
150
151 unsigned char nistp256_oid[] = {0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07};
152 int nistp256_oid_len = 8;
153 unsigned char nistp384_oid[] = {0x2b, 0x81, 0x04, 0x00, 0x22};
154 int nistp384_oid_len = 5;
155 unsigned char nistp521_oid[] = {0x2b, 0x81, 0x04, 0x00, 0x23};
156 int nistp521_oid_len = 5;
157 unsigned char curve25519_oid[] = {0x06, 0x0A, 0x2B, 0x06, 0x01, 0x04, 0x01, 0x97, 0x55, 0x01, 0x05, 0x01};
158 int curve25519_oid_len = 12;
159
160 struct ec_curve *ec_p256(void)
161 {
162     static struct ec_curve curve = { 0 };
163     static unsigned char initialised = 0;
164
165     if (!initialised)
166     {
167         unsigned char p[] = {
168             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
169             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
170             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
171             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
172         };
173         unsigned char a[] = {
174             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
175             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
176             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
177             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
178         };
179         unsigned char b[] = {
180             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
181             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
182             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
183             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
184         };
185         unsigned char n[] = {
186             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
187             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
188             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
189             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
190         };
191         unsigned char Gx[] = {
192             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
193             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
194             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
195             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
196         };
197         unsigned char Gy[] = {
198             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
199             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
200             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
201             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
202         };
203
204         if (!initialise_wcurve(&curve, 256, p, a, b, n, Gx, Gy)) {
205             return NULL;
206         }
207
208         /* Now initialised, no need to do it again */
209         initialised = 1;
210     }
211
212     return &curve;
213 }
214
215 struct ec_curve *ec_p384(void)
216 {
217     static struct ec_curve curve = { 0 };
218     static unsigned char initialised = 0;
219
220     if (!initialised)
221     {
222         unsigned char p[] = {
223             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
224             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
225             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
226             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
227             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
228             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
229         };
230         unsigned char a[] = {
231             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
232             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
233             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
234             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
235             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
236             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
237         };
238         unsigned char b[] = {
239             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
240             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
241             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
242             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
243             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
244             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
245         };
246         unsigned char n[] = {
247             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
248             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
249             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
250             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
251             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
252             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
253         };
254         unsigned char Gx[] = {
255             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
256             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
257             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
258             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
259             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
260             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
261         };
262         unsigned char Gy[] = {
263             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
264             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
265             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
266             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
267             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
268             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
269         };
270
271         if (!initialise_wcurve(&curve, 384, p, a, b, n, Gx, Gy)) {
272             return NULL;
273         }
274
275         /* Now initialised, no need to do it again */
276         initialised = 1;
277     }
278
279     return &curve;
280 }
281
282 struct ec_curve *ec_p521(void)
283 {
284     static struct ec_curve curve = { 0 };
285     static unsigned char initialised = 0;
286
287     if (!initialised)
288     {
289         unsigned char p[] = {
290             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
291             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
292             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
293             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
294             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
295             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
296             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
297             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
298             0xff, 0xff
299         };
300         unsigned char a[] = {
301             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
302             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
303             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
304             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
305             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
306             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
307             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
308             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
309             0xff, 0xfc
310         };
311         unsigned char b[] = {
312             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
313             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
314             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
315             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
316             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
317             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
318             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
319             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
320             0x3f, 0x00
321         };
322         unsigned char n[] = {
323             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
324             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
325             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
326             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
327             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
328             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
329             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
330             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
331             0x64, 0x09
332         };
333         unsigned char Gx[] = {
334             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
335             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
336             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
337             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
338             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
339             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
340             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
341             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
342             0xbd, 0x66
343         };
344         unsigned char Gy[] = {
345             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
346             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
347             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
348             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
349             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
350             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
351             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
352             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
353             0x66, 0x50
354         };
355
356         if (!initialise_wcurve(&curve, 521, p, a, b, n, Gx, Gy)) {
357             return NULL;
358         }
359
360         /* Now initialised, no need to do it again */
361         initialised = 1;
362     }
363
364     return &curve;
365 }
366
367 struct ec_curve *ec_curve25519(void)
368 {
369     static struct ec_curve curve = { 0 };
370     static unsigned char initialised = 0;
371
372     if (!initialised)
373     {
374         unsigned char p[] = {
375             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
376             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
377             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
378             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
379         };
380         unsigned char a[] = {
381             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
382             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
383             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
384             0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x6d, 0x06
385         };
386         unsigned char b[] = {
387             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
388             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
389             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
390             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01
391         };
392         unsigned char gx[32] = {
393             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
394             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
395             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
396             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09
397         };
398
399         if (!initialise_mcurve(&curve, 256, p, a, b, gx)) {
400             return NULL;
401         }
402
403         /* Now initialised, no need to do it again */
404         initialised = 1;
405     }
406
407     return &curve;
408 }
409 struct ec_curve *ec_ed25519(void)
410 {
411     static struct ec_curve curve = { 0 };
412     static unsigned char initialised = 0;
413
414     if (!initialised)
415     {
416         unsigned char q[] = {
417             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
418             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
419             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
420             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
421         };
422         unsigned char l[32] = {
423             0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
424             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
425             0x14, 0xde, 0xf9, 0xde, 0xa2, 0xf7, 0x9c, 0xd6,
426             0x58, 0x12, 0x63, 0x1a, 0x5c, 0xf5, 0xd3, 0xed
427         };
428         unsigned char d[32] = {
429             0x52, 0x03, 0x6c, 0xee, 0x2b, 0x6f, 0xfe, 0x73,
430             0x8c, 0xc7, 0x40, 0x79, 0x77, 0x79, 0xe8, 0x98,
431             0x00, 0x70, 0x0a, 0x4d, 0x41, 0x41, 0xd8, 0xab,
432             0x75, 0xeb, 0x4d, 0xca, 0x13, 0x59, 0x78, 0xa3
433         };
434         unsigned char Bx[32] = {
435             0x21, 0x69, 0x36, 0xd3, 0xcd, 0x6e, 0x53, 0xfe,
436             0xc0, 0xa4, 0xe2, 0x31, 0xfd, 0xd6, 0xdc, 0x5c,
437             0x69, 0x2c, 0xc7, 0x60, 0x95, 0x25, 0xa7, 0xb2,
438             0xc9, 0x56, 0x2d, 0x60, 0x8f, 0x25, 0xd5, 0x1a
439         };
440         unsigned char By[32] = {
441             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
442             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
443             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
444             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x58
445         };
446
447
448         if (!initialise_ecurve(&curve, 256, q, l, d, Bx, By)) {
449             return NULL;
450         }
451
452         /* Now initialised, no need to do it again */
453         initialised = 1;
454     }
455
456     return &curve;
457 }
458
459 static struct ec_curve *ec_name_to_curve(const char *name, int len) {
460     if (len > 11 && !memcmp(name, "ecdsa-sha2-", 11)) {
461         name += 11;
462         len -= 11;
463     } else if (len > 10 && !memcmp(name, "ecdh-sha2-", 10)) {
464         name += 10;
465         len -= 10;
466     } else if (len == 11 && !memcmp(name, "ssh-ed25519", 11)) {
467         return ec_ed25519();
468     }
469
470     if (len == 8 && !memcmp(name, "nistp", 5)) {
471         name += 5;
472         if (!memcmp(name, "256", 3)) {
473             return ec_p256();
474         } else if (!memcmp(name, "384", 3)) {
475             return ec_p384();
476         } else if (!memcmp(name, "521", 3)) {
477             return ec_p521();
478         }
479     }
480
481     if (len == 28 && !memcmp(name, "curve25519-sha256@libssh.org", 28)) {
482         return ec_curve25519();
483     }
484
485     return NULL;
486 }
487
488 /* Type enumeration for specifying the curve name */
489 enum ec_name_type { EC_TYPE_DSA, EC_TYPE_DH, EC_TYPE_CURVE };
490
491 static int ec_curve_to_name(enum ec_name_type type, const struct ec_curve *curve,
492                             unsigned char *name, int len) {
493     if (curve->type == EC_WEIERSTRASS) {
494         int length, loc;
495         if (type == EC_TYPE_DSA) {
496             length = 19;
497             loc = 16;
498         } else if (type == EC_TYPE_DH) {
499             length = 18;
500             loc = 15;
501         } else {
502             length = 8;
503             loc = 5;
504         }
505
506         /* Return length of string */
507         if (name == NULL) return length;
508
509         /* Not enough space for the name */
510         if (len < length) return 0;
511
512         /* Put the name in the buffer */
513         switch (curve->fieldBits) {
514           case 256:
515             memcpy(name+loc, "256", 3);
516             break;
517           case 384:
518             memcpy(name+loc, "384", 3);
519             break;
520           case 521:
521             memcpy(name+loc, "521", 3);
522             break;
523           default:
524             return 0;
525         }
526
527         if (type == EC_TYPE_DSA) {
528             memcpy(name, "ecdsa-sha2-nistp", 16);
529         } else if (type == EC_TYPE_DH) {
530             memcpy(name, "ecdh-sha2-nistp", 15);
531         } else {
532             memcpy(name, "nistp", 5);
533         }
534
535         return length;
536     } else if (curve->type == EC_EDWARDS) {
537         /* No DH for ed25519 - use Montgomery instead */
538         if (type == EC_TYPE_DH) return 0;
539
540         if (type == EC_TYPE_CURVE) {
541             /* Return length of string */
542             if (name == NULL) return 7;
543
544             /* Not enough space for the name */
545             if (len < 7) return 0;
546
547             /* Unknown curve field */
548             if (curve->fieldBits != 256) return 0;
549
550             memcpy(name, "ed25519", 7);
551             return 7;
552
553         } else {
554             /* Return length of string */
555             if (name == NULL) return 11;
556
557             /* Not enough space for the name */
558             if (len < 11) return 0;
559
560             /* Unknown curve field */
561             if (curve->fieldBits != 256) return 0;
562
563             memcpy(name, "ssh-ed25519", 11);
564             return 11;
565         }
566     } else {
567         /* No DSA for curve25519 */
568         if (type == EC_TYPE_DSA || type == EC_TYPE_CURVE) return 0;
569
570         /* Return length of string */
571         if (name == NULL) return 28;
572
573         /* Not enough space for the name */
574         if (len < 28) return 0;
575
576         /* Unknown curve field */
577         if (curve->fieldBits != 256) return 0;
578
579         memcpy(name, "curve25519-sha256@libssh.org", 28);
580         return 28;
581     }
582 }
583
584 /* Return 1 if a is -3 % p, otherwise return 0
585  * This is used because there are some maths optimisations */
586 static int ec_aminus3(const struct ec_curve *curve)
587 {
588     int ret;
589     Bignum _p;
590
591     if (curve->type != EC_WEIERSTRASS) {
592         return 0;
593     }
594
595     _p = bignum_add_long(curve->w.a, 3);
596     if (!_p) return 0;
597
598     ret = !bignum_cmp(curve->p, _p);
599     freebn(_p);
600     return ret;
601 }
602
603 /* ----------------------------------------------------------------------
604  * Elliptic curve field maths
605  */
606
607 static Bignum ecf_add(const Bignum a, const Bignum b,
608                       const struct ec_curve *curve)
609 {
610     Bignum a1, b1, ab, ret;
611
612     a1 = bigmod(a, curve->p);
613     if (!a1) return NULL;
614     b1 = bigmod(b, curve->p);
615     if (!b1)
616     {
617         freebn(a1);
618         return NULL;
619     }
620
621     ab = bigadd(a1, b1);
622     freebn(a1);
623     freebn(b1);
624     if (!ab) return NULL;
625
626     ret = bigmod(ab, curve->p);
627     freebn(ab);
628
629     return ret;
630 }
631
632 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
633 {
634     return modmul(a, a, curve->p);
635 }
636
637 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
638 {
639     Bignum ret, tmp;
640
641     /* Double */
642     tmp = bignum_lshift(a, 1);
643     if (!tmp) return NULL;
644
645     /* Add itself (i.e. treble) */
646     ret = bigadd(tmp, a);
647     freebn(tmp);
648
649     /* Normalise */
650     while (ret != NULL && bignum_cmp(ret, curve->p) >= 0)
651     {
652         tmp = bigsub(ret, curve->p);
653         freebn(ret);
654         ret = tmp;
655     }
656
657     return ret;
658 }
659
660 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
661 {
662     Bignum ret = bignum_lshift(a, 1);
663     if (!ret) return NULL;
664     if (bignum_cmp(ret, curve->p) >= 0)
665     {
666         Bignum tmp = bigsub(ret, curve->p);
667         freebn(ret);
668         return tmp;
669     }
670     else
671     {
672         return ret;
673     }
674 }
675
676 /* ----------------------------------------------------------------------
677  * Memory functions
678  */
679
680 void ec_point_free(struct ec_point *point)
681 {
682     if (point == NULL) return;
683     point->curve = 0;
684     if (point->x) freebn(point->x);
685     if (point->y) freebn(point->y);
686     if (point->z) freebn(point->z);
687     point->infinity = 0;
688     sfree(point);
689 }
690
691 static struct ec_point *ec_point_new(const struct ec_curve *curve,
692                                      const Bignum x, const Bignum y, const Bignum z,
693                                      unsigned char infinity)
694 {
695     struct ec_point *point = snewn(1, struct ec_point);
696     point->curve = curve;
697     point->x = x;
698     point->y = y;
699     point->z = z;
700     point->infinity = infinity ? 1 : 0;
701     return point;
702 }
703
704 static struct ec_point *ec_point_copy(const struct ec_point *a)
705 {
706     if (a == NULL) return NULL;
707     return ec_point_new(a->curve,
708                         a->x ? copybn(a->x) : NULL,
709                         a->y ? copybn(a->y) : NULL,
710                         a->z ? copybn(a->z) : NULL,
711                         a->infinity);
712 }
713
714 static int ec_point_verify(const struct ec_point *a)
715 {
716     if (a->infinity) {
717         return 1;
718     } else if (a->curve->type == EC_EDWARDS) {
719         /* Check y^2 - x^2 - 1 - d * x^2 * y^2 == 0 */
720         Bignum y2, x2, tmp, tmp2, tmp3;
721         int ret;
722
723         y2 = ecf_square(a->y, a->curve);
724         if (!y2) {
725             return 0;
726         }
727         x2 = ecf_square(a->x, a->curve);
728         if (!x2) {
729             freebn(y2);
730             return 0;
731         }
732         tmp = modmul(a->curve->e.d, x2, a->curve->p);
733         if (!tmp) {
734             freebn(x2);
735             freebn(y2);
736             return 0;
737         }
738         tmp2 = modmul(tmp, y2, a->curve->p);
739         freebn(tmp);
740         if (!tmp2) {
741             freebn(x2);
742             freebn(y2);
743             return 0;
744         }
745         tmp = modsub(y2, x2, a->curve->p);
746         freebn(y2);
747         freebn(x2);
748         if (!tmp) {
749             freebn(tmp2);
750             return 0;
751         }
752         tmp3 = modsub(tmp, tmp2, a->curve->p);
753         freebn(tmp);
754         freebn(tmp2);
755         if (!tmp3) {
756             return 0;
757         }
758         ret = !bignum_cmp(tmp3, One);
759         freebn(tmp3);
760         return ret;
761     } else if (a->curve->type == EC_WEIERSTRASS) {
762         /* Verify y^2 = x^3 + ax + b */
763         int ret = 0;
764
765         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
766
767         Bignum Three = bignum_from_long(3);
768         if (!Three) return 0;
769
770         lhs = modmul(a->y, a->y, a->curve->p);
771         if (!lhs) goto error;
772
773         /* This uses montgomery multiplication to optimise */
774         x3 = modpow(a->x, Three, a->curve->p);
775         freebn(Three);
776         if (!x3) goto error;
777         ax = modmul(a->curve->w.a, a->x, a->curve->p);
778         if (!ax) goto error;
779         x3ax = bigadd(x3, ax);
780         if (!x3ax) goto error;
781         freebn(x3); x3 = NULL;
782         freebn(ax); ax = NULL;
783         x3axm = bigmod(x3ax, a->curve->p);
784         if (!x3axm) goto error;
785         freebn(x3ax); x3ax = NULL;
786         x3axb = bigadd(x3axm, a->curve->w.b);
787         if (!x3axb) goto error;
788         freebn(x3axm); x3axm = NULL;
789         rhs = bigmod(x3axb, a->curve->p);
790         if (!rhs) goto error;
791         freebn(x3axb);
792
793         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
794         freebn(lhs);
795         freebn(rhs);
796
797         return ret;
798
799       error:
800         if (x3) freebn(x3);
801         if (ax) freebn(ax);
802         if (x3ax) freebn(x3ax);
803         if (x3axm) freebn(x3axm);
804         if (x3axb) freebn(x3axb);
805         if (lhs) freebn(lhs);
806         return 0;
807     } else {
808         return 0;
809     }
810 }
811
812 /* ----------------------------------------------------------------------
813  * Elliptic curve point maths
814  */
815
816 /* Returns 1 on success and 0 on memory error */
817 static int ecp_normalise(struct ec_point *a)
818 {
819     if (!a) {
820         /* No point */
821         return 0;
822     }
823
824     if (a->infinity) {
825         /* Point is at infinity - i.e. normalised */
826         return 1;
827     }
828
829     if (a->curve->type == EC_WEIERSTRASS) {
830         /* In Jacobian Coordinates the triple (X, Y, Z) represents
831            the affine point (X / Z^2, Y / Z^3) */
832
833         Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
834
835         if (!a->x || !a->y) {
836             /* No point defined */
837             return 0;
838         } else if (!a->z) {
839             /* Already normalised */
840             return 1;
841         }
842
843         Z2 = ecf_square(a->z, a->curve);
844         if (!Z2) {
845             return 0;
846         }
847         Z2inv = modinv(Z2, a->curve->p);
848         if (!Z2inv) {
849             freebn(Z2);
850             return 0;
851         }
852         tx = modmul(a->x, Z2inv, a->curve->p);
853         freebn(Z2inv);
854         if (!tx) {
855             freebn(Z2);
856             return 0;
857         }
858
859         Z3 = modmul(Z2, a->z, a->curve->p);
860         freebn(Z2);
861         if (!Z3) {
862             freebn(tx);
863             return 0;
864         }
865         Z3inv = modinv(Z3, a->curve->p);
866         freebn(Z3);
867         if (!Z3inv) {
868             freebn(tx);
869             return 0;
870         }
871         ty = modmul(a->y, Z3inv, a->curve->p);
872         freebn(Z3inv);
873         if (!ty) {
874             freebn(tx);
875             return 0;
876         }
877
878         freebn(a->x);
879         a->x = tx;
880         freebn(a->y);
881         a->y = ty;
882         freebn(a->z);
883         a->z = NULL;
884         return 1;
885     } else if (a->curve->type == EC_MONTGOMERY) {
886         /* In Montgomery (X : Z) represents the x co-ord (X / Z, ?) */
887
888         Bignum tmp, tmp2;
889
890         if (!a->x) {
891             /* No point defined */
892             return 0;
893         } else if (!a->z) {
894             /* Already normalised */
895             return 1;
896         }
897
898         tmp = modinv(a->z, a->curve->p);
899         if (!tmp) {
900             return 0;
901         }
902         tmp2 = modmul(a->x, tmp, a->curve->p);
903         freebn(tmp);
904         if (!tmp2) {
905             return 0;
906         }
907
908         freebn(a->z);
909         a->z = NULL;
910         freebn(a->x);
911         a->x = tmp2;
912         return 1;
913     } else if (a->curve->type == EC_EDWARDS) {
914         /* Always normalised */
915         return 1;
916     } else {
917         return 0;
918     }
919 }
920
921 static struct ec_point *ecp_doublew(const struct ec_point *a, const int aminus3)
922 {
923     Bignum S, M, outx, outy, outz;
924
925     if (bignum_cmp(a->y, Zero) == 0)
926     {
927         /* Identity */
928         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
929     }
930
931     /* S = 4*X*Y^2 */
932     {
933         Bignum Y2, XY2, _2XY2;
934
935         Y2 = ecf_square(a->y, a->curve);
936         if (!Y2) {
937             return NULL;
938         }
939         XY2 = modmul(a->x, Y2, a->curve->p);
940         freebn(Y2);
941         if (!XY2) {
942             return NULL;
943         }
944
945         _2XY2 = ecf_double(XY2, a->curve);
946         freebn(XY2);
947         if (!_2XY2) {
948             return NULL;
949         }
950         S = ecf_double(_2XY2, a->curve);
951         freebn(_2XY2);
952         if (!S) {
953             return NULL;
954         }
955     }
956
957     /* Faster calculation if a = -3 */
958     if (aminus3) {
959         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
960         Bignum Z2, XpZ2, XmZ2, second;
961
962         if (a->z == NULL) {
963             Z2 = copybn(One);
964         } else {
965             Z2 = ecf_square(a->z, a->curve);
966         }
967         if (!Z2) {
968             freebn(S);
969             return NULL;
970         }
971
972         XpZ2 = ecf_add(a->x, Z2, a->curve);
973         if (!XpZ2) {
974             freebn(S);
975             freebn(Z2);
976             return NULL;
977         }
978         XmZ2 = modsub(a->x, Z2, a->curve->p);
979         freebn(Z2);
980         if (!XmZ2) {
981             freebn(S);
982             freebn(XpZ2);
983             return NULL;
984         }
985
986         second = modmul(XpZ2, XmZ2, a->curve->p);
987         freebn(XpZ2);
988         freebn(XmZ2);
989         if (!second) {
990             freebn(S);
991             return NULL;
992         }
993
994         M = ecf_treble(second, a->curve);
995         freebn(second);
996         if (!M) {
997             freebn(S);
998             return NULL;
999         }
1000     } else {
1001         /* M = 3*X^2 + a*Z^4 */
1002         Bignum _3X2, X2, aZ4;
1003
1004         if (a->z == NULL) {
1005             aZ4 = copybn(a->curve->w.a);
1006         } else {
1007             Bignum Z2, Z4;
1008
1009             Z2 = ecf_square(a->z, a->curve);
1010             if (!Z2) {
1011                 freebn(S);
1012                 return NULL;
1013             }
1014             Z4 = ecf_square(Z2, a->curve);
1015             freebn(Z2);
1016             if (!Z4) {
1017                 freebn(S);
1018                 return NULL;
1019             }
1020             aZ4 = modmul(a->curve->w.a, Z4, a->curve->p);
1021             freebn(Z4);
1022         }
1023         if (!aZ4) {
1024             freebn(S);
1025             return NULL;
1026         }
1027
1028         X2 = modmul(a->x, a->x, a->curve->p);
1029         if (!X2) {
1030             freebn(S);
1031             freebn(aZ4);
1032             return NULL;
1033         }
1034         _3X2 = ecf_treble(X2, a->curve);
1035         freebn(X2);
1036         if (!_3X2) {
1037             freebn(S);
1038             freebn(aZ4);
1039             return NULL;
1040         }
1041         M = ecf_add(_3X2, aZ4, a->curve);
1042         freebn(_3X2);
1043         freebn(aZ4);
1044         if (!M) {
1045             freebn(S);
1046             return NULL;
1047         }
1048     }
1049
1050     /* X' = M^2 - 2*S */
1051     {
1052         Bignum M2, _2S;
1053
1054         M2 = ecf_square(M, a->curve);
1055         if (!M2) {
1056             freebn(S);
1057             freebn(M);
1058             return NULL;
1059         }
1060
1061         _2S = ecf_double(S, a->curve);
1062         if (!_2S) {
1063             freebn(M2);
1064             freebn(S);
1065             freebn(M);
1066             return NULL;
1067         }
1068
1069         outx = modsub(M2, _2S, a->curve->p);
1070         freebn(M2);
1071         freebn(_2S);
1072         if (!outx) {
1073             freebn(S);
1074             freebn(M);
1075             return NULL;
1076         }
1077     }
1078
1079     /* Y' = M*(S - X') - 8*Y^4 */
1080     {
1081         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
1082
1083         SX = modsub(S, outx, a->curve->p);
1084         freebn(S);
1085         if (!SX) {
1086             freebn(M);
1087             freebn(outx);
1088             return NULL;
1089         }
1090         MSX = modmul(M, SX, a->curve->p);
1091         freebn(SX);
1092         freebn(M);
1093         if (!MSX) {
1094             freebn(outx);
1095             return NULL;
1096         }
1097         Y2 = ecf_square(a->y, a->curve);
1098         if (!Y2) {
1099             freebn(outx);
1100             freebn(MSX);
1101             return NULL;
1102         }
1103         Y4 = ecf_square(Y2, a->curve);
1104         freebn(Y2);
1105         if (!Y4) {
1106             freebn(outx);
1107             freebn(MSX);
1108             return NULL;
1109         }
1110         Eight = bignum_from_long(8);
1111         if (!Eight) {
1112             freebn(outx);
1113             freebn(MSX);
1114             freebn(Y4);
1115             return NULL;
1116         }
1117         _8Y4 = modmul(Eight, Y4, a->curve->p);
1118         freebn(Eight);
1119         freebn(Y4);
1120         if (!_8Y4) {
1121             freebn(outx);
1122             freebn(MSX);
1123             return NULL;
1124         }
1125         outy = modsub(MSX, _8Y4, a->curve->p);
1126         freebn(MSX);
1127         freebn(_8Y4);
1128         if (!outy) {
1129             freebn(outx);
1130             return NULL;
1131         }
1132     }
1133
1134     /* Z' = 2*Y*Z */
1135     {
1136         Bignum YZ;
1137
1138         if (a->z == NULL) {
1139             YZ = copybn(a->y);
1140         } else {
1141             YZ = modmul(a->y, a->z, a->curve->p);
1142         }
1143         if (!YZ) {
1144             freebn(outx);
1145             freebn(outy);
1146             return NULL;
1147         }
1148
1149         outz = ecf_double(YZ, a->curve);
1150         freebn(YZ);
1151         if (!outz) {
1152             freebn(outx);
1153             freebn(outy);
1154             return NULL;
1155         }
1156     }
1157
1158     return ec_point_new(a->curve, outx, outy, outz, 0);
1159 }
1160
1161 static struct ec_point *ecp_doublem(const struct ec_point *a)
1162 {
1163     Bignum z, outx, outz, xpz, xmz;
1164
1165     z = a->z;
1166     if (!z) {
1167         z = One;
1168     }
1169
1170     /* 4xz = (x + z)^2 - (x - z)^2 */
1171     {
1172         Bignum tmp;
1173
1174         tmp = ecf_add(a->x, z, a->curve);
1175         if (!tmp) {
1176             return NULL;
1177         }
1178         xpz = ecf_square(tmp, a->curve);
1179         freebn(tmp);
1180         if (!xpz) {
1181             return NULL;
1182         }
1183
1184         tmp = modsub(a->x, z, a->curve->p);
1185         if (!tmp) {
1186             freebn(xpz);
1187             return NULL;
1188         }
1189         xmz = ecf_square(tmp, a->curve);
1190         freebn(tmp);
1191         if (!xmz) {
1192             freebn(xpz);
1193             return NULL;
1194         }
1195     }
1196
1197     /* outx = (x + z)^2 * (x - z)^2 */
1198     outx = modmul(xpz, xmz, a->curve->p);
1199     if (!outx) {
1200         freebn(xpz);
1201         freebn(xmz);
1202         return NULL;
1203     }
1204
1205     /* outz = 4xz * ((x - z)^2 + ((A + 2) / 4)*4xz) */
1206     {
1207         Bignum _4xz, tmp, tmp2, tmp3;
1208
1209         tmp = bignum_from_long(2);
1210         if (!tmp) {
1211             freebn(xpz);
1212             freebn(outx);
1213             freebn(xmz);
1214             return NULL;
1215         }
1216         tmp2 = ecf_add(a->curve->m.a, tmp, a->curve);
1217         freebn(tmp);
1218         if (!tmp2) {
1219             freebn(xpz);
1220             freebn(outx);
1221             freebn(xmz);
1222             return NULL;
1223         }
1224
1225         _4xz = modsub(xpz, xmz, a->curve->p);
1226         freebn(xpz);
1227         if (!_4xz) {
1228             freebn(tmp2);
1229             freebn(outx);
1230             freebn(xmz);
1231             return NULL;
1232         }
1233         tmp = modmul(tmp2, _4xz, a->curve->p);
1234         freebn(tmp2);
1235         if (!tmp) {
1236             freebn(_4xz);
1237             freebn(outx);
1238             freebn(xmz);
1239             return NULL;
1240         }
1241
1242         tmp2 = bignum_from_long(4);
1243         if (!tmp2) {
1244             freebn(tmp);
1245             freebn(_4xz);
1246             freebn(outx);
1247             freebn(xmz);
1248             return NULL;
1249         }
1250         tmp3 = modinv(tmp2, a->curve->p);
1251         freebn(tmp2);
1252         if (!tmp3) {
1253             freebn(tmp);
1254             freebn(_4xz);
1255             freebn(outx);
1256             freebn(xmz);
1257             return NULL;
1258         }
1259         tmp2 = modmul(tmp, tmp3, a->curve->p);
1260         freebn(tmp);
1261         freebn(tmp3);
1262         if (!tmp2) {
1263             freebn(_4xz);
1264             freebn(outx);
1265             freebn(xmz);
1266             return NULL;
1267         }
1268
1269         tmp = ecf_add(xmz, tmp2, a->curve);
1270         freebn(xmz);
1271         freebn(tmp2);
1272         if (!tmp) {
1273             freebn(_4xz);
1274             freebn(outx);
1275             return NULL;
1276         }
1277         outz = modmul(_4xz, tmp, a->curve->p);
1278         freebn(_4xz);
1279         freebn(tmp);
1280         if (!outz) {
1281             freebn(outx);
1282             return NULL;
1283         }
1284     }
1285
1286     return ec_point_new(a->curve, outx, NULL, outz, 0);
1287 }
1288
1289 /* Forward declaration for Edwards curve doubling */
1290 static struct ec_point *ecp_add(const struct ec_point *a,
1291                                 const struct ec_point *b,
1292                                 const int aminus3);
1293
1294 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
1295 {
1296     if (a->infinity)
1297     {
1298         /* Identity */
1299         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
1300     }
1301
1302     if (a->curve->type == EC_EDWARDS)
1303     {
1304         return ecp_add(a, a, aminus3);
1305     }
1306     else if (a->curve->type == EC_WEIERSTRASS)
1307     {
1308         return ecp_doublew(a, aminus3);
1309     }
1310     else
1311     {
1312         return ecp_doublem(a);
1313     }
1314 }
1315
1316 static struct ec_point *ecp_addw(const struct ec_point *a,
1317                                  const struct ec_point *b,
1318                                  const int aminus3)
1319 {
1320     Bignum U1, U2, S1, S2, outx, outy, outz;
1321
1322     /* U1 = X1*Z2^2 */
1323     /* S1 = Y1*Z2^3 */
1324     if (b->z) {
1325         Bignum Z2, Z3;
1326
1327         Z2 = ecf_square(b->z, a->curve);
1328         if (!Z2) {
1329             return NULL;
1330         }
1331         U1 = modmul(a->x, Z2, a->curve->p);
1332         if (!U1) {
1333             freebn(Z2);
1334             return NULL;
1335         }
1336         Z3 = modmul(Z2, b->z, a->curve->p);
1337         freebn(Z2);
1338         if (!Z3) {
1339             freebn(U1);
1340             return NULL;
1341         }
1342         S1 = modmul(a->y, Z3, a->curve->p);
1343         freebn(Z3);
1344         if (!S1) {
1345             freebn(U1);
1346             return NULL;
1347         }
1348     } else {
1349         U1 = copybn(a->x);
1350         if (!U1) {
1351             return NULL;
1352         }
1353         S1 = copybn(a->y);
1354         if (!S1) {
1355             freebn(U1);
1356             return NULL;
1357         }
1358     }
1359
1360     /* U2 = X2*Z1^2 */
1361     /* S2 = Y2*Z1^3 */
1362     if (a->z) {
1363         Bignum Z2, Z3;
1364
1365         Z2 = ecf_square(a->z, b->curve);
1366         if (!Z2) {
1367             freebn(U1);
1368             freebn(S1);
1369             return NULL;
1370         }
1371         U2 = modmul(b->x, Z2, b->curve->p);
1372         if (!U2) {
1373             freebn(U1);
1374             freebn(S1);
1375             freebn(Z2);
1376             return NULL;
1377         }
1378         Z3 = modmul(Z2, a->z, b->curve->p);
1379         freebn(Z2);
1380         if (!Z3) {
1381             freebn(U1);
1382             freebn(S1);
1383             freebn(U2);
1384             return NULL;
1385         }
1386         S2 = modmul(b->y, Z3, b->curve->p);
1387         freebn(Z3);
1388         if (!S2) {
1389             freebn(U1);
1390             freebn(S1);
1391             freebn(U2);
1392             return NULL;
1393         }
1394     } else {
1395         U2 = copybn(b->x);
1396         if (!U2) {
1397             freebn(U1);
1398             freebn(S1);
1399             return NULL;
1400         }
1401         S2 = copybn(b->y);
1402         if (!S2) {
1403             freebn(U1);
1404             freebn(S1);
1405             freebn(U2);
1406             return NULL;
1407         }
1408     }
1409
1410     /* Check if multiplying by self */
1411     if (bignum_cmp(U1, U2) == 0)
1412     {
1413         freebn(U1);
1414         freebn(U2);
1415         if (bignum_cmp(S1, S2) == 0)
1416         {
1417             freebn(S1);
1418             freebn(S2);
1419             return ecp_double(a, aminus3);
1420         }
1421         else
1422         {
1423             freebn(S1);
1424             freebn(S2);
1425             /* Infinity */
1426             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
1427         }
1428     }
1429
1430     {
1431         Bignum H, R, UH2, H3;
1432
1433         /* H = U2 - U1 */
1434         H = modsub(U2, U1, a->curve->p);
1435         freebn(U2);
1436         if (!H) {
1437             freebn(U1);
1438             freebn(S1);
1439             freebn(S2);
1440             return NULL;
1441         }
1442
1443         /* R = S2 - S1 */
1444         R = modsub(S2, S1, a->curve->p);
1445         freebn(S2);
1446         if (!R) {
1447             freebn(H);
1448             freebn(S1);
1449             freebn(U1);
1450             return NULL;
1451         }
1452
1453         /* X3 = R^2 - H^3 - 2*U1*H^2 */
1454         {
1455             Bignum R2, H2, _2UH2, first;
1456
1457             H2 = ecf_square(H, a->curve);
1458             if (!H2) {
1459                 freebn(U1);
1460                 freebn(S1);
1461                 freebn(H);
1462                 freebn(R);
1463                 return NULL;
1464             }
1465             UH2 = modmul(U1, H2, a->curve->p);
1466             freebn(U1);
1467             if (!UH2) {
1468                 freebn(H2);
1469                 freebn(S1);
1470                 freebn(H);
1471                 freebn(R);
1472                 return NULL;
1473             }
1474             H3 = modmul(H2, H, a->curve->p);
1475             freebn(H2);
1476             if (!H3) {
1477                 freebn(UH2);
1478                 freebn(S1);
1479                 freebn(H);
1480                 freebn(R);
1481                 return NULL;
1482             }
1483             R2 = ecf_square(R, a->curve);
1484             if (!R2) {
1485                 freebn(H3);
1486                 freebn(UH2);
1487                 freebn(S1);
1488                 freebn(H);
1489                 freebn(R);
1490                 return NULL;
1491             }
1492             _2UH2 = ecf_double(UH2, a->curve);
1493             if (!_2UH2) {
1494                 freebn(R2);
1495                 freebn(H3);
1496                 freebn(UH2);
1497                 freebn(S1);
1498                 freebn(H);
1499                 freebn(R);
1500                 return NULL;
1501             }
1502             first = modsub(R2, H3, a->curve->p);
1503             freebn(R2);
1504             if (!first) {
1505                 freebn(H3);
1506                 freebn(_2UH2);
1507                 freebn(UH2);
1508                 freebn(S1);
1509                 freebn(H);
1510                 freebn(R);
1511                 return NULL;
1512             }
1513             outx = modsub(first, _2UH2, a->curve->p);
1514             freebn(first);
1515             freebn(_2UH2);
1516             if (!outx) {
1517                 freebn(H3);
1518                 freebn(UH2);
1519                 freebn(S1);
1520                 freebn(H);
1521                 freebn(R);
1522                 return NULL;
1523             }
1524         }
1525
1526         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
1527         {
1528             Bignum RUH2mX, UH2mX, SH3;
1529
1530             UH2mX = modsub(UH2, outx, a->curve->p);
1531             freebn(UH2);
1532             if (!UH2mX) {
1533                 freebn(outx);
1534                 freebn(H3);
1535                 freebn(S1);
1536                 freebn(H);
1537                 freebn(R);
1538                 return NULL;
1539             }
1540             RUH2mX = modmul(R, UH2mX, a->curve->p);
1541             freebn(UH2mX);
1542             freebn(R);
1543             if (!RUH2mX) {
1544                 freebn(outx);
1545                 freebn(H3);
1546                 freebn(S1);
1547                 freebn(H);
1548                 return NULL;
1549             }
1550             SH3 = modmul(S1, H3, a->curve->p);
1551             freebn(S1);
1552             freebn(H3);
1553             if (!SH3) {
1554                 freebn(RUH2mX);
1555                 freebn(outx);
1556                 freebn(H);
1557                 return NULL;
1558             }
1559
1560             outy = modsub(RUH2mX, SH3, a->curve->p);
1561             freebn(RUH2mX);
1562             freebn(SH3);
1563             if (!outy) {
1564                 freebn(outx);
1565                 freebn(H);
1566                 return NULL;
1567             }
1568         }
1569
1570         /* Z3 = H*Z1*Z2 */
1571         if (a->z && b->z) {
1572             Bignum ZZ;
1573
1574             ZZ = modmul(a->z, b->z, a->curve->p);
1575             if (!ZZ) {
1576                 freebn(outx);
1577                 freebn(outy);
1578                 freebn(H);
1579                 return NULL;
1580             }
1581             outz = modmul(H, ZZ, a->curve->p);
1582             freebn(H);
1583             freebn(ZZ);
1584             if (!outz) {
1585                 freebn(outx);
1586                 freebn(outy);
1587                 return NULL;
1588             }
1589         } else if (a->z) {
1590             outz = modmul(H, a->z, a->curve->p);
1591             freebn(H);
1592             if (!outz) {
1593                 freebn(outx);
1594                 freebn(outy);
1595                 return NULL;
1596             }
1597         } else if (b->z) {
1598             outz = modmul(H, b->z, a->curve->p);
1599             freebn(H);
1600             if (!outz) {
1601                 freebn(outx);
1602                 freebn(outy);
1603                 return NULL;
1604             }
1605         } else {
1606             outz = H;
1607         }
1608     }
1609
1610     return ec_point_new(a->curve, outx, outy, outz, 0);
1611 }
1612
1613 static struct ec_point *ecp_addm(const struct ec_point *a,
1614                                  const struct ec_point *b,
1615                                  const struct ec_point *base)
1616 {
1617     Bignum outx, outz, az, bz;
1618
1619     az = a->z;
1620     if (!az) {
1621         az = One;
1622     }
1623     bz = b->z;
1624     if (!bz) {
1625         bz = One;
1626     }
1627
1628     /* a-b is maintained at 1 due to Montgomery ladder implementation */
1629     /* Xa+b = Za-b * ((Xa - Za)*(Xb + Zb) + (Xa + Za)*(Xb - Zb))^2 */
1630     /* Za+b = Xa-b * ((Xa - Za)*(Xb + Zb) - (Xa + Za)*(Xb - Zb))^2 */
1631     {
1632         Bignum tmp, tmp2, tmp3, tmp4;
1633
1634         /* (Xa + Za) * (Xb - Zb) */
1635         tmp = ecf_add(a->x, az, a->curve);
1636         if (!tmp) {
1637             return NULL;
1638         }
1639         tmp2 = modsub(b->x, bz, a->curve->p);
1640         if (!tmp2) {
1641             freebn(tmp);
1642             return NULL;
1643         }
1644         tmp3 = modmul(tmp, tmp2, a->curve->p);
1645         freebn(tmp);
1646         freebn(tmp2);
1647         if (!tmp3) {
1648             return NULL;
1649         }
1650
1651         /* (Xa - Za) * (Xb + Zb) */
1652         tmp = modsub(a->x, az, a->curve->p);
1653         if (!tmp) {
1654             freebn(tmp3);
1655             return NULL;
1656         }
1657         tmp2 = ecf_add(b->x, bz, a->curve);
1658         if (!tmp2) {
1659             freebn(tmp);
1660             freebn(tmp3);
1661             return NULL;
1662         }
1663         tmp4 = modmul(tmp, tmp2, a->curve->p);
1664         freebn(tmp);
1665         freebn(tmp2);
1666         if (!tmp4) {
1667             freebn(tmp3);
1668             return NULL;
1669         }
1670
1671         tmp = ecf_add(tmp3, tmp4, a->curve);
1672         if (!tmp) {
1673             freebn(tmp3);
1674             freebn(tmp4);
1675             return NULL;
1676         }
1677         outx = ecf_square(tmp, a->curve);
1678         freebn(tmp);
1679         if (!outx) {
1680             freebn(tmp3);
1681             freebn(tmp4);
1682             return NULL;
1683         }
1684
1685         tmp = modsub(tmp3, tmp4, a->curve->p);
1686         freebn(tmp3);
1687         freebn(tmp4);
1688         if (!tmp) {
1689             freebn(outx);
1690             return NULL;
1691         }
1692         tmp2 = ecf_square(tmp, a->curve);
1693         freebn(tmp);
1694         if (!tmp2) {
1695             freebn(outx);
1696             return NULL;
1697         }
1698         outz = modmul(base->x, tmp2, a->curve->p);
1699         freebn(tmp2);
1700         if (!outz) {
1701             freebn(outx);
1702             return NULL;
1703         }
1704     }
1705
1706     return ec_point_new(a->curve, outx, NULL, outz, 0);
1707 }
1708
1709 static struct ec_point *ecp_adde(const struct ec_point *a,
1710                                  const struct ec_point *b)
1711 {
1712     Bignum outx, outy, dmul;
1713
1714     /* outx = (a->x * b->y + b->x * a->y) /
1715      *        (1 + a->curve->e.d * a->x * b->x * a->y * b->y) */
1716     {
1717         Bignum tmp, tmp2, tmp3, tmp4;
1718
1719         tmp = modmul(a->x, b->y, a->curve->p);
1720         if (!tmp)
1721         {
1722             return NULL;
1723         }
1724         tmp2 = modmul(b->x, a->y, a->curve->p);
1725         if (!tmp2)
1726         {
1727             freebn(tmp);
1728             return NULL;
1729         }
1730         tmp3 = ecf_add(tmp, tmp2, a->curve);
1731         if (!tmp3)
1732         {
1733             freebn(tmp);
1734             freebn(tmp2);
1735             return NULL;
1736         }
1737
1738         tmp4 = modmul(tmp, tmp2, a->curve->p);
1739         freebn(tmp);
1740         freebn(tmp2);
1741         if (!tmp4)
1742         {
1743             freebn(tmp3);
1744             return NULL;
1745         }
1746         dmul = modmul(a->curve->e.d, tmp4, a->curve->p);
1747         freebn(tmp4);
1748         if (!dmul) {
1749             freebn(tmp3);
1750             return NULL;
1751         }
1752
1753         tmp = ecf_add(One, dmul, a->curve);
1754         if (!tmp)
1755         {
1756             freebn(tmp3);
1757             freebn(dmul);
1758             return NULL;
1759         }
1760         tmp2 = modinv(tmp, a->curve->p);
1761         freebn(tmp);
1762         if (!tmp2)
1763         {
1764             freebn(tmp3);
1765             freebn(dmul);
1766             return NULL;
1767         }
1768
1769         outx = modmul(tmp3, tmp2, a->curve->p);
1770         freebn(tmp3);
1771         freebn(tmp2);
1772         if (!outx)
1773         {
1774             freebn(dmul);
1775             return NULL;
1776         }
1777     }
1778
1779     /* outy = (a->y * b->y + a->x * b->x) /
1780      *        (1 - a->curve->e.d * a->x * b->x * a->y * b->y) */
1781     {
1782         Bignum tmp, tmp2, tmp3, tmp4;
1783
1784         tmp = modsub(One, dmul, a->curve->p);
1785         freebn(dmul);
1786         if (!tmp)
1787         {
1788             freebn(outx);
1789             return NULL;
1790         }
1791
1792         tmp2 = modinv(tmp, a->curve->p);
1793         freebn(tmp);
1794         if (!tmp2)
1795         {
1796             freebn(outx);
1797             return NULL;
1798         }
1799
1800         tmp = modmul(a->y, b->y, a->curve->p);
1801         if (!tmp)
1802         {
1803             freebn(tmp2);
1804             freebn(outx);
1805             return NULL;
1806         }
1807         tmp3 = modmul(a->x, b->x, a->curve->p);
1808         if (!tmp3)
1809         {
1810             freebn(tmp);
1811             freebn(tmp2);
1812             freebn(outx);
1813             return NULL;
1814         }
1815         tmp4 = ecf_add(tmp, tmp3, a->curve);
1816         freebn(tmp);
1817         freebn(tmp3);
1818         if (!tmp4)
1819         {
1820             freebn(tmp2);
1821             freebn(outx);
1822             return NULL;
1823         }
1824
1825         outy = modmul(tmp4, tmp2, a->curve->p);
1826         freebn(tmp4);
1827         freebn(tmp2);
1828         if (!outy)
1829         {
1830             freebn(outx);
1831             return NULL;
1832         }
1833     }
1834
1835     return ec_point_new(a->curve, outx, outy, NULL, 0);
1836 }
1837
1838 static struct ec_point *ecp_add(const struct ec_point *a,
1839                                 const struct ec_point *b,
1840                                 const int aminus3)
1841 {
1842     if (a->curve != b->curve) {
1843         return NULL;
1844     }
1845
1846     /* Check if multiplying by infinity */
1847     if (a->infinity) return ec_point_copy(b);
1848     if (b->infinity) return ec_point_copy(a);
1849
1850     if (a->curve->type == EC_EDWARDS)
1851     {
1852         return ecp_adde(a, b);
1853     }
1854
1855     if (a->curve->type == EC_WEIERSTRASS)
1856     {
1857         return ecp_addw(a, b, aminus3);
1858     }
1859
1860     return NULL;
1861 }
1862
1863 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1864 {
1865     struct ec_point *A, *ret;
1866     int bits, i;
1867
1868     A = ec_point_copy(a);
1869     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1870
1871     bits = bignum_bitcount(b);
1872     for (i = 0; ret != NULL && A != NULL && i < bits; ++i)
1873     {
1874         if (bignum_bit(b, i))
1875         {
1876             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1877             ec_point_free(ret);
1878             ret = tmp;
1879         }
1880         if (i+1 != bits)
1881         {
1882             struct ec_point *tmp = ecp_double(A, aminus3);
1883             ec_point_free(A);
1884             A = tmp;
1885         }
1886     }
1887
1888     if (!A) {
1889         ec_point_free(ret);
1890         ret = NULL;
1891     } else {
1892         ec_point_free(A);
1893     }
1894
1895     return ret;
1896 }
1897
1898 static struct ec_point *ecp_mulw(const struct ec_point *a, const Bignum b)
1899 {
1900     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1901
1902     if (!ecp_normalise(ret)) {
1903         ec_point_free(ret);
1904         return NULL;
1905     }
1906
1907     return ret;
1908 }
1909
1910 static struct ec_point *ecp_mule(const struct ec_point *a, const Bignum b)
1911 {
1912     int i;
1913     struct ec_point *ret;
1914
1915     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1916
1917     for (i = bignum_bitcount(b); i >= 0 && ret; --i)
1918     {
1919         {
1920             struct ec_point *tmp = ecp_double(ret, 0);
1921             ec_point_free(ret);
1922             ret = tmp;
1923         }
1924         if (ret && bignum_bit(b, i))
1925         {
1926             struct ec_point *tmp = ecp_add(ret, a, 0);
1927             ec_point_free(ret);
1928             ret = tmp;
1929         }
1930     }
1931
1932     return ret;
1933 }
1934
1935 static struct ec_point *ecp_mulm(const struct ec_point *p, const Bignum n)
1936 {
1937     struct ec_point *P1, *P2;
1938     int bits, i;
1939
1940     /* P1 <- P and P2 <- [2]P */
1941     P2 = ecp_double(p, 0);
1942     if (!P2) {
1943         return NULL;
1944     }
1945     P1 = ec_point_copy(p);
1946     if (!P1) {
1947         ec_point_free(P2);
1948         return NULL;
1949     }
1950
1951     /* for i = bits âˆ’ 2 down to 0 */
1952     bits = bignum_bitcount(n);
1953     for (i = bits - 2; P1 != NULL && P2 != NULL && i >= 0; --i)
1954     {
1955         if (!bignum_bit(n, i))
1956         {
1957             /* P2 <- P1 + P2 */
1958             struct ec_point *tmp = ecp_addm(P1, P2, p);
1959             ec_point_free(P2);
1960             P2 = tmp;
1961
1962             /* P1 <- [2]P1 */
1963             tmp = ecp_double(P1, 0);
1964             ec_point_free(P1);
1965             P1 = tmp;
1966         }
1967         else
1968         {
1969             /* P1 <- P1 + P2 */
1970             struct ec_point *tmp = ecp_addm(P1, P2, p);
1971             ec_point_free(P1);
1972             P1 = tmp;
1973
1974             /* P2 <- [2]P2 */
1975             tmp = ecp_double(P2, 0);
1976             ec_point_free(P2);
1977             P2 = tmp;
1978         }
1979     }
1980
1981     if (!P2) {
1982         if (P1) ec_point_free(P1);
1983         P1 = NULL;
1984     } else {
1985         ec_point_free(P2);
1986     }
1987
1988     if (!ecp_normalise(P1)) {
1989         ec_point_free(P1);
1990         return NULL;
1991     }
1992
1993     return P1;
1994 }
1995
1996 /* Not static because it is used by sshecdsag.c to generate a new key */
1997 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1998 {
1999     if (a->curve->type == EC_WEIERSTRASS) {
2000         return ecp_mulw(a, b);
2001     } else if (a->curve->type == EC_EDWARDS) {
2002         return ecp_mule(a, b);
2003     } else {
2004         return ecp_mulm(a, b);
2005     }
2006 }
2007
2008 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
2009                                    const struct ec_point *point)
2010 {
2011     struct ec_point *aG, *bP, *ret;
2012     int aminus3;
2013
2014     if (point->curve->type != EC_WEIERSTRASS) {
2015         return NULL;
2016     }
2017
2018     aminus3 = ec_aminus3(point->curve);
2019
2020     aG = ecp_mul_(&point->curve->w.G, a, aminus3);
2021     if (!aG) return NULL;
2022     bP = ecp_mul_(point, b, aminus3);
2023     if (!bP) {
2024         ec_point_free(aG);
2025         return NULL;
2026     }
2027
2028     ret = ecp_add(aG, bP, aminus3);
2029
2030     ec_point_free(aG);
2031     ec_point_free(bP);
2032
2033     if (!ecp_normalise(ret)) {
2034         ec_point_free(ret);
2035         return NULL;
2036     }
2037
2038     return ret;
2039 }
2040 static Bignum *ecp_edx(const struct ec_curve *curve, const Bignum y)
2041 {
2042     /* Get the x value on the given Edwards curve for a given y */
2043     Bignum x, xx;
2044
2045     /* xx = (y^2 - 1) / (d * y^2 + 1) */
2046     {
2047         Bignum tmp, tmp2, tmp3;
2048
2049         tmp = ecf_square(y, curve);
2050         if (!tmp) {
2051             return NULL;
2052         }
2053         tmp2 = modmul(curve->e.d, tmp, curve->p);
2054         if (!tmp2) {
2055             freebn(tmp);
2056             return NULL;
2057         }
2058         tmp3 = ecf_add(tmp2, One, curve);
2059         freebn(tmp2);
2060         if (!tmp3) {
2061             freebn(tmp);
2062             return NULL;
2063         }
2064         tmp2 = modinv(tmp3, curve->p);
2065         freebn(tmp3);
2066         if (!tmp2) {
2067             freebn(tmp);
2068             return NULL;
2069         }
2070
2071         tmp3 = modsub(tmp, One, curve->p);
2072         freebn(tmp);
2073         if (!tmp3) {
2074             freebn(tmp2);
2075             return NULL;
2076         }
2077         xx = modmul(tmp3, tmp2, curve->p);
2078         freebn(tmp3);
2079         freebn(tmp2);
2080         if (!xx) {
2081             return NULL;
2082         }
2083     }
2084
2085     /* x = xx^((p + 3) / 8) */
2086     {
2087         Bignum tmp, tmp2;
2088
2089         tmp = bignum_add_long(curve->p, 3);
2090         if (!tmp) {
2091             freebn(xx);
2092             return NULL;
2093         }
2094         tmp2 = bignum_rshift(tmp, 3);
2095         freebn(tmp);
2096         if (!tmp2) {
2097             freebn(xx);
2098             return NULL;
2099         }
2100         x = modpow(xx, tmp2, curve->p);
2101         freebn(tmp2);
2102         if (!x) {
2103             freebn(xx);
2104             return NULL;
2105         }
2106     }
2107
2108     /* if x^2 - xx != 0 then x = x*(2^((p - 1) / 4)) */
2109     {
2110         Bignum tmp, tmp2;
2111
2112         tmp = ecf_square(x, curve);
2113         if (!tmp) {
2114             freebn(x);
2115             freebn(xx);
2116             return NULL;
2117         }
2118         tmp2 = modsub(tmp, xx, curve->p);
2119         freebn(tmp);
2120         freebn(xx);
2121         if (!tmp2) {
2122             freebn(x);
2123             return NULL;
2124         }
2125         if (bignum_cmp(tmp2, Zero)) {
2126             Bignum tmp3;
2127
2128             freebn(tmp2);
2129
2130             tmp = modsub(curve->p, One, curve->p);
2131             if (!tmp) {
2132                 freebn(x);
2133                 return NULL;
2134             }
2135             tmp2 = bignum_rshift(tmp, 2);
2136             freebn(tmp);
2137             if (!tmp2) {
2138                 freebn(x);
2139                 return NULL;
2140             }
2141             tmp = bignum_from_long(2);
2142             if (!tmp) {
2143                 freebn(tmp2);
2144                 freebn(x);
2145                 return NULL;
2146             }
2147             tmp3 = modpow(tmp, tmp2, curve->p);
2148             freebn(tmp);
2149             freebn(tmp2);
2150             if (!tmp3) {
2151                 freebn(x);
2152                 return NULL;
2153             }
2154
2155             tmp = modmul(x, tmp3, curve->p);
2156             freebn(x);
2157             freebn(tmp3);
2158             x = tmp;
2159             if (!tmp) {
2160                 return NULL;
2161             }
2162         } else {
2163             freebn(tmp2);
2164         }
2165     }
2166
2167     /* if x % 2 != 0 then x = p - x */
2168     if (bignum_bit(x, 0)) {
2169         Bignum tmp = modsub(curve->p, x, curve->p);
2170         freebn(x);
2171         x = tmp;
2172         if (!tmp) {
2173             return NULL;
2174         }
2175     }
2176
2177     return x;
2178 }
2179
2180 /* ----------------------------------------------------------------------
2181  * Public point from private
2182  */
2183
2184 struct ec_point *ec_public(const Bignum privateKey, const struct ec_curve *curve)
2185 {
2186     if (curve->type == EC_WEIERSTRASS) {
2187         return ecp_mul(&curve->w.G, privateKey);
2188     } else if (curve->type == EC_EDWARDS) {
2189         /* hash = H(sk) (where hash creates 2 * fieldBits)
2190          * b = fieldBits
2191          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
2192          * publicKey = aB */
2193         struct ec_point *ret;
2194         unsigned char hash[512/8];
2195         Bignum a;
2196         int i, keylen;
2197         SHA512_State s;
2198         SHA512_Init(&s);
2199
2200         keylen = curve->fieldBits / 8;
2201         for (i = 0; i < keylen; ++i) {
2202             unsigned char b = bignum_byte(privateKey, i);
2203             SHA512_Bytes(&s, &b, 1);
2204         }
2205         SHA512_Final(&s, hash);
2206
2207         /* The second part is simply turning the hash into a Bignum,
2208          * however the 2^(b-2) bit *must* be set, and the bottom 3
2209          * bits *must* not be */
2210         hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
2211         hash[31] &= 0x7f; /* Unset above (b-2) */
2212         hash[31] |= 0x40; /* Set 2^(b-2) */
2213         /* Chop off the top part and convert to int */
2214         a = bignum_from_bytes_le(hash, 32);
2215         if (!a) {
2216             return NULL;
2217         }
2218
2219         ret = ecp_mul(&curve->e.B, a);
2220         freebn(a);
2221         return ret;
2222     } else {
2223         return NULL;
2224     }
2225 }
2226
2227 /* ----------------------------------------------------------------------
2228  * Basic sign and verify routines
2229  */
2230
2231 static int _ecdsa_verify(const struct ec_point *publicKey,
2232                          const unsigned char *data, const int dataLen,
2233                          const Bignum r, const Bignum s)
2234 {
2235     int z_bits, n_bits;
2236     Bignum z;
2237     int valid = 0;
2238
2239     if (publicKey->curve->type != EC_WEIERSTRASS) {
2240         return 0;
2241     }
2242
2243     /* Sanity checks */
2244     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->w.n) >= 0
2245         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->w.n) >= 0)
2246     {
2247         return 0;
2248     }
2249
2250     /* z = left most bitlen(curve->n) of data */
2251     z = bignum_from_bytes(data, dataLen);
2252     if (!z) return 0;
2253     n_bits = bignum_bitcount(publicKey->curve->w.n);
2254     z_bits = bignum_bitcount(z);
2255     if (z_bits > n_bits)
2256     {
2257         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
2258         freebn(z);
2259         z = tmp;
2260         if (!z) return 0;
2261     }
2262
2263     /* Ensure z in range of n */
2264     {
2265         Bignum tmp = bigmod(z, publicKey->curve->w.n);
2266         freebn(z);
2267         z = tmp;
2268         if (!z) return 0;
2269     }
2270
2271     /* Calculate signature */
2272     {
2273         Bignum w, x, u1, u2;
2274         struct ec_point *tmp;
2275
2276         w = modinv(s, publicKey->curve->w.n);
2277         if (!w) {
2278             freebn(z);
2279             return 0;
2280         }
2281         u1 = modmul(z, w, publicKey->curve->w.n);
2282         if (!u1) {
2283             freebn(z);
2284             freebn(w);
2285             return 0;
2286         }
2287         u2 = modmul(r, w, publicKey->curve->w.n);
2288         freebn(w);
2289         if (!u2) {
2290             freebn(z);
2291             freebn(u1);
2292             return 0;
2293         }
2294
2295         tmp = ecp_summul(u1, u2, publicKey);
2296         freebn(u1);
2297         freebn(u2);
2298         if (!tmp) {
2299             freebn(z);
2300             return 0;
2301         }
2302
2303         x = bigmod(tmp->x, publicKey->curve->w.n);
2304         ec_point_free(tmp);
2305         if (!x) {
2306             freebn(z);
2307             return 0;
2308         }
2309
2310         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
2311         freebn(x);
2312     }
2313
2314     freebn(z);
2315
2316     return valid;
2317 }
2318
2319 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
2320                         const unsigned char *data, const int dataLen,
2321                         Bignum *r, Bignum *s)
2322 {
2323     unsigned char digest[20];
2324     int z_bits, n_bits;
2325     Bignum z, k;
2326     struct ec_point *kG;
2327
2328     *r = NULL;
2329     *s = NULL;
2330
2331     if (curve->type != EC_WEIERSTRASS) {
2332         return;
2333     }
2334
2335     /* z = left most bitlen(curve->n) of data */
2336     z = bignum_from_bytes(data, dataLen);
2337     if (!z) return;
2338     n_bits = bignum_bitcount(curve->w.n);
2339     z_bits = bignum_bitcount(z);
2340     if (z_bits > n_bits)
2341     {
2342         Bignum tmp;
2343         tmp = bignum_rshift(z, z_bits - n_bits);
2344         freebn(z);
2345         z = tmp;
2346         if (!z) return;
2347     }
2348
2349     /* Generate k between 1 and curve->n, using the same deterministic
2350      * k generation system we use for conventional DSA. */
2351     SHA_Simple(data, dataLen, digest);
2352     k = dss_gen_k("ECDSA deterministic k generator", curve->w.n, privateKey,
2353                   digest, sizeof(digest));
2354     if (!k) return;
2355
2356     kG = ecp_mul(&curve->w.G, k);
2357     if (!kG) {
2358         freebn(z);
2359         freebn(k);
2360         return;
2361     }
2362
2363     /* r = kG.x mod n */
2364     *r = bigmod(kG->x, curve->w.n);
2365     ec_point_free(kG);
2366     if (!*r) {
2367         freebn(z);
2368         freebn(k);
2369         return;
2370     }
2371
2372     /* s = (z + r * priv)/k mod n */
2373     {
2374         Bignum rPriv, zMod, first, firstMod, kInv;
2375         rPriv = modmul(*r, privateKey, curve->w.n);
2376         if (!rPriv) {
2377             freebn(*r);
2378             freebn(z);
2379             freebn(k);
2380             return;
2381         }
2382         zMod = bigmod(z, curve->w.n);
2383         freebn(z);
2384         if (!zMod) {
2385             freebn(rPriv);
2386             freebn(*r);
2387             freebn(k);
2388             return;
2389         }
2390         first = bigadd(rPriv, zMod);
2391         freebn(rPriv);
2392         freebn(zMod);
2393         if (!first) {
2394             freebn(*r);
2395             freebn(k);
2396             return;
2397         }
2398         firstMod = bigmod(first, curve->w.n);
2399         freebn(first);
2400         if (!firstMod) {
2401             freebn(*r);
2402             freebn(k);
2403             return;
2404         }
2405         kInv = modinv(k, curve->w.n);
2406         freebn(k);
2407         if (!kInv) {
2408             freebn(firstMod);
2409             freebn(*r);
2410             return;
2411         }
2412         *s = modmul(firstMod, kInv, curve->w.n);
2413         freebn(firstMod);
2414         freebn(kInv);
2415         if (!*s) {
2416             freebn(*r);
2417             return;
2418         }
2419     }
2420 }
2421
2422 /* ----------------------------------------------------------------------
2423  * Misc functions
2424  */
2425
2426 static void getstring(const char **data, int *datalen,
2427                       const char **p, int *length)
2428 {
2429     *p = NULL;
2430     if (*datalen < 4)
2431         return;
2432     *length = toint(GET_32BIT(*data));
2433     if (*length < 0)
2434         return;
2435     *datalen -= 4;
2436     *data += 4;
2437     if (*datalen < *length)
2438         return;
2439     *p = *data;
2440     *data += *length;
2441     *datalen -= *length;
2442 }
2443
2444 static Bignum getmp(const char **data, int *datalen)
2445 {
2446     const char *p;
2447     int length;
2448
2449     getstring(data, datalen, &p, &length);
2450     if (!p)
2451         return NULL;
2452     if (p[0] & 0x80)
2453         return NULL;                   /* negative mp */
2454     return bignum_from_bytes((unsigned char *)p, length);
2455 }
2456
2457 static Bignum getmp_le(const char **data, int *datalen)
2458 {
2459     const char *p;
2460     int length;
2461
2462     getstring(data, datalen, &p, &length);
2463     if (!p)
2464         return NULL;
2465     return bignum_from_bytes_le((const unsigned char *)p, length);
2466 }
2467
2468 static int decodepoint_ed(const char *p, int length, struct ec_point *point)
2469 {
2470     /* Got some conversion to do, first read in the y co-ord */
2471     int negative;
2472
2473     point->y = bignum_from_bytes_le((const unsigned char*)p, length);
2474     if (!point->y) {
2475         return 0;
2476     }
2477     if ((unsigned)bignum_bitcount(point->y) > point->curve->fieldBits) {
2478         freebn(point->y);
2479         point->y = NULL;
2480         return 0;
2481     }
2482     /* Read x bit and then reset it */
2483     negative = bignum_bit(point->y, point->curve->fieldBits - 1);
2484     bignum_set_bit(point->y, point->curve->fieldBits - 1, 0);
2485
2486     /* Get the x from the y */
2487     point->x = ecp_edx(point->curve, point->y);
2488     if (!point->x) {
2489         freebn(point->y);
2490         point->y = NULL;
2491         return 0;
2492     }
2493     if (negative) {
2494         Bignum tmp = modsub(point->curve->p, point->x, point->curve->p);
2495         freebn(point->x);
2496         point->x = tmp;
2497         if (!tmp) {
2498             freebn(point->y);
2499             point->y = NULL;
2500             return 0;
2501         }
2502     }
2503
2504     /* Verify the point is on the curve */
2505     if (!ec_point_verify(point)) {
2506         freebn(point->x);
2507         point->x = NULL;
2508         freebn(point->y);
2509         point->y = NULL;
2510         return 0;
2511     }
2512
2513     return 1;
2514 }
2515
2516 static int decodepoint(const char *p, int length, struct ec_point *point)
2517 {
2518     if (point->curve->type == EC_EDWARDS) {
2519         return decodepoint_ed(p, length, point);
2520     }
2521
2522     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
2523         return 0;
2524     /* Skip compression flag */
2525     ++p;
2526     --length;
2527     /* The two values must be equal length */
2528     if (length % 2 != 0) {
2529         point->x = NULL;
2530         point->y = NULL;
2531         point->z = NULL;
2532         return 0;
2533     }
2534     length = length / 2;
2535     point->x = bignum_from_bytes((const unsigned char *)p, length);
2536     if (!point->x) return 0;
2537     p += length;
2538     point->y = bignum_from_bytes((const unsigned char *)p, length);
2539     if (!point->y) {
2540         freebn(point->x);
2541         point->x = NULL;
2542         return 0;
2543     }
2544     point->z = NULL;
2545
2546     /* Verify the point is on the curve */
2547     if (!ec_point_verify(point)) {
2548         freebn(point->x);
2549         point->x = NULL;
2550         freebn(point->y);
2551         point->y = NULL;
2552         return 0;
2553     }
2554
2555     return 1;
2556 }
2557
2558 static int getmppoint(const char **data, int *datalen, struct ec_point *point)
2559 {
2560     const char *p;
2561     int length;
2562
2563     getstring(data, datalen, &p, &length);
2564     if (!p) return 0;
2565     return decodepoint(p, length, point);
2566 }
2567
2568 /* ----------------------------------------------------------------------
2569  * Exposed ECDSA interface
2570  */
2571
2572 static void ecdsa_freekey(void *key)
2573 {
2574     struct ec_key *ec = (struct ec_key *) key;
2575     if (!ec) return;
2576
2577     if (ec->publicKey.x)
2578         freebn(ec->publicKey.x);
2579     if (ec->publicKey.y)
2580         freebn(ec->publicKey.y);
2581     if (ec->publicKey.z)
2582         freebn(ec->publicKey.z);
2583     if (ec->privateKey)
2584         freebn(ec->privateKey);
2585     sfree(ec);
2586 }
2587
2588 static void *ecdsa_newkey(const char *data, int len)
2589 {
2590     const char *p;
2591     int slen;
2592     struct ec_key *ec;
2593     struct ec_curve *curve;
2594
2595     getstring(&data, &len, &p, &slen);
2596
2597     if (!p) {
2598         return NULL;
2599     }
2600     curve = ec_name_to_curve(p, slen);
2601     if (!curve) return NULL;
2602
2603     if (curve->type != EC_WEIERSTRASS && curve->type != EC_EDWARDS) {
2604         return NULL;
2605     }
2606
2607     /* Curve name is duplicated for Weierstrass form */
2608     if (curve->type == EC_WEIERSTRASS) {
2609         getstring(&data, &len, &p, &slen);
2610         if (curve != ec_name_to_curve(p, slen)) return NULL;
2611     }
2612
2613     ec = snew(struct ec_key);
2614
2615     ec->publicKey.curve = curve;
2616     ec->publicKey.infinity = 0;
2617     ec->publicKey.x = NULL;
2618     ec->publicKey.y = NULL;
2619     ec->publicKey.z = NULL;
2620     if (!getmppoint(&data, &len, &ec->publicKey)) {
2621         ecdsa_freekey(ec);
2622         return NULL;
2623     }
2624     ec->privateKey = NULL;
2625
2626     if (!ec->publicKey.x || !ec->publicKey.y ||
2627         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2628         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2629     {
2630         ecdsa_freekey(ec);
2631         ec = NULL;
2632     }
2633
2634     return ec;
2635 }
2636
2637 static char *ecdsa_fmtkey(void *key)
2638 {
2639     struct ec_key *ec = (struct ec_key *) key;
2640     char *p;
2641     int len, i, pos, nibbles;
2642     static const char hex[] = "0123456789abcdef";
2643     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2644         return NULL;
2645
2646     pos = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
2647     if (pos == 0) return NULL;
2648
2649     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
2650     len += pos; /* Curve name */
2651     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
2652     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
2653     p = snewn(len, char);
2654
2655     pos = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, (unsigned char*)p, pos);
2656     pos += sprintf(p + pos, ",0x");
2657     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
2658     if (nibbles < 1)
2659         nibbles = 1;
2660     for (i = nibbles; i--;) {
2661         p[pos++] =
2662             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
2663     }
2664     pos += sprintf(p + pos, ",0x");
2665     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
2666     if (nibbles < 1)
2667         nibbles = 1;
2668     for (i = nibbles; i--;) {
2669         p[pos++] =
2670             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
2671     }
2672     p[pos] = '\0';
2673     return p;
2674 }
2675
2676 static unsigned char *ecdsa_public_blob(void *key, int *len)
2677 {
2678     struct ec_key *ec = (struct ec_key *) key;
2679     int pointlen, bloblen, fullnamelen, namelen;
2680     int i;
2681     unsigned char *blob, *p;
2682
2683     if (ec->publicKey.curve->type == EC_EDWARDS) {
2684         /* Edwards compressed form "ssh-ed25519" point y[:-1] + x[0:1] */
2685         fullnamelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
2686         if (fullnamelen == 0) return NULL;
2687
2688         pointlen = ec->publicKey.curve->fieldBits / 8;
2689
2690         /* Can't handle this in our loop */
2691         if (pointlen < 2) return NULL;
2692
2693         bloblen = 4 + fullnamelen + 4 + pointlen;
2694         blob = snewn(bloblen, unsigned char);
2695         if (!blob) return NULL;
2696
2697         p = blob;
2698         PUT_32BIT(p, fullnamelen);
2699         p += 4;
2700         p += ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, p, fullnamelen);
2701         PUT_32BIT(p, pointlen);
2702         p += 4;
2703
2704         /* Unset last bit of y and set first bit of x in its place */
2705         for (i = 0; i < pointlen - 1; ++i) {
2706             *p++ = bignum_byte(ec->publicKey.y, i);
2707         }
2708         /* Unset last bit of y and set first bit of x in its place */
2709         *p = bignum_byte(ec->publicKey.y, i) & 0x7f;
2710         *p++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2711     } else if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2712         fullnamelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
2713         if (fullnamelen == 0) return NULL;
2714         namelen = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
2715         if (namelen == 0) return NULL;
2716
2717         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2718
2719         /*
2720          * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
2721          */
2722         bloblen = 4 + fullnamelen + 4 + namelen + 4 + 1 + (pointlen * 2);
2723         blob = snewn(bloblen, unsigned char);
2724
2725         p = blob;
2726         PUT_32BIT(p, fullnamelen);
2727         p += 4;
2728         p += ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, p, fullnamelen);
2729         PUT_32BIT(p, namelen);
2730         p += 4;
2731         p += ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, p, namelen);
2732         PUT_32BIT(p, (2 * pointlen) + 1);
2733         p += 4;
2734         *p++ = 0x04;
2735         for (i = pointlen; i--;) {
2736             *p++ = bignum_byte(ec->publicKey.x, i);
2737         }
2738         for (i = pointlen; i--;) {
2739             *p++ = bignum_byte(ec->publicKey.y, i);
2740         }
2741     } else {
2742         return NULL;
2743     }
2744
2745     assert(p == blob + bloblen);
2746     *len = bloblen;
2747
2748     return blob;
2749 }
2750
2751 static unsigned char *ecdsa_private_blob(void *key, int *len)
2752 {
2753     struct ec_key *ec = (struct ec_key *) key;
2754     int keylen, bloblen;
2755     int i;
2756     unsigned char *blob, *p;
2757
2758     if (!ec->privateKey) return NULL;
2759
2760     if (ec->publicKey.curve->type == EC_EDWARDS) {
2761         /* Unsigned */
2762         keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2763     } else {
2764         /* Signed */
2765         keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2766     }
2767
2768     /*
2769      * mpint privateKey. Total 4 + keylen.
2770      */
2771     bloblen = 4 + keylen;
2772     blob = snewn(bloblen, unsigned char);
2773
2774     p = blob;
2775     PUT_32BIT(p, keylen);
2776     p += 4;
2777     if (ec->publicKey.curve->type == EC_EDWARDS) {
2778         /* Little endian */
2779         for (i = 0; i < keylen; ++i)
2780             *p++ = bignum_byte(ec->privateKey, i);
2781     } else {
2782         for (i = keylen; i--;)
2783             *p++ = bignum_byte(ec->privateKey, i);
2784     }
2785
2786     assert(p == blob + bloblen);
2787     *len = bloblen;
2788     return blob;
2789 }
2790
2791 static void *ecdsa_createkey(const unsigned char *pub_blob, int pub_len,
2792                              const unsigned char *priv_blob, int priv_len)
2793 {
2794     struct ec_key *ec;
2795     struct ec_point *publicKey;
2796     const char *pb = (const char *) priv_blob;
2797
2798     ec = (struct ec_key*)ecdsa_newkey((const char *) pub_blob, pub_len);
2799     if (!ec) {
2800         return NULL;
2801     }
2802
2803     if (ec->publicKey.curve->type != EC_WEIERSTRASS
2804         && ec->publicKey.curve->type != EC_EDWARDS) {
2805         ecdsa_freekey(ec);
2806         return NULL;
2807     }
2808
2809     if (ec->publicKey.curve->type == EC_EDWARDS) {
2810         ec->privateKey = getmp_le(&pb, &priv_len);
2811     } else {
2812         ec->privateKey = getmp(&pb, &priv_len);
2813     }
2814     if (!ec->privateKey) {
2815         ecdsa_freekey(ec);
2816         return NULL;
2817     }
2818
2819     /* Check that private key generates public key */
2820     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2821
2822     if (!publicKey ||
2823         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2824         bignum_cmp(publicKey->y, ec->publicKey.y))
2825     {
2826         ecdsa_freekey(ec);
2827         ec = NULL;
2828     }
2829     ec_point_free(publicKey);
2830
2831     return ec;
2832 }
2833
2834 static void *ed25519_openssh_createkey(const unsigned char **blob, int *len)
2835 {
2836     struct ec_key *ec;
2837     struct ec_point *publicKey;
2838     const char *p, *q;
2839     int plen, qlen;
2840
2841     getstring((const char**)blob, len, &p, &plen);
2842     if (!p)
2843     {
2844         return NULL;
2845     }
2846
2847     ec = snew(struct ec_key);
2848     if (!ec)
2849     {
2850         return NULL;
2851     }
2852
2853     ec->publicKey.curve = ec_ed25519();
2854     ec->publicKey.infinity = 0;
2855     ec->privateKey = NULL;
2856     ec->publicKey.x = NULL;
2857     ec->publicKey.z = NULL;
2858     ec->publicKey.y = NULL;
2859
2860     if (!decodepoint_ed(p, plen, &ec->publicKey))
2861     {
2862         ecdsa_freekey(ec);
2863         return NULL;
2864     }
2865
2866     getstring((const char**)blob, len, &q, &qlen);
2867     if (!q)
2868         return NULL;
2869     if (qlen != 64)
2870         return NULL;
2871
2872     ec->privateKey = bignum_from_bytes_le((const unsigned char *)q, 32);
2873     if (!ec->privateKey) {
2874         ecdsa_freekey(ec);
2875         return NULL;
2876     }
2877
2878     /* Check that private key generates public key */
2879     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2880
2881     if (!publicKey ||
2882         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2883         bignum_cmp(publicKey->y, ec->publicKey.y))
2884     {
2885         ecdsa_freekey(ec);
2886         ec = NULL;
2887     }
2888     ec_point_free(publicKey);
2889
2890     /* The OpenSSH format for ed25519 private keys also for some
2891      * reason encodes an extra copy of the public key in the second
2892      * half of the secret-key string. Check that that's present and
2893      * correct as well, otherwise the key we think we've imported
2894      * won't behave identically to the way OpenSSH would have treated
2895      * it. */
2896     if (plen != 32 || 0 != memcmp(q + 32, p, 32)) {
2897         ecdsa_freekey(ec);
2898         return NULL;
2899     }
2900
2901     return ec;
2902 }
2903
2904 static int ed25519_openssh_fmtkey(void *key, unsigned char *blob, int len)
2905 {
2906     struct ec_key *ec = (struct ec_key *) key;
2907
2908     int pointlen;
2909     int keylen;
2910     int bloblen;
2911     int i;
2912
2913     if (ec->publicKey.curve->type != EC_EDWARDS) {
2914         return 0;
2915     }
2916
2917     pointlen = (bignum_bitcount(ec->publicKey.y) + 7) / 8;
2918     keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2919     bloblen = 4 + pointlen + 4 + keylen + pointlen;
2920
2921     if (bloblen > len)
2922         return bloblen;
2923
2924     /* Encode the public point */
2925     PUT_32BIT(blob, pointlen);
2926     blob += 4;
2927
2928     for (i = 0; i < pointlen - 1; ++i) {
2929          *blob++ = bignum_byte(ec->publicKey.y, i);
2930     }
2931     /* Unset last bit of y and set first bit of x in its place */
2932     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2933     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2934
2935     PUT_32BIT(blob, keylen + pointlen);
2936     blob += 4;
2937     for (i = 0; i < keylen; ++i) {
2938          *blob++ = bignum_byte(ec->privateKey, i);
2939     }
2940     /* Now encode an extra copy of the public point as the second half
2941      * of the private key string, as the OpenSSH format for some
2942      * reason requires */
2943     for (i = 0; i < pointlen - 1; ++i) {
2944          *blob++ = bignum_byte(ec->publicKey.y, i);
2945     }
2946     /* Unset last bit of y and set first bit of x in its place */
2947     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2948     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2949
2950     return bloblen;
2951 }
2952
2953 static void *ecdsa_openssh_createkey(const unsigned char **blob, int *len)
2954 {
2955     const char **b = (const char **) blob;
2956     const char *p;
2957     int slen;
2958     struct ec_key *ec;
2959     struct ec_curve *curve;
2960     struct ec_point *publicKey;
2961
2962     getstring(b, len, &p, &slen);
2963
2964     if (!p) {
2965         return NULL;
2966     }
2967     curve = ec_name_to_curve(p, slen);
2968     if (!curve) return NULL;
2969
2970     if (curve->type != EC_WEIERSTRASS) {
2971         return NULL;
2972     }
2973
2974     ec = snew(struct ec_key);
2975
2976     ec->publicKey.curve = curve;
2977     ec->publicKey.infinity = 0;
2978     ec->publicKey.x = NULL;
2979     ec->publicKey.y = NULL;
2980     ec->publicKey.z = NULL;
2981     if (!getmppoint(b, len, &ec->publicKey)) {
2982         ecdsa_freekey(ec);
2983         return NULL;
2984     }
2985     ec->privateKey = NULL;
2986
2987     if (!ec->publicKey.x || !ec->publicKey.y ||
2988         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2989         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2990     {
2991         ecdsa_freekey(ec);
2992         return NULL;
2993     }
2994
2995     ec->privateKey = getmp(b, len);
2996     if (ec->privateKey == NULL)
2997     {
2998         ecdsa_freekey(ec);
2999         return NULL;
3000     }
3001
3002     /* Now check that the private key makes the public key */
3003     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
3004     if (!publicKey)
3005     {
3006         ecdsa_freekey(ec);
3007         return NULL;
3008     }
3009
3010     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
3011         bignum_cmp(ec->publicKey.y, publicKey->y))
3012     {
3013         /* Private key doesn't make the public key on the given curve */
3014         ecdsa_freekey(ec);
3015         ec_point_free(publicKey);
3016         return NULL;
3017     }
3018
3019     ec_point_free(publicKey);
3020
3021     return ec;
3022 }
3023
3024 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
3025 {
3026     struct ec_key *ec = (struct ec_key *) key;
3027
3028     int pointlen;
3029     int namelen;
3030     int bloblen;
3031     int i;
3032
3033     if (ec->publicKey.curve->type != EC_WEIERSTRASS) {
3034         return 0;
3035     }
3036
3037     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
3038     namelen = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
3039     bloblen =
3040         4 + namelen /* <LEN> nistpXXX */
3041         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
3042         + ssh2_bignum_length(ec->privateKey);
3043
3044     if (bloblen > len)
3045         return bloblen;
3046
3047     bloblen = 0;
3048
3049     PUT_32BIT(blob+bloblen, namelen);
3050     bloblen += 4;
3051
3052     bloblen += ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, blob+bloblen, namelen);
3053
3054     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
3055     bloblen += 4;
3056     blob[bloblen++] = 0x04;
3057     for (i = pointlen; i--; )
3058         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
3059     for (i = pointlen; i--; )
3060         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
3061
3062     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
3063     PUT_32BIT(blob+bloblen, pointlen);
3064     bloblen += 4;
3065     for (i = pointlen; i--; )
3066         blob[bloblen++] = bignum_byte(ec->privateKey, i);
3067
3068     return bloblen;
3069 }
3070
3071 static int ecdsa_pubkey_bits(const void *blob, int len)
3072 {
3073     struct ec_key *ec;
3074     int ret;
3075
3076     ec = (struct ec_key*)ecdsa_newkey((const char *) blob, len);
3077     if (!ec)
3078         return -1;
3079     ret = ec->publicKey.curve->fieldBits;
3080     ecdsa_freekey(ec);
3081
3082     return ret;
3083 }
3084
3085 static char *ecdsa_fingerprint(void *key)
3086 {
3087     struct ec_key *ec = (struct ec_key *) key;
3088     struct MD5Context md5c;
3089     unsigned char digest[16], lenbuf[4];
3090     char *ret;
3091     unsigned char *name, *fullname;
3092     int pointlen, namelen, fullnamelen, i, j;
3093
3094     MD5Init(&md5c);
3095
3096     namelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
3097     name = snewn(namelen, unsigned char);
3098     ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, name, namelen);
3099
3100     if (ec->publicKey.curve->type == EC_EDWARDS) {
3101         unsigned char b;
3102
3103         /* Do it with the weird encoding */
3104         PUT_32BIT(lenbuf, namelen);
3105         MD5Update(&md5c, lenbuf, 4);
3106         MD5Update(&md5c, name, namelen);
3107
3108         pointlen = ec->publicKey.curve->fieldBits / 8;
3109         PUT_32BIT(lenbuf, pointlen);
3110         MD5Update(&md5c, lenbuf, 4);
3111         for (i = 0; i < pointlen - 1; ++i) {
3112             b = bignum_byte(ec->publicKey.y, i);
3113             MD5Update(&md5c, &b, 1);
3114         }
3115         /* Unset last bit of y and set first bit of x in its place */
3116         b = bignum_byte(ec->publicKey.y, i) & 0x7f;
3117         b |= bignum_bit(ec->publicKey.x, 0) << 7;
3118         MD5Update(&md5c, &b, 1);
3119     } else if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
3120         fullnamelen = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
3121         fullname = snewn(namelen, unsigned char);
3122         ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, fullname, fullnamelen);
3123
3124         PUT_32BIT(lenbuf, fullnamelen);
3125         MD5Update(&md5c, lenbuf, 4);
3126         MD5Update(&md5c, fullname, fullnamelen);
3127         sfree(fullname);
3128
3129         PUT_32BIT(lenbuf, namelen);
3130         MD5Update(&md5c, lenbuf, 4);
3131         MD5Update(&md5c, name, namelen);
3132
3133         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
3134         PUT_32BIT(lenbuf, 1 + (pointlen * 2));
3135         MD5Update(&md5c, lenbuf, 4);
3136         MD5Update(&md5c, (const unsigned char *)"\x04", 1);
3137         for (i = pointlen; i--; ) {
3138             unsigned char c = bignum_byte(ec->publicKey.x, i);
3139             MD5Update(&md5c, &c, 1);
3140         }
3141         for (i = pointlen; i--; ) {
3142             unsigned char c = bignum_byte(ec->publicKey.y, i);
3143             MD5Update(&md5c, &c, 1);
3144         }
3145     } else {
3146         sfree(name);
3147         return NULL;
3148     }
3149
3150     MD5Final(digest, &md5c);
3151
3152     ret = snewn(namelen + 1 + (16 * 3), char);
3153
3154     i = 0;
3155     memcpy(ret, name, namelen);
3156     i += namelen;
3157     sfree(name);
3158     ret[i++] = ' ';
3159     for (j = 0; j < 16; j++) {
3160         i += sprintf(ret + i, "%s%02x", j ? ":" : "", digest[j]);
3161     }
3162
3163     return ret;
3164 }
3165
3166 static int ecdsa_verifysig(void *key, const char *sig, int siglen,
3167                            const char *data, int datalen)
3168 {
3169     struct ec_key *ec = (struct ec_key *) key;
3170     const char *p;
3171     int slen;
3172     int digestLen;
3173     int ret;
3174
3175     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
3176         return 0;
3177
3178     /* Check the signature curve matches the key curve */
3179     getstring(&sig, &siglen, &p, &slen);
3180     if (!p) {
3181         return 0;
3182     }
3183     if (ec->publicKey.curve != ec_name_to_curve(p, slen)) {
3184         return 0;
3185     }
3186
3187     getstring(&sig, &siglen, &p, &slen);
3188     if (ec->publicKey.curve->type == EC_EDWARDS) {
3189         struct ec_point *r;
3190         Bignum s, h;
3191
3192         /* Check that the signature is two times the length of a point */
3193         if (slen != (ec->publicKey.curve->fieldBits / 8) * 2) {
3194             return 0;
3195         }
3196
3197         /* Check it's the 256 bit field so that SHA512 is the correct hash */
3198         if (ec->publicKey.curve->fieldBits != 256) {
3199             return 0;
3200         }
3201
3202         /* Get the signature */
3203         r = ec_point_new(ec->publicKey.curve, NULL, NULL, NULL, 0);
3204         if (!r) {
3205             return 0;
3206         }
3207         if (!decodepoint(p, ec->publicKey.curve->fieldBits / 8, r)) {
3208             ec_point_free(r);
3209             return 0;
3210         }
3211         s = bignum_from_bytes_le((unsigned char*)p + (ec->publicKey.curve->fieldBits / 8),
3212                                  ec->publicKey.curve->fieldBits / 8);
3213         if (!s) {
3214             ec_point_free(r);
3215             return 0;
3216         }
3217
3218         /* Get the hash of the encoded value of R + encoded value of pk + message */
3219         {
3220             int i, pointlen;
3221             unsigned char b;
3222             unsigned char digest[512 / 8];
3223             SHA512_State hs;
3224             SHA512_Init(&hs);
3225
3226             /* Add encoded r (no need to encode it again, it was in the signature) */
3227             SHA512_Bytes(&hs, p, ec->publicKey.curve->fieldBits / 8);
3228
3229             /* Encode pk and add it */
3230             pointlen = ec->publicKey.curve->fieldBits / 8;
3231             for (i = 0; i < pointlen - 1; ++i) {
3232                 b = bignum_byte(ec->publicKey.y, i);
3233                 SHA512_Bytes(&hs, &b, 1);
3234             }
3235             /* Unset last bit of y and set first bit of x in its place */
3236             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
3237             b |= bignum_bit(ec->publicKey.x, 0) << 7;
3238             SHA512_Bytes(&hs, &b, 1);
3239
3240             /* Add the message itself */
3241             SHA512_Bytes(&hs, data, datalen);
3242
3243             /* Get the hash */
3244             SHA512_Final(&hs, digest);
3245
3246             /* Convert to Bignum */
3247             h = bignum_from_bytes_le(digest, sizeof(digest));
3248             if (!h) {
3249                 ec_point_free(r);
3250                 freebn(s);
3251                 return 0;
3252             }
3253         }
3254
3255         /* Verify sB == r + h*publicKey */
3256         {
3257             struct ec_point *lhs, *rhs, *tmp;
3258
3259             /* lhs = sB */
3260             lhs = ecp_mul(&ec->publicKey.curve->e.B, s);
3261             freebn(s);
3262             if (!lhs) {
3263                 ec_point_free(r);
3264                 freebn(h);
3265                 return 0;
3266             }
3267
3268             /* rhs = r + h*publicKey */
3269             tmp = ecp_mul(&ec->publicKey, h);
3270             freebn(h);
3271             if (!tmp) {
3272                 ec_point_free(lhs);
3273                 ec_point_free(r);
3274                 return 0;
3275             }
3276             rhs = ecp_add(r, tmp, 0);
3277             ec_point_free(r);
3278             ec_point_free(tmp);
3279             if (!rhs) {
3280                 ec_point_free(lhs);
3281                 return 0;
3282             }
3283
3284             /* Check the point is the same */
3285             ret = !bignum_cmp(lhs->x, rhs->x);
3286             if (ret) {
3287                 ret = !bignum_cmp(lhs->y, rhs->y);
3288                 if (ret) {
3289                     ret = 1;
3290                 }
3291             }
3292             ec_point_free(lhs);
3293             ec_point_free(rhs);
3294         }
3295     } else {
3296         Bignum r, s;
3297         unsigned char digest[512 / 8];
3298
3299         r = getmp(&p, &slen);
3300         if (!r) return 0;
3301         s = getmp(&p, &slen);
3302         if (!s) {
3303             freebn(r);
3304             return 0;
3305         }
3306
3307         /* Perform correct hash function depending on curve size */
3308         if (ec->publicKey.curve->fieldBits <= 256) {
3309             SHA256_Simple(data, datalen, digest);
3310             digestLen = 256 / 8;
3311         } else if (ec->publicKey.curve->fieldBits <= 384) {
3312             SHA384_Simple(data, datalen, digest);
3313             digestLen = 384 / 8;
3314         } else {
3315             SHA512_Simple(data, datalen, digest);
3316             digestLen = 512 / 8;
3317         }
3318
3319         /* Verify the signature */
3320         ret = _ecdsa_verify(&ec->publicKey, digest, digestLen, r, s);
3321
3322         freebn(r);
3323         freebn(s);
3324     }
3325
3326     return ret;
3327 }
3328
3329 static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
3330                                  int *siglen)
3331 {
3332     struct ec_key *ec = (struct ec_key *) key;
3333     unsigned char digest[512 / 8];
3334     int digestLen;
3335     Bignum r = NULL, s = NULL;
3336     unsigned char *buf, *p;
3337     int rlen, slen, namelen;
3338     int i;
3339
3340     if (!ec->privateKey || !ec->publicKey.curve) {
3341         return NULL;
3342     }
3343
3344     if (ec->publicKey.curve->type == EC_EDWARDS) {
3345         struct ec_point *rp;
3346         int pointlen = ec->publicKey.curve->fieldBits / 8;
3347
3348         /* hash = H(sk) (where hash creates 2 * fieldBits)
3349          * b = fieldBits
3350          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
3351          * r = H(h[b/8:b/4] + m)
3352          * R = rB
3353          * S = (r + H(encodepoint(R) + encodepoint(pk) + m) * a) % l */
3354         {
3355             unsigned char hash[512/8];
3356             unsigned char b;
3357             Bignum a;
3358             SHA512_State hs;
3359             SHA512_Init(&hs);
3360
3361             for (i = 0; i < pointlen; ++i) {
3362                 unsigned char b = (unsigned char)bignum_byte(ec->privateKey, i);
3363                 SHA512_Bytes(&hs, &b, 1);
3364             }
3365
3366             SHA512_Final(&hs, hash);
3367
3368             /* The second part is simply turning the hash into a
3369              * Bignum, however the 2^(b-2) bit *must* be set, and the
3370              * bottom 3 bits *must* not be */
3371             hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
3372             hash[31] &= 0x7f; /* Unset above (b-2) */
3373             hash[31] |= 0x40; /* Set 2^(b-2) */
3374             /* Chop off the top part and convert to int */
3375             a = bignum_from_bytes_le(hash, 32);
3376             if (!a) {
3377                 return NULL;
3378             }
3379
3380             SHA512_Init(&hs);
3381             SHA512_Bytes(&hs,
3382                          hash+(ec->publicKey.curve->fieldBits / 8),
3383                          (ec->publicKey.curve->fieldBits / 4)
3384                          - (ec->publicKey.curve->fieldBits / 8));
3385             SHA512_Bytes(&hs, data, datalen);
3386             SHA512_Final(&hs, hash);
3387
3388             r = bignum_from_bytes_le(hash, 512/8);
3389             if (!r) {
3390                 freebn(a);
3391                 return NULL;
3392             }
3393             rp = ecp_mul(&ec->publicKey.curve->e.B, r);
3394             if (!rp) {
3395                 freebn(r);
3396                 freebn(a);
3397                 return NULL;
3398             }
3399
3400             /* Now calculate s */
3401             SHA512_Init(&hs);
3402             /* Encode the point R */
3403             for (i = 0; i < pointlen - 1; ++i) {
3404                 b = bignum_byte(rp->y, i);
3405                 SHA512_Bytes(&hs, &b, 1);
3406             }
3407             /* Unset last bit of y and set first bit of x in its place */
3408             b = bignum_byte(rp->y, i) & 0x7f;
3409             b |= bignum_bit(rp->x, 0) << 7;
3410             SHA512_Bytes(&hs, &b, 1);
3411
3412             /* Encode the point pk */
3413             for (i = 0; i < pointlen - 1; ++i) {
3414                 b = bignum_byte(ec->publicKey.y, i);
3415                 SHA512_Bytes(&hs, &b, 1);
3416             }
3417             /* Unset last bit of y and set first bit of x in its place */
3418             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
3419             b |= bignum_bit(ec->publicKey.x, 0) << 7;
3420             SHA512_Bytes(&hs, &b, 1);
3421
3422             /* Add the message */
3423             SHA512_Bytes(&hs, data, datalen);
3424             SHA512_Final(&hs, hash);
3425
3426             {
3427                 Bignum tmp, tmp2;
3428
3429                 tmp = bignum_from_bytes_le(hash, 512/8);
3430                 if (!tmp) {
3431                     ec_point_free(rp);
3432                     freebn(r);
3433                     freebn(a);
3434                     return NULL;
3435                 }
3436                 tmp2 = modmul(tmp, a, ec->publicKey.curve->e.l);
3437                 freebn(a);
3438                 freebn(tmp);
3439                 if (!tmp2) {
3440                     ec_point_free(rp);
3441                     freebn(r);
3442                     return NULL;
3443                 }
3444                 tmp = bigadd(r, tmp2);
3445                 freebn(r);
3446                 freebn(tmp2);
3447                 if (!tmp) {
3448                     ec_point_free(rp);
3449                     return NULL;
3450                 }
3451                 s = bigmod(tmp, ec->publicKey.curve->e.l);
3452                 freebn(tmp);
3453                 if (!s) {
3454                     ec_point_free(rp);
3455                     return NULL;
3456                 }
3457             }
3458         }
3459
3460         /* Format the output */
3461         namelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
3462         *siglen = 4+namelen+4+((ec->publicKey.curve->fieldBits / 8)*2);
3463         buf = snewn(*siglen, unsigned char);
3464         p = buf;
3465         PUT_32BIT(p, namelen);
3466         p += 4;
3467         p += ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, p, namelen);
3468         PUT_32BIT(p, ((ec->publicKey.curve->fieldBits / 8)*2));
3469         p += 4;
3470
3471         /* Encode the point */
3472         pointlen = ec->publicKey.curve->fieldBits / 8;
3473         for (i = 0; i < pointlen - 1; ++i) {
3474             *p++ = bignum_byte(rp->y, i);
3475         }
3476         /* Unset last bit of y and set first bit of x in its place */
3477         *p = bignum_byte(rp->y, i) & 0x7f;
3478         *p++ |= bignum_bit(rp->x, 0) << 7;
3479         ec_point_free(rp);
3480
3481         /* Encode the int */
3482         for (i = 0; i < pointlen; ++i) {
3483             *p++ = bignum_byte(s, i);
3484         }
3485         freebn(s);
3486     } else {
3487         /* Perform correct hash function depending on curve size */
3488         if (ec->publicKey.curve->fieldBits <= 256) {
3489             SHA256_Simple(data, datalen, digest);
3490             digestLen = 256 / 8;
3491         } else if (ec->publicKey.curve->fieldBits <= 384) {
3492             SHA384_Simple(data, datalen, digest);
3493             digestLen = 384 / 8;
3494         } else {
3495             SHA512_Simple(data, datalen, digest);
3496             digestLen = 512 / 8;
3497         }
3498
3499         /* Do the signature */
3500         _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
3501         if (!r || !s) {
3502             if (r) freebn(r);
3503             if (s) freebn(s);
3504             return NULL;
3505         }
3506
3507         rlen = (bignum_bitcount(r) + 8) / 8;
3508         slen = (bignum_bitcount(s) + 8) / 8;
3509
3510         namelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
3511
3512         /* Format the output */
3513         *siglen = 8+namelen+rlen+slen+8;
3514         buf = snewn(*siglen, unsigned char);
3515         p = buf;
3516         PUT_32BIT(p, namelen);
3517         p += 4;
3518         p += ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, p, namelen);
3519         PUT_32BIT(p, rlen + slen + 8);
3520         p += 4;
3521         PUT_32BIT(p, rlen);
3522         p += 4;
3523         for (i = rlen; i--;)
3524             *p++ = bignum_byte(r, i);
3525         PUT_32BIT(p, slen);
3526         p += 4;
3527         for (i = slen; i--;)
3528             *p++ = bignum_byte(s, i);
3529
3530         freebn(r);
3531         freebn(s);
3532     }
3533
3534     return buf;
3535 }
3536
3537 const struct ssh_signkey ssh_ecdsa_ed25519 = {
3538     ecdsa_newkey,
3539     ecdsa_freekey,
3540     ecdsa_fmtkey,
3541     ecdsa_public_blob,
3542     ecdsa_private_blob,
3543     ecdsa_createkey,
3544     ed25519_openssh_createkey,
3545     ed25519_openssh_fmtkey,
3546     2 /* point, private exponent */,
3547     ecdsa_pubkey_bits,
3548     ecdsa_fingerprint,
3549     ecdsa_verifysig,
3550     ecdsa_sign,
3551     "ssh-ed25519",
3552     "ssh-ed25519",
3553 };
3554
3555 const struct ssh_signkey ssh_ecdsa_nistp256 = {
3556     ecdsa_newkey,
3557     ecdsa_freekey,
3558     ecdsa_fmtkey,
3559     ecdsa_public_blob,
3560     ecdsa_private_blob,
3561     ecdsa_createkey,
3562     ecdsa_openssh_createkey,
3563     ecdsa_openssh_fmtkey,
3564     3 /* curve name, point, private exponent */,
3565     ecdsa_pubkey_bits,
3566     ecdsa_fingerprint,
3567     ecdsa_verifysig,
3568     ecdsa_sign,
3569     "ecdsa-sha2-nistp256",
3570     "ecdsa-sha2-nistp256",
3571 };
3572
3573 const struct ssh_signkey ssh_ecdsa_nistp384 = {
3574     ecdsa_newkey,
3575     ecdsa_freekey,
3576     ecdsa_fmtkey,
3577     ecdsa_public_blob,
3578     ecdsa_private_blob,
3579     ecdsa_createkey,
3580     ecdsa_openssh_createkey,
3581     ecdsa_openssh_fmtkey,
3582     3 /* curve name, point, private exponent */,
3583     ecdsa_pubkey_bits,
3584     ecdsa_fingerprint,
3585     ecdsa_verifysig,
3586     ecdsa_sign,
3587     "ecdsa-sha2-nistp384",
3588     "ecdsa-sha2-nistp384",
3589 };
3590
3591 const struct ssh_signkey ssh_ecdsa_nistp521 = {
3592     ecdsa_newkey,
3593     ecdsa_freekey,
3594     ecdsa_fmtkey,
3595     ecdsa_public_blob,
3596     ecdsa_private_blob,
3597     ecdsa_createkey,
3598     ecdsa_openssh_createkey,
3599     ecdsa_openssh_fmtkey,
3600     3 /* curve name, point, private exponent */,
3601     ecdsa_pubkey_bits,
3602     ecdsa_fingerprint,
3603     ecdsa_verifysig,
3604     ecdsa_sign,
3605     "ecdsa-sha2-nistp521",
3606     "ecdsa-sha2-nistp521",
3607 };
3608
3609 /* ----------------------------------------------------------------------
3610  * Exposed ECDH interface
3611  */
3612
3613 static Bignum ecdh_calculate(const Bignum private,
3614                              const struct ec_point *public)
3615 {
3616     struct ec_point *p;
3617     Bignum ret;
3618     p = ecp_mul(public, private);
3619     if (!p) return NULL;
3620     ret = p->x;
3621     p->x = NULL;
3622
3623     if (p->curve->type == EC_MONTGOMERY) {
3624         /* Do conversion in network byte order */
3625         int i;
3626         int bytes = (bignum_bitcount(ret)+7) / 8;
3627         unsigned char *byteorder = snewn(bytes, unsigned char);
3628         if (!byteorder) {
3629             ec_point_free(p);
3630             freebn(ret);
3631             return NULL;
3632         }
3633         for (i = 0; i < bytes; ++i) {
3634             byteorder[i] = bignum_byte(ret, i);
3635         }
3636         freebn(ret);
3637         ret = bignum_from_bytes(byteorder, bytes);
3638         sfree(byteorder);
3639     }
3640
3641     ec_point_free(p);
3642     return ret;
3643 }
3644
3645 void *ssh_ecdhkex_newkey(const char *name)
3646 {
3647     struct ec_curve *curve;
3648     struct ec_key *key;
3649     struct ec_point *publicKey;
3650
3651     curve = ec_name_to_curve(name, strlen(name));
3652
3653     key = snew(struct ec_key);
3654     if (!key) {
3655         return NULL;
3656     }
3657
3658     key->publicKey.curve = curve;
3659
3660     if (curve->type == EC_MONTGOMERY) {
3661         unsigned char bytes[32] = {0};
3662         int i;
3663
3664         for (i = 0; i < sizeof(bytes); ++i)
3665         {
3666             bytes[i] = (unsigned char)random_byte();
3667         }
3668         bytes[0] &= 248;
3669         bytes[31] &= 127;
3670         bytes[31] |= 64;
3671         key->privateKey = bignum_from_bytes(bytes, sizeof(bytes));
3672         for (i = 0; i < sizeof(bytes); ++i)
3673         {
3674             ((volatile char*)bytes)[i] = 0;
3675         }
3676         if (!key->privateKey) {
3677             sfree(key);
3678             return NULL;
3679         }
3680         publicKey = ecp_mul(&key->publicKey.curve->m.G, key->privateKey);
3681         if (!publicKey) {
3682             freebn(key->privateKey);
3683             sfree(key);
3684             return NULL;
3685         }
3686         key->publicKey.x = publicKey->x;
3687         key->publicKey.y = publicKey->y;
3688         key->publicKey.z = NULL;
3689         sfree(publicKey);
3690     } else {
3691         key->privateKey = bignum_random_in_range(One, key->publicKey.curve->w.n);
3692         if (!key->privateKey) {
3693             sfree(key);
3694             return NULL;
3695         }
3696         publicKey = ecp_mul(&key->publicKey.curve->w.G, key->privateKey);
3697         if (!publicKey) {
3698             freebn(key->privateKey);
3699             sfree(key);
3700             return NULL;
3701         }
3702         key->publicKey.x = publicKey->x;
3703         key->publicKey.y = publicKey->y;
3704         key->publicKey.z = NULL;
3705         sfree(publicKey);
3706     }
3707     return key;
3708 }
3709
3710 char *ssh_ecdhkex_getpublic(void *key, int *len)
3711 {
3712     struct ec_key *ec = (struct ec_key*)key;
3713     char *point, *p;
3714     int i;
3715     int pointlen;
3716
3717     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
3718
3719     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
3720         *len = 1 + pointlen * 2;
3721     } else {
3722         *len = pointlen;
3723     }
3724     point = (char*)snewn(*len, char);
3725     if (!point) {
3726         return NULL;
3727     }
3728
3729     p = point;
3730     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
3731         *p++ = 0x04;
3732         for (i = pointlen; i--;) {
3733             *p++ = bignum_byte(ec->publicKey.x, i);
3734         }
3735         for (i = pointlen; i--;) {
3736             *p++ = bignum_byte(ec->publicKey.y, i);
3737         }
3738     } else {
3739         for (i = 0; i < pointlen; ++i) {
3740             *p++ = bignum_byte(ec->publicKey.x, i);
3741         }
3742     }
3743
3744     return point;
3745 }
3746
3747 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
3748 {
3749     struct ec_key *ec = (struct ec_key*) key;
3750     struct ec_point remote;
3751     Bignum ret;
3752
3753     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
3754         remote.curve = ec->publicKey.curve;
3755         remote.infinity = 0;
3756         if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
3757             return NULL;
3758         }
3759     } else {
3760         /* Point length has to be the same length */
3761         if (remoteKeyLen != (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8) {
3762             return NULL;
3763         }
3764
3765         remote.curve = ec->publicKey.curve;
3766         remote.infinity = 0;
3767         remote.x = bignum_from_bytes_le((unsigned char*)remoteKey, remoteKeyLen);
3768         remote.y = NULL;
3769         remote.z = NULL;
3770     }
3771
3772     ret = ecdh_calculate(ec->privateKey, &remote);
3773     if (remote.x) freebn(remote.x);
3774     if (remote.y) freebn(remote.y);
3775     return ret;
3776 }
3777
3778 void ssh_ecdhkex_freekey(void *key)
3779 {
3780     ecdsa_freekey(key);
3781 }
3782
3783 static const struct ssh_kex ssh_ec_kex_curve25519 = {
3784     "curve25519-sha256@libssh.org", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha256
3785 };
3786
3787 static const struct ssh_kex ssh_ec_kex_nistp256 = {
3788     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha256
3789 };
3790
3791 static const struct ssh_kex ssh_ec_kex_nistp384 = {
3792     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha384
3793 };
3794
3795 static const struct ssh_kex ssh_ec_kex_nistp521 = {
3796     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha512
3797 };
3798
3799 static const struct ssh_kex *const ec_kex_list[] = {
3800     &ssh_ec_kex_curve25519,
3801     &ssh_ec_kex_nistp256,
3802     &ssh_ec_kex_nistp384,
3803     &ssh_ec_kex_nistp521
3804 };
3805
3806 const struct ssh_kexes ssh_ecdh_kex = {
3807     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
3808     ec_kex_list
3809 };