]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
Support public keys using the "ssh-ed25519" method.
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  *
5  * NOTE: Only curves on prime field are handled by the maths functions
6  *       in Weierstrass form using Jacobian co-ordinates.
7  *
8  *       Montgomery form curves are supported for DH. (Curve25519)
9  *
10  *       Edwards form curves are supported for DSA. (Ed25519)
11  */
12
13 /*
14  * References:
15  *
16  * Elliptic curves in SSH are specified in RFC 5656:
17  *   http://tools.ietf.org/html/rfc5656
18  *
19  * That specification delegates details of public key formatting and a
20  * lot of underlying mechanism to SEC 1:
21  *   http://www.secg.org/sec1-v2.pdf
22  *
23  * Montgomery maths from:
24  * Handbook of elliptic and hyperelliptic curve cryptography, Chapter 13
25  *   http://cs.ucsb.edu/~koc/ccs130h/2013/EllipticHyperelliptic-CohenFrey.pdf
26  *
27  * Edwards DSA:
28  *   http://ed25519.cr.yp.to/ed25519-20110926.pdf
29  */
30
31 #include <stdlib.h>
32 #include <assert.h>
33
34 #include "ssh.h"
35
36 /* ----------------------------------------------------------------------
37  * Elliptic curve definitions
38  */
39
40 static int initialise_wcurve(struct ec_curve *curve, int bits, unsigned char *p,
41                              unsigned char *a, unsigned char *b,
42                              unsigned char *n, unsigned char *Gx,
43                              unsigned char *Gy)
44 {
45     int length = bits / 8;
46     if (bits % 8) ++length;
47
48     curve->type = EC_WEIERSTRASS;
49
50     curve->fieldBits = bits;
51     curve->p = bignum_from_bytes(p, length);
52     if (!curve->p) goto error;
53
54     /* Curve co-efficients */
55     curve->w.a = bignum_from_bytes(a, length);
56     if (!curve->w.a) goto error;
57     curve->w.b = bignum_from_bytes(b, length);
58     if (!curve->w.b) goto error;
59
60     /* Group order and generator */
61     curve->w.n = bignum_from_bytes(n, length);
62     if (!curve->w.n) goto error;
63     curve->w.G.x = bignum_from_bytes(Gx, length);
64     if (!curve->w.G.x) goto error;
65     curve->w.G.y = bignum_from_bytes(Gy, length);
66     if (!curve->w.G.y) goto error;
67     curve->w.G.curve = curve;
68     curve->w.G.infinity = 0;
69
70     return 1;
71   error:
72     if (curve->p) freebn(curve->p);
73     if (curve->w.a) freebn(curve->w.a);
74     if (curve->w.b) freebn(curve->w.b);
75     if (curve->w.n) freebn(curve->w.n);
76     if (curve->w.G.x) freebn(curve->w.G.x);
77     return 0;
78 }
79
80 static int initialise_mcurve(struct ec_curve *curve, int bits, unsigned char *p,
81                              unsigned char *a, unsigned char *b,
82                              unsigned char *Gx)
83 {
84     int length = bits / 8;
85     if (bits % 8) ++length;
86
87     curve->type = EC_MONTGOMERY;
88
89     curve->fieldBits = bits;
90     curve->p = bignum_from_bytes(p, length);
91     if (!curve->p) goto error;
92
93     /* Curve co-efficients */
94     curve->m.a = bignum_from_bytes(a, length);
95     if (!curve->m.a) goto error;
96     curve->m.b = bignum_from_bytes(b, length);
97     if (!curve->m.b) goto error;
98
99     /* Generator */
100     curve->m.G.x = bignum_from_bytes(Gx, length);
101     if (!curve->m.G.x) goto error;
102     curve->m.G.y = NULL;
103     curve->m.G.z = NULL;
104     curve->m.G.curve = curve;
105     curve->m.G.infinity = 0;
106
107     return 1;
108   error:
109     if (curve->p) freebn(curve->p);
110     if (curve->m.a) freebn(curve->m.a);
111     if (curve->m.b) freebn(curve->m.b);
112     return 0;
113 }
114
115 static int initialise_ecurve(struct ec_curve *curve, int bits, unsigned char *p,
116                              unsigned char *l, unsigned char *d,
117                              unsigned char *Bx, unsigned char *By)
118 {
119     int length = bits / 8;
120     if (bits % 8) ++length;
121
122     curve->type = EC_EDWARDS;
123
124     curve->fieldBits = bits;
125     curve->p = bignum_from_bytes(p, length);
126     if (!curve->p) goto error;
127
128     /* Curve co-efficients */
129     curve->e.l = bignum_from_bytes(l, length);
130     if (!curve->e.l) goto error;
131     curve->e.d = bignum_from_bytes(d, length);
132     if (!curve->e.d) goto error;
133
134     /* Group order and generator */
135     curve->e.B.x = bignum_from_bytes(Bx, length);
136     if (!curve->e.B.x) goto error;
137     curve->e.B.y = bignum_from_bytes(By, length);
138     if (!curve->e.B.y) goto error;
139     curve->e.B.curve = curve;
140     curve->e.B.infinity = 0;
141
142     return 1;
143   error:
144     if (curve->p) freebn(curve->p);
145     if (curve->e.l) freebn(curve->e.l);
146     if (curve->e.d) freebn(curve->e.d);
147     if (curve->e.B.x) freebn(curve->e.B.x);
148     return 0;
149 }
150
151 unsigned char nistp256_oid[] = {0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07};
152 int nistp256_oid_len = 8;
153 unsigned char nistp384_oid[] = {0x2b, 0x81, 0x04, 0x00, 0x22};
154 int nistp384_oid_len = 5;
155 unsigned char nistp521_oid[] = {0x2b, 0x81, 0x04, 0x00, 0x23};
156 int nistp521_oid_len = 5;
157 unsigned char curve25519_oid[] = {0x06, 0x0A, 0x2B, 0x06, 0x01, 0x04, 0x01, 0x97, 0x55, 0x01, 0x05, 0x01};
158 int curve25519_oid_len = 12;
159
160 struct ec_curve *ec_p256(void)
161 {
162     static struct ec_curve curve = { 0 };
163     static unsigned char initialised = 0;
164
165     if (!initialised)
166     {
167         unsigned char p[] = {
168             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
169             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
170             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
171             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
172         };
173         unsigned char a[] = {
174             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
175             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
176             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
177             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
178         };
179         unsigned char b[] = {
180             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
181             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
182             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
183             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
184         };
185         unsigned char n[] = {
186             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
187             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
188             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
189             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
190         };
191         unsigned char Gx[] = {
192             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
193             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
194             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
195             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
196         };
197         unsigned char Gy[] = {
198             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
199             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
200             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
201             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
202         };
203
204         if (!initialise_wcurve(&curve, 256, p, a, b, n, Gx, Gy)) {
205             return NULL;
206         }
207
208         /* Now initialised, no need to do it again */
209         initialised = 1;
210     }
211
212     return &curve;
213 }
214
215 struct ec_curve *ec_p384(void)
216 {
217     static struct ec_curve curve = { 0 };
218     static unsigned char initialised = 0;
219
220     if (!initialised)
221     {
222         unsigned char p[] = {
223             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
224             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
225             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
226             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
227             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
228             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
229         };
230         unsigned char a[] = {
231             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
232             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
233             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
234             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
235             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
236             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
237         };
238         unsigned char b[] = {
239             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
240             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
241             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
242             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
243             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
244             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
245         };
246         unsigned char n[] = {
247             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
248             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
249             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
250             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
251             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
252             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
253         };
254         unsigned char Gx[] = {
255             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
256             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
257             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
258             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
259             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
260             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
261         };
262         unsigned char Gy[] = {
263             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
264             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
265             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
266             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
267             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
268             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
269         };
270
271         if (!initialise_wcurve(&curve, 384, p, a, b, n, Gx, Gy)) {
272             return NULL;
273         }
274
275         /* Now initialised, no need to do it again */
276         initialised = 1;
277     }
278
279     return &curve;
280 }
281
282 struct ec_curve *ec_p521(void)
283 {
284     static struct ec_curve curve = { 0 };
285     static unsigned char initialised = 0;
286
287     if (!initialised)
288     {
289         unsigned char p[] = {
290             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
291             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
292             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
293             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
294             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
295             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
296             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
297             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
298             0xff, 0xff
299         };
300         unsigned char a[] = {
301             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
302             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
303             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
304             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
305             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
306             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
307             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
308             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
309             0xff, 0xfc
310         };
311         unsigned char b[] = {
312             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
313             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
314             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
315             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
316             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
317             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
318             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
319             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
320             0x3f, 0x00
321         };
322         unsigned char n[] = {
323             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
324             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
325             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
326             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
327             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
328             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
329             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
330             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
331             0x64, 0x09
332         };
333         unsigned char Gx[] = {
334             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
335             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
336             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
337             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
338             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
339             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
340             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
341             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
342             0xbd, 0x66
343         };
344         unsigned char Gy[] = {
345             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
346             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
347             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
348             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
349             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
350             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
351             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
352             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
353             0x66, 0x50
354         };
355
356         if (!initialise_wcurve(&curve, 521, p, a, b, n, Gx, Gy)) {
357             return NULL;
358         }
359
360         /* Now initialised, no need to do it again */
361         initialised = 1;
362     }
363
364     return &curve;
365 }
366
367 struct ec_curve *ec_curve25519(void)
368 {
369     static struct ec_curve curve = { 0 };
370     static unsigned char initialised = 0;
371
372     if (!initialised)
373     {
374         unsigned char p[] = {
375             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
376             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
377             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
378             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
379         };
380         unsigned char a[] = {
381             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
382             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
383             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
384             0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x6d, 0x06
385         };
386         unsigned char b[] = {
387             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
388             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
389             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
390             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01
391         };
392         unsigned char gx[32] = {
393             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
394             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
395             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
396             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09
397         };
398
399         if (!initialise_mcurve(&curve, 256, p, a, b, gx)) {
400             return NULL;
401         }
402
403         /* Now initialised, no need to do it again */
404         initialised = 1;
405     }
406
407     return &curve;
408 }
409 struct ec_curve *ec_ed25519(void)
410 {
411     static struct ec_curve curve = { 0 };
412     static unsigned char initialised = 0;
413
414     if (!initialised)
415     {
416         unsigned char q[] = {
417             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
418             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
419             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
420             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
421         };
422         unsigned char l[32] = {
423             0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
424             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
425             0x14, 0xde, 0xf9, 0xde, 0xa2, 0xf7, 0x9c, 0xd6,
426             0x58, 0x12, 0x63, 0x1a, 0x5c, 0xf5, 0xd3, 0xed
427         };
428         unsigned char d[32] = {
429             0x52, 0x03, 0x6c, 0xee, 0x2b, 0x6f, 0xfe, 0x73,
430             0x8c, 0xc7, 0x40, 0x79, 0x77, 0x79, 0xe8, 0x98,
431             0x00, 0x70, 0x0a, 0x4d, 0x41, 0x41, 0xd8, 0xab,
432             0x75, 0xeb, 0x4d, 0xca, 0x13, 0x59, 0x78, 0xa3
433         };
434         unsigned char Bx[32] = {
435             0x21, 0x69, 0x36, 0xd3, 0xcd, 0x6e, 0x53, 0xfe,
436             0xc0, 0xa4, 0xe2, 0x31, 0xfd, 0xd6, 0xdc, 0x5c,
437             0x69, 0x2c, 0xc7, 0x60, 0x95, 0x25, 0xa7, 0xb2,
438             0xc9, 0x56, 0x2d, 0x60, 0x8f, 0x25, 0xd5, 0x1a
439         };
440         unsigned char By[32] = {
441             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
442             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
443             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
444             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x58
445         };
446
447
448         if (!initialise_ecurve(&curve, 256, q, l, d, Bx, By)) {
449             return NULL;
450         }
451
452         /* Now initialised, no need to do it again */
453         initialised = 1;
454     }
455
456     return &curve;
457 }
458
459 static struct ec_curve *ec_name_to_curve(const char *name, int len) {
460     if (len > 11 && !memcmp(name, "ecdsa-sha2-", 11)) {
461         name += 11;
462         len -= 11;
463     } else if (len > 10 && !memcmp(name, "ecdh-sha2-", 10)) {
464         name += 10;
465         len -= 10;
466     } else if (len == 11 && !memcmp(name, "ssh-ed25519", 11)) {
467         return ec_ed25519();
468     }
469
470     if (len == 8 && !memcmp(name, "nistp", 5)) {
471         name += 5;
472         if (!memcmp(name, "256", 3)) {
473             return ec_p256();
474         } else if (!memcmp(name, "384", 3)) {
475             return ec_p384();
476         } else if (!memcmp(name, "521", 3)) {
477             return ec_p521();
478         }
479     }
480
481     if (len == 28 && !memcmp(name, "curve25519-sha256@libssh.org", 28)) {
482         return ec_curve25519();
483     }
484
485     return NULL;
486 }
487
488 /* Type enumeration for specifying the curve name */
489 enum ec_name_type { EC_TYPE_DSA, EC_TYPE_DH, EC_TYPE_CURVE };
490
491 static int ec_curve_to_name(enum ec_name_type type, const struct ec_curve *curve,
492                             unsigned char *name, int len) {
493     if (curve->type == EC_WEIERSTRASS) {
494         int length, loc;
495         if (type == EC_TYPE_DSA) {
496             length = 19;
497             loc = 16;
498         } else if (type == EC_TYPE_DH) {
499             length = 18;
500             loc = 15;
501         } else {
502             length = 8;
503             loc = 5;
504         }
505
506         /* Return length of string */
507         if (name == NULL) return length;
508
509         /* Not enough space for the name */
510         if (len < length) return 0;
511
512         /* Put the name in the buffer */
513         switch (curve->fieldBits) {
514           case 256:
515             memcpy(name+loc, "256", 3);
516             break;
517           case 384:
518             memcpy(name+loc, "384", 3);
519             break;
520           case 521:
521             memcpy(name+loc, "521", 3);
522             break;
523           default:
524             return 0;
525         }
526
527         if (type == EC_TYPE_DSA) {
528             memcpy(name, "ecdsa-sha2-nistp", 16);
529         } else if (type == EC_TYPE_DH) {
530             memcpy(name, "ecdh-sha2-nistp", 15);
531         } else {
532             memcpy(name, "nistp", 5);
533         }
534
535         return length;
536     } else if (curve->type == EC_EDWARDS) {
537         /* No DH for ed25519 - use Montgomery instead */
538         if (type == EC_TYPE_DH) return 0;
539
540         if (type == EC_TYPE_CURVE) {
541             /* Return length of string */
542             if (name == NULL) return 7;
543
544             /* Not enough space for the name */
545             if (len < 7) return 0;
546
547             /* Unknown curve field */
548             if (curve->fieldBits != 256) return 0;
549
550             memcpy(name, "ed25519", 7);
551             return 7;
552
553         } else {
554             /* Return length of string */
555             if (name == NULL) return 11;
556
557             /* Not enough space for the name */
558             if (len < 11) return 0;
559
560             /* Unknown curve field */
561             if (curve->fieldBits != 256) return 0;
562
563             memcpy(name, "ssh-ed25519", 11);
564             return 11;
565         }
566     } else {
567         /* No DSA for curve25519 */
568         if (type == EC_TYPE_DSA || type == EC_TYPE_CURVE) return 0;
569
570         /* Return length of string */
571         if (name == NULL) return 28;
572
573         /* Not enough space for the name */
574         if (len < 28) return 0;
575
576         /* Unknown curve field */
577         if (curve->fieldBits != 256) return 0;
578
579         memcpy(name, "curve25519-sha256@libssh.org", 28);
580         return 28;
581     }
582 }
583
584 /* Return 1 if a is -3 % p, otherwise return 0
585  * This is used because there are some maths optimisations */
586 static int ec_aminus3(const struct ec_curve *curve)
587 {
588     int ret;
589     Bignum _p;
590
591     if (curve->type != EC_WEIERSTRASS) {
592         return 0;
593     }
594
595     _p = bignum_add_long(curve->w.a, 3);
596     if (!_p) return 0;
597
598     ret = !bignum_cmp(curve->p, _p);
599     freebn(_p);
600     return ret;
601 }
602
603 /* ----------------------------------------------------------------------
604  * Elliptic curve field maths
605  */
606
607 static Bignum ecf_add(const Bignum a, const Bignum b,
608                       const struct ec_curve *curve)
609 {
610     Bignum a1, b1, ab, ret;
611
612     a1 = bigmod(a, curve->p);
613     if (!a1) return NULL;
614     b1 = bigmod(b, curve->p);
615     if (!b1)
616     {
617         freebn(a1);
618         return NULL;
619     }
620
621     ab = bigadd(a1, b1);
622     freebn(a1);
623     freebn(b1);
624     if (!ab) return NULL;
625
626     ret = bigmod(ab, curve->p);
627     freebn(ab);
628
629     return ret;
630 }
631
632 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
633 {
634     return modmul(a, a, curve->p);
635 }
636
637 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
638 {
639     Bignum ret, tmp;
640
641     /* Double */
642     tmp = bignum_lshift(a, 1);
643     if (!tmp) return NULL;
644
645     /* Add itself (i.e. treble) */
646     ret = bigadd(tmp, a);
647     freebn(tmp);
648
649     /* Normalise */
650     while (ret != NULL && bignum_cmp(ret, curve->p) >= 0)
651     {
652         tmp = bigsub(ret, curve->p);
653         freebn(ret);
654         ret = tmp;
655     }
656
657     return ret;
658 }
659
660 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
661 {
662     Bignum ret = bignum_lshift(a, 1);
663     if (!ret) return NULL;
664     if (bignum_cmp(ret, curve->p) >= 0)
665     {
666         Bignum tmp = bigsub(ret, curve->p);
667         freebn(ret);
668         return tmp;
669     }
670     else
671     {
672         return ret;
673     }
674 }
675
676 /* ----------------------------------------------------------------------
677  * Memory functions
678  */
679
680 void ec_point_free(struct ec_point *point)
681 {
682     if (point == NULL) return;
683     point->curve = 0;
684     if (point->x) freebn(point->x);
685     if (point->y) freebn(point->y);
686     if (point->z) freebn(point->z);
687     point->infinity = 0;
688     sfree(point);
689 }
690
691 static struct ec_point *ec_point_new(const struct ec_curve *curve,
692                                      const Bignum x, const Bignum y, const Bignum z,
693                                      unsigned char infinity)
694 {
695     struct ec_point *point = snewn(1, struct ec_point);
696     point->curve = curve;
697     point->x = x;
698     point->y = y;
699     point->z = z;
700     point->infinity = infinity ? 1 : 0;
701     return point;
702 }
703
704 static struct ec_point *ec_point_copy(const struct ec_point *a)
705 {
706     if (a == NULL) return NULL;
707     return ec_point_new(a->curve,
708                         a->x ? copybn(a->x) : NULL,
709                         a->y ? copybn(a->y) : NULL,
710                         a->z ? copybn(a->z) : NULL,
711                         a->infinity);
712 }
713
714 static int ec_point_verify(const struct ec_point *a)
715 {
716     if (a->infinity) {
717         return 1;
718     } else if (a->curve->type == EC_EDWARDS) {
719         /* Check y^2 - x^2 - 1 - d * x^2 * y^2 == 0 */
720         Bignum y2, x2, tmp, tmp2, tmp3;
721         int ret;
722
723         y2 = ecf_square(a->y, a->curve);
724         if (!y2) {
725             return 0;
726         }
727         x2 = ecf_square(a->x, a->curve);
728         if (!x2) {
729             freebn(y2);
730             return 0;
731         }
732         tmp = modmul(a->curve->e.d, x2, a->curve->p);
733         if (!tmp) {
734             freebn(x2);
735             freebn(y2);
736             return 0;
737         }
738         tmp2 = modmul(tmp, y2, a->curve->p);
739         freebn(tmp);
740         if (!tmp2) {
741             freebn(x2);
742             freebn(y2);
743             return 0;
744         }
745         tmp = modsub(y2, x2, a->curve->p);
746         freebn(y2);
747         freebn(x2);
748         if (!tmp) {
749             freebn(tmp2);
750             return 0;
751         }
752         tmp3 = modsub(tmp, tmp2, a->curve->p);
753         freebn(tmp);
754         freebn(tmp2);
755         if (!tmp3) {
756             return 0;
757         }
758         ret = !bignum_cmp(tmp3, One);
759         freebn(tmp3);
760         return ret;
761     } else if (a->curve->type == EC_WEIERSTRASS) {
762         /* Verify y^2 = x^3 + ax + b */
763         int ret = 0;
764
765         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
766
767         Bignum Three = bignum_from_long(3);
768         if (!Three) return 0;
769
770         lhs = modmul(a->y, a->y, a->curve->p);
771         if (!lhs) goto error;
772
773         /* This uses montgomery multiplication to optimise */
774         x3 = modpow(a->x, Three, a->curve->p);
775         freebn(Three);
776         if (!x3) goto error;
777         ax = modmul(a->curve->w.a, a->x, a->curve->p);
778         if (!ax) goto error;
779         x3ax = bigadd(x3, ax);
780         if (!x3ax) goto error;
781         freebn(x3); x3 = NULL;
782         freebn(ax); ax = NULL;
783         x3axm = bigmod(x3ax, a->curve->p);
784         if (!x3axm) goto error;
785         freebn(x3ax); x3ax = NULL;
786         x3axb = bigadd(x3axm, a->curve->w.b);
787         if (!x3axb) goto error;
788         freebn(x3axm); x3axm = NULL;
789         rhs = bigmod(x3axb, a->curve->p);
790         if (!rhs) goto error;
791         freebn(x3axb);
792
793         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
794         freebn(lhs);
795         freebn(rhs);
796
797         return ret;
798
799       error:
800         if (x3) freebn(x3);
801         if (ax) freebn(ax);
802         if (x3ax) freebn(x3ax);
803         if (x3axm) freebn(x3axm);
804         if (x3axb) freebn(x3axb);
805         if (lhs) freebn(lhs);
806         return 0;
807     } else {
808         return 0;
809     }
810 }
811
812 /* ----------------------------------------------------------------------
813  * Elliptic curve point maths
814  */
815
816 /* Returns 1 on success and 0 on memory error */
817 static int ecp_normalise(struct ec_point *a)
818 {
819     if (!a) {
820         /* No point */
821         return 0;
822     }
823
824     if (a->infinity) {
825         /* Point is at infinity - i.e. normalised */
826         return 1;
827     }
828
829     if (a->curve->type == EC_WEIERSTRASS) {
830         /* In Jacobian Coordinates the triple (X, Y, Z) represents
831            the affine point (X / Z^2, Y / Z^3) */
832
833         Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
834
835         if (!a->x || !a->y) {
836             /* No point defined */
837             return 0;
838         } else if (!a->z) {
839             /* Already normalised */
840             return 1;
841         }
842
843         Z2 = ecf_square(a->z, a->curve);
844         if (!Z2) {
845             return 0;
846         }
847         Z2inv = modinv(Z2, a->curve->p);
848         if (!Z2inv) {
849             freebn(Z2);
850             return 0;
851         }
852         tx = modmul(a->x, Z2inv, a->curve->p);
853         freebn(Z2inv);
854         if (!tx) {
855             freebn(Z2);
856             return 0;
857         }
858
859         Z3 = modmul(Z2, a->z, a->curve->p);
860         freebn(Z2);
861         if (!Z3) {
862             freebn(tx);
863             return 0;
864         }
865         Z3inv = modinv(Z3, a->curve->p);
866         freebn(Z3);
867         if (!Z3inv) {
868             freebn(tx);
869             return 0;
870         }
871         ty = modmul(a->y, Z3inv, a->curve->p);
872         freebn(Z3inv);
873         if (!ty) {
874             freebn(tx);
875             return 0;
876         }
877
878         freebn(a->x);
879         a->x = tx;
880         freebn(a->y);
881         a->y = ty;
882         freebn(a->z);
883         a->z = NULL;
884         return 1;
885     } else if (a->curve->type == EC_MONTGOMERY) {
886         /* In Montgomery (X : Z) represents the x co-ord (X / Z, ?) */
887
888         Bignum tmp, tmp2;
889
890         if (!a->x) {
891             /* No point defined */
892             return 0;
893         } else if (!a->z) {
894             /* Already normalised */
895             return 1;
896         }
897
898         tmp = modinv(a->z, a->curve->p);
899         if (!tmp) {
900             return 0;
901         }
902         tmp2 = modmul(a->x, tmp, a->curve->p);
903         freebn(tmp);
904         if (!tmp2) {
905             return 0;
906         }
907
908         freebn(a->z);
909         a->z = NULL;
910         freebn(a->x);
911         a->x = tmp2;
912         return 1;
913     } else if (a->curve->type == EC_EDWARDS) {
914         /* Always normalised */
915         return 1;
916     } else {
917         return 0;
918     }
919 }
920
921 static struct ec_point *ecp_doublew(const struct ec_point *a, const int aminus3)
922 {
923     Bignum S, M, outx, outy, outz;
924
925     if (bignum_cmp(a->y, Zero) == 0)
926     {
927         /* Identity */
928         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
929     }
930
931     /* S = 4*X*Y^2 */
932     {
933         Bignum Y2, XY2, _2XY2;
934
935         Y2 = ecf_square(a->y, a->curve);
936         if (!Y2) {
937             return NULL;
938         }
939         XY2 = modmul(a->x, Y2, a->curve->p);
940         freebn(Y2);
941         if (!XY2) {
942             return NULL;
943         }
944
945         _2XY2 = ecf_double(XY2, a->curve);
946         freebn(XY2);
947         if (!_2XY2) {
948             return NULL;
949         }
950         S = ecf_double(_2XY2, a->curve);
951         freebn(_2XY2);
952         if (!S) {
953             return NULL;
954         }
955     }
956
957     /* Faster calculation if a = -3 */
958     if (aminus3) {
959         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
960         Bignum Z2, XpZ2, XmZ2, second;
961
962         if (a->z == NULL) {
963             Z2 = copybn(One);
964         } else {
965             Z2 = ecf_square(a->z, a->curve);
966         }
967         if (!Z2) {
968             freebn(S);
969             return NULL;
970         }
971
972         XpZ2 = ecf_add(a->x, Z2, a->curve);
973         if (!XpZ2) {
974             freebn(S);
975             freebn(Z2);
976             return NULL;
977         }
978         XmZ2 = modsub(a->x, Z2, a->curve->p);
979         freebn(Z2);
980         if (!XmZ2) {
981             freebn(S);
982             freebn(XpZ2);
983             return NULL;
984         }
985
986         second = modmul(XpZ2, XmZ2, a->curve->p);
987         freebn(XpZ2);
988         freebn(XmZ2);
989         if (!second) {
990             freebn(S);
991             return NULL;
992         }
993
994         M = ecf_treble(second, a->curve);
995         freebn(second);
996         if (!M) {
997             freebn(S);
998             return NULL;
999         }
1000     } else {
1001         /* M = 3*X^2 + a*Z^4 */
1002         Bignum _3X2, X2, aZ4;
1003
1004         if (a->z == NULL) {
1005             aZ4 = copybn(a->curve->w.a);
1006         } else {
1007             Bignum Z2, Z4;
1008
1009             Z2 = ecf_square(a->z, a->curve);
1010             if (!Z2) {
1011                 freebn(S);
1012                 return NULL;
1013             }
1014             Z4 = ecf_square(Z2, a->curve);
1015             freebn(Z2);
1016             if (!Z4) {
1017                 freebn(S);
1018                 return NULL;
1019             }
1020             aZ4 = modmul(a->curve->w.a, Z4, a->curve->p);
1021             freebn(Z4);
1022         }
1023         if (!aZ4) {
1024             freebn(S);
1025             return NULL;
1026         }
1027
1028         X2 = modmul(a->x, a->x, a->curve->p);
1029         if (!X2) {
1030             freebn(S);
1031             freebn(aZ4);
1032             return NULL;
1033         }
1034         _3X2 = ecf_treble(X2, a->curve);
1035         freebn(X2);
1036         if (!_3X2) {
1037             freebn(S);
1038             freebn(aZ4);
1039             return NULL;
1040         }
1041         M = ecf_add(_3X2, aZ4, a->curve);
1042         freebn(_3X2);
1043         freebn(aZ4);
1044         if (!M) {
1045             freebn(S);
1046             return NULL;
1047         }
1048     }
1049
1050     /* X' = M^2 - 2*S */
1051     {
1052         Bignum M2, _2S;
1053
1054         M2 = ecf_square(M, a->curve);
1055         if (!M2) {
1056             freebn(S);
1057             freebn(M);
1058             return NULL;
1059         }
1060
1061         _2S = ecf_double(S, a->curve);
1062         if (!_2S) {
1063             freebn(M2);
1064             freebn(S);
1065             freebn(M);
1066             return NULL;
1067         }
1068
1069         outx = modsub(M2, _2S, a->curve->p);
1070         freebn(M2);
1071         freebn(_2S);
1072         if (!outx) {
1073             freebn(S);
1074             freebn(M);
1075             return NULL;
1076         }
1077     }
1078
1079     /* Y' = M*(S - X') - 8*Y^4 */
1080     {
1081         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
1082
1083         SX = modsub(S, outx, a->curve->p);
1084         freebn(S);
1085         if (!SX) {
1086             freebn(M);
1087             freebn(outx);
1088             return NULL;
1089         }
1090         MSX = modmul(M, SX, a->curve->p);
1091         freebn(SX);
1092         freebn(M);
1093         if (!MSX) {
1094             freebn(outx);
1095             return NULL;
1096         }
1097         Y2 = ecf_square(a->y, a->curve);
1098         if (!Y2) {
1099             freebn(outx);
1100             freebn(MSX);
1101             return NULL;
1102         }
1103         Y4 = ecf_square(Y2, a->curve);
1104         freebn(Y2);
1105         if (!Y4) {
1106             freebn(outx);
1107             freebn(MSX);
1108             return NULL;
1109         }
1110         Eight = bignum_from_long(8);
1111         if (!Eight) {
1112             freebn(outx);
1113             freebn(MSX);
1114             freebn(Y4);
1115             return NULL;
1116         }
1117         _8Y4 = modmul(Eight, Y4, a->curve->p);
1118         freebn(Eight);
1119         freebn(Y4);
1120         if (!_8Y4) {
1121             freebn(outx);
1122             freebn(MSX);
1123             return NULL;
1124         }
1125         outy = modsub(MSX, _8Y4, a->curve->p);
1126         freebn(MSX);
1127         freebn(_8Y4);
1128         if (!outy) {
1129             freebn(outx);
1130             return NULL;
1131         }
1132     }
1133
1134     /* Z' = 2*Y*Z */
1135     {
1136         Bignum YZ;
1137
1138         if (a->z == NULL) {
1139             YZ = copybn(a->y);
1140         } else {
1141             YZ = modmul(a->y, a->z, a->curve->p);
1142         }
1143         if (!YZ) {
1144             freebn(outx);
1145             freebn(outy);
1146             return NULL;
1147         }
1148
1149         outz = ecf_double(YZ, a->curve);
1150         freebn(YZ);
1151         if (!outz) {
1152             freebn(outx);
1153             freebn(outy);
1154             return NULL;
1155         }
1156     }
1157
1158     return ec_point_new(a->curve, outx, outy, outz, 0);
1159 }
1160
1161 static struct ec_point *ecp_doublem(const struct ec_point *a)
1162 {
1163     Bignum z, outx, outz, xpz, xmz;
1164
1165     z = a->z;
1166     if (!z) {
1167         z = One;
1168     }
1169
1170     /* 4xz = (x + z)^2 - (x - z)^2 */
1171     {
1172         Bignum tmp;
1173
1174         tmp = ecf_add(a->x, z, a->curve);
1175         if (!tmp) {
1176             return NULL;
1177         }
1178         xpz = ecf_square(tmp, a->curve);
1179         freebn(tmp);
1180         if (!xpz) {
1181             return NULL;
1182         }
1183
1184         tmp = modsub(a->x, z, a->curve->p);
1185         if (!tmp) {
1186             freebn(xpz);
1187             return NULL;
1188         }
1189         xmz = ecf_square(tmp, a->curve);
1190         freebn(tmp);
1191         if (!xmz) {
1192             freebn(xpz);
1193             return NULL;
1194         }
1195     }
1196
1197     /* outx = (x + z)^2 * (x - z)^2 */
1198     outx = modmul(xpz, xmz, a->curve->p);
1199     if (!outx) {
1200         freebn(xpz);
1201         freebn(xmz);
1202         return NULL;
1203     }
1204
1205     /* outz = 4xz * ((x - z)^2 + ((A + 2) / 4)*4xz) */
1206     {
1207         Bignum _4xz, tmp, tmp2, tmp3;
1208
1209         tmp = bignum_from_long(2);
1210         if (!tmp) {
1211             freebn(xpz);
1212             freebn(outx);
1213             freebn(xmz);
1214             return NULL;
1215         }
1216         tmp2 = ecf_add(a->curve->m.a, tmp, a->curve);
1217         freebn(tmp);
1218         if (!tmp2) {
1219             freebn(xpz);
1220             freebn(outx);
1221             freebn(xmz);
1222             return NULL;
1223         }
1224
1225         _4xz = modsub(xpz, xmz, a->curve->p);
1226         freebn(xpz);
1227         if (!_4xz) {
1228             freebn(tmp2);
1229             freebn(outx);
1230             freebn(xmz);
1231             return NULL;
1232         }
1233         tmp = modmul(tmp2, _4xz, a->curve->p);
1234         freebn(tmp2);
1235         if (!tmp) {
1236             freebn(_4xz);
1237             freebn(outx);
1238             freebn(xmz);
1239             return NULL;
1240         }
1241
1242         tmp2 = bignum_from_long(4);
1243         if (!tmp2) {
1244             freebn(tmp);
1245             freebn(_4xz);
1246             freebn(outx);
1247             freebn(xmz);
1248             return NULL;
1249         }
1250         tmp3 = modinv(tmp2, a->curve->p);
1251         freebn(tmp2);
1252         if (!tmp3) {
1253             freebn(tmp);
1254             freebn(_4xz);
1255             freebn(outx);
1256             freebn(xmz);
1257             return NULL;
1258         }
1259         tmp2 = modmul(tmp, tmp3, a->curve->p);
1260         freebn(tmp);
1261         freebn(tmp3);
1262         if (!tmp2) {
1263             freebn(_4xz);
1264             freebn(outx);
1265             freebn(xmz);
1266             return NULL;
1267         }
1268
1269         tmp = ecf_add(xmz, tmp2, a->curve);
1270         freebn(xmz);
1271         freebn(tmp2);
1272         if (!tmp) {
1273             freebn(_4xz);
1274             freebn(outx);
1275             return NULL;
1276         }
1277         outz = modmul(_4xz, tmp, a->curve->p);
1278         freebn(_4xz);
1279         freebn(tmp);
1280         if (!outz) {
1281             freebn(outx);
1282             return NULL;
1283         }
1284     }
1285
1286     return ec_point_new(a->curve, outx, NULL, outz, 0);
1287 }
1288
1289 /* Forward declaration for Edwards curve doubling */
1290 static struct ec_point *ecp_add(const struct ec_point *a,
1291                                 const struct ec_point *b,
1292                                 const int aminus3);
1293
1294 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
1295 {
1296     if (a->infinity)
1297     {
1298         /* Identity */
1299         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
1300     }
1301
1302     if (a->curve->type == EC_EDWARDS)
1303     {
1304         return ecp_add(a, a, aminus3);
1305     }
1306     else if (a->curve->type == EC_WEIERSTRASS)
1307     {
1308         return ecp_doublew(a, aminus3);
1309     }
1310     else
1311     {
1312         return ecp_doublem(a);
1313     }
1314 }
1315
1316 static struct ec_point *ecp_addw(const struct ec_point *a,
1317                                  const struct ec_point *b,
1318                                  const int aminus3)
1319 {
1320     Bignum U1, U2, S1, S2, outx, outy, outz;
1321
1322     /* U1 = X1*Z2^2 */
1323     /* S1 = Y1*Z2^3 */
1324     if (b->z) {
1325         Bignum Z2, Z3;
1326
1327         Z2 = ecf_square(b->z, a->curve);
1328         if (!Z2) {
1329             return NULL;
1330         }
1331         U1 = modmul(a->x, Z2, a->curve->p);
1332         if (!U1) {
1333             freebn(Z2);
1334             return NULL;
1335         }
1336         Z3 = modmul(Z2, b->z, a->curve->p);
1337         freebn(Z2);
1338         if (!Z3) {
1339             freebn(U1);
1340             return NULL;
1341         }
1342         S1 = modmul(a->y, Z3, a->curve->p);
1343         freebn(Z3);
1344         if (!S1) {
1345             freebn(U1);
1346             return NULL;
1347         }
1348     } else {
1349         U1 = copybn(a->x);
1350         if (!U1) {
1351             return NULL;
1352         }
1353         S1 = copybn(a->y);
1354         if (!S1) {
1355             freebn(U1);
1356             return NULL;
1357         }
1358     }
1359
1360     /* U2 = X2*Z1^2 */
1361     /* S2 = Y2*Z1^3 */
1362     if (a->z) {
1363         Bignum Z2, Z3;
1364
1365         Z2 = ecf_square(a->z, b->curve);
1366         if (!Z2) {
1367             freebn(U1);
1368             freebn(S1);
1369             return NULL;
1370         }
1371         U2 = modmul(b->x, Z2, b->curve->p);
1372         if (!U2) {
1373             freebn(U1);
1374             freebn(S1);
1375             freebn(Z2);
1376             return NULL;
1377         }
1378         Z3 = modmul(Z2, a->z, b->curve->p);
1379         freebn(Z2);
1380         if (!Z3) {
1381             freebn(U1);
1382             freebn(S1);
1383             freebn(U2);
1384             return NULL;
1385         }
1386         S2 = modmul(b->y, Z3, b->curve->p);
1387         freebn(Z3);
1388         if (!S2) {
1389             freebn(U1);
1390             freebn(S1);
1391             freebn(U2);
1392             return NULL;
1393         }
1394     } else {
1395         U2 = copybn(b->x);
1396         if (!U2) {
1397             freebn(U1);
1398             freebn(S1);
1399             return NULL;
1400         }
1401         S2 = copybn(b->y);
1402         if (!S2) {
1403             freebn(U1);
1404             freebn(S1);
1405             freebn(U2);
1406             return NULL;
1407         }
1408     }
1409
1410     /* Check if multiplying by self */
1411     if (bignum_cmp(U1, U2) == 0)
1412     {
1413         freebn(U1);
1414         freebn(U2);
1415         if (bignum_cmp(S1, S2) == 0)
1416         {
1417             freebn(S1);
1418             freebn(S2);
1419             return ecp_double(a, aminus3);
1420         }
1421         else
1422         {
1423             freebn(S1);
1424             freebn(S2);
1425             /* Infinity */
1426             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
1427         }
1428     }
1429
1430     {
1431         Bignum H, R, UH2, H3;
1432
1433         /* H = U2 - U1 */
1434         H = modsub(U2, U1, a->curve->p);
1435         freebn(U2);
1436         if (!H) {
1437             freebn(U1);
1438             freebn(S1);
1439             freebn(S2);
1440             return NULL;
1441         }
1442
1443         /* R = S2 - S1 */
1444         R = modsub(S2, S1, a->curve->p);
1445         freebn(S2);
1446         if (!R) {
1447             freebn(H);
1448             freebn(S1);
1449             freebn(U1);
1450             return NULL;
1451         }
1452
1453         /* X3 = R^2 - H^3 - 2*U1*H^2 */
1454         {
1455             Bignum R2, H2, _2UH2, first;
1456
1457             H2 = ecf_square(H, a->curve);
1458             if (!H2) {
1459                 freebn(U1);
1460                 freebn(S1);
1461                 freebn(H);
1462                 freebn(R);
1463                 return NULL;
1464             }
1465             UH2 = modmul(U1, H2, a->curve->p);
1466             freebn(U1);
1467             if (!UH2) {
1468                 freebn(H2);
1469                 freebn(S1);
1470                 freebn(H);
1471                 freebn(R);
1472                 return NULL;
1473             }
1474             H3 = modmul(H2, H, a->curve->p);
1475             freebn(H2);
1476             if (!H3) {
1477                 freebn(UH2);
1478                 freebn(S1);
1479                 freebn(H);
1480                 freebn(R);
1481                 return NULL;
1482             }
1483             R2 = ecf_square(R, a->curve);
1484             if (!R2) {
1485                 freebn(H3);
1486                 freebn(UH2);
1487                 freebn(S1);
1488                 freebn(H);
1489                 freebn(R);
1490                 return NULL;
1491             }
1492             _2UH2 = ecf_double(UH2, a->curve);
1493             if (!_2UH2) {
1494                 freebn(R2);
1495                 freebn(H3);
1496                 freebn(UH2);
1497                 freebn(S1);
1498                 freebn(H);
1499                 freebn(R);
1500                 return NULL;
1501             }
1502             first = modsub(R2, H3, a->curve->p);
1503             freebn(R2);
1504             if (!first) {
1505                 freebn(H3);
1506                 freebn(_2UH2);
1507                 freebn(UH2);
1508                 freebn(S1);
1509                 freebn(H);
1510                 freebn(R);
1511                 return NULL;
1512             }
1513             outx = modsub(first, _2UH2, a->curve->p);
1514             freebn(first);
1515             freebn(_2UH2);
1516             if (!outx) {
1517                 freebn(H3);
1518                 freebn(UH2);
1519                 freebn(S1);
1520                 freebn(H);
1521                 freebn(R);
1522                 return NULL;
1523             }
1524         }
1525
1526         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
1527         {
1528             Bignum RUH2mX, UH2mX, SH3;
1529
1530             UH2mX = modsub(UH2, outx, a->curve->p);
1531             freebn(UH2);
1532             if (!UH2mX) {
1533                 freebn(outx);
1534                 freebn(H3);
1535                 freebn(S1);
1536                 freebn(H);
1537                 freebn(R);
1538                 return NULL;
1539             }
1540             RUH2mX = modmul(R, UH2mX, a->curve->p);
1541             freebn(UH2mX);
1542             freebn(R);
1543             if (!RUH2mX) {
1544                 freebn(outx);
1545                 freebn(H3);
1546                 freebn(S1);
1547                 freebn(H);
1548                 return NULL;
1549             }
1550             SH3 = modmul(S1, H3, a->curve->p);
1551             freebn(S1);
1552             freebn(H3);
1553             if (!SH3) {
1554                 freebn(RUH2mX);
1555                 freebn(outx);
1556                 freebn(H);
1557                 return NULL;
1558             }
1559
1560             outy = modsub(RUH2mX, SH3, a->curve->p);
1561             freebn(RUH2mX);
1562             freebn(SH3);
1563             if (!outy) {
1564                 freebn(outx);
1565                 freebn(H);
1566                 return NULL;
1567             }
1568         }
1569
1570         /* Z3 = H*Z1*Z2 */
1571         if (a->z && b->z) {
1572             Bignum ZZ;
1573
1574             ZZ = modmul(a->z, b->z, a->curve->p);
1575             if (!ZZ) {
1576                 freebn(outx);
1577                 freebn(outy);
1578                 freebn(H);
1579                 return NULL;
1580             }
1581             outz = modmul(H, ZZ, a->curve->p);
1582             freebn(H);
1583             freebn(ZZ);
1584             if (!outz) {
1585                 freebn(outx);
1586                 freebn(outy);
1587                 return NULL;
1588             }
1589         } else if (a->z) {
1590             outz = modmul(H, a->z, a->curve->p);
1591             freebn(H);
1592             if (!outz) {
1593                 freebn(outx);
1594                 freebn(outy);
1595                 return NULL;
1596             }
1597         } else if (b->z) {
1598             outz = modmul(H, b->z, a->curve->p);
1599             freebn(H);
1600             if (!outz) {
1601                 freebn(outx);
1602                 freebn(outy);
1603                 return NULL;
1604             }
1605         } else {
1606             outz = H;
1607         }
1608     }
1609
1610     return ec_point_new(a->curve, outx, outy, outz, 0);
1611 }
1612
1613 static struct ec_point *ecp_addm(const struct ec_point *a,
1614                                  const struct ec_point *b,
1615                                  const struct ec_point *base)
1616 {
1617     Bignum outx, outz, az, bz;
1618
1619     az = a->z;
1620     if (!az) {
1621         az = One;
1622     }
1623     bz = b->z;
1624     if (!bz) {
1625         bz = One;
1626     }
1627
1628     /* a-b is maintained at 1 due to Montgomery ladder implementation */
1629     /* Xa+b = Za-b * ((Xa - Za)*(Xb + Zb) + (Xa + Za)*(Xb - Zb))^2 */
1630     /* Za+b = Xa-b * ((Xa - Za)*(Xb + Zb) - (Xa + Za)*(Xb - Zb))^2 */
1631     {
1632         Bignum tmp, tmp2, tmp3, tmp4;
1633
1634         /* (Xa + Za) * (Xb - Zb) */
1635         tmp = ecf_add(a->x, az, a->curve);
1636         if (!tmp) {
1637             return NULL;
1638         }
1639         tmp2 = modsub(b->x, bz, a->curve->p);
1640         if (!tmp2) {
1641             freebn(tmp);
1642             return NULL;
1643         }
1644         tmp3 = modmul(tmp, tmp2, a->curve->p);
1645         freebn(tmp);
1646         freebn(tmp2);
1647         if (!tmp3) {
1648             return NULL;
1649         }
1650
1651         /* (Xa - Za) * (Xb + Zb) */
1652         tmp = modsub(a->x, az, a->curve->p);
1653         if (!tmp) {
1654             freebn(tmp3);
1655             return NULL;
1656         }
1657         tmp2 = ecf_add(b->x, bz, a->curve);
1658         if (!tmp2) {
1659             freebn(tmp);
1660             freebn(tmp3);
1661             return NULL;
1662         }
1663         tmp4 = modmul(tmp, tmp2, a->curve->p);
1664         freebn(tmp);
1665         freebn(tmp2);
1666         if (!tmp4) {
1667             freebn(tmp3);
1668             return NULL;
1669         }
1670
1671         tmp = ecf_add(tmp3, tmp4, a->curve);
1672         if (!tmp) {
1673             freebn(tmp3);
1674             freebn(tmp4);
1675             return NULL;
1676         }
1677         outx = ecf_square(tmp, a->curve);
1678         freebn(tmp);
1679         if (!outx) {
1680             freebn(tmp3);
1681             freebn(tmp4);
1682             return NULL;
1683         }
1684
1685         tmp = modsub(tmp3, tmp4, a->curve->p);
1686         freebn(tmp3);
1687         freebn(tmp4);
1688         if (!tmp) {
1689             freebn(outx);
1690             return NULL;
1691         }
1692         tmp2 = ecf_square(tmp, a->curve);
1693         freebn(tmp);
1694         if (!tmp2) {
1695             freebn(outx);
1696             return NULL;
1697         }
1698         outz = modmul(base->x, tmp2, a->curve->p);
1699         freebn(tmp2);
1700         if (!outz) {
1701             freebn(outx);
1702             return NULL;
1703         }
1704     }
1705
1706     return ec_point_new(a->curve, outx, NULL, outz, 0);
1707 }
1708
1709 static struct ec_point *ecp_adde(const struct ec_point *a,
1710                                  const struct ec_point *b)
1711 {
1712     Bignum outx, outy, dmul;
1713
1714     /* outx = (a->x * b->y + b->x * a->y) /
1715      *        (1 + a->curve->e.d * a->x * b->x * a->y * b->y) */
1716     {
1717         Bignum tmp, tmp2, tmp3, tmp4;
1718
1719         tmp = modmul(a->x, b->y, a->curve->p);
1720         if (!tmp)
1721         {
1722             return NULL;
1723         }
1724         tmp2 = modmul(b->x, a->y, a->curve->p);
1725         if (!tmp2)
1726         {
1727             freebn(tmp);
1728             return NULL;
1729         }
1730         tmp3 = ecf_add(tmp, tmp2, a->curve);
1731         if (!tmp3)
1732         {
1733             freebn(tmp);
1734             freebn(tmp2);
1735             return NULL;
1736         }
1737
1738         tmp4 = modmul(tmp, tmp2, a->curve->p);
1739         freebn(tmp);
1740         freebn(tmp2);
1741         if (!tmp4)
1742         {
1743             freebn(tmp3);
1744             return NULL;
1745         }
1746         dmul = modmul(a->curve->e.d, tmp4, a->curve->p);
1747         freebn(tmp4);
1748         if (!dmul) {
1749             freebn(tmp3);
1750             return NULL;
1751         }
1752
1753         tmp = ecf_add(One, dmul, a->curve);
1754         if (!tmp)
1755         {
1756             freebn(tmp3);
1757             freebn(dmul);
1758             return NULL;
1759         }
1760         tmp2 = modinv(tmp, a->curve->p);
1761         freebn(tmp);
1762         if (!tmp2)
1763         {
1764             freebn(tmp3);
1765             freebn(dmul);
1766             return NULL;
1767         }
1768
1769         outx = modmul(tmp3, tmp2, a->curve->p);
1770         freebn(tmp3);
1771         freebn(tmp2);
1772         if (!outx)
1773         {
1774             freebn(dmul);
1775             return NULL;
1776         }
1777     }
1778
1779     /* outy = (a->y * b->y + a->x * b->x) /
1780      *        (1 - a->curve->e.d * a->x * b->x * a->y * b->y) */
1781     {
1782         Bignum tmp, tmp2, tmp3, tmp4;
1783
1784         tmp = modsub(One, dmul, a->curve->p);
1785         freebn(dmul);
1786         if (!tmp)
1787         {
1788             freebn(outx);
1789             return NULL;
1790         }
1791
1792         tmp2 = modinv(tmp, a->curve->p);
1793         freebn(tmp);
1794         if (!tmp2)
1795         {
1796             freebn(outx);
1797             return NULL;
1798         }
1799
1800         tmp = modmul(a->y, b->y, a->curve->p);
1801         if (!tmp)
1802         {
1803             freebn(tmp2);
1804             freebn(outx);
1805             return NULL;
1806         }
1807         tmp3 = modmul(a->x, b->x, a->curve->p);
1808         if (!tmp3)
1809         {
1810             freebn(tmp);
1811             freebn(tmp2);
1812             freebn(outx);
1813             return NULL;
1814         }
1815         tmp4 = ecf_add(tmp, tmp3, a->curve);
1816         freebn(tmp);
1817         freebn(tmp3);
1818         if (!tmp4)
1819         {
1820             freebn(tmp2);
1821             freebn(outx);
1822             return NULL;
1823         }
1824
1825         outy = modmul(tmp4, tmp2, a->curve->p);
1826         freebn(tmp4);
1827         freebn(tmp2);
1828         if (!outy)
1829         {
1830             freebn(outx);
1831             return NULL;
1832         }
1833     }
1834
1835     return ec_point_new(a->curve, outx, outy, NULL, 0);
1836 }
1837
1838 static struct ec_point *ecp_add(const struct ec_point *a,
1839                                 const struct ec_point *b,
1840                                 const int aminus3)
1841 {
1842     if (a->curve != b->curve) {
1843         return NULL;
1844     }
1845
1846     /* Check if multiplying by infinity */
1847     if (a->infinity) return ec_point_copy(b);
1848     if (b->infinity) return ec_point_copy(a);
1849
1850     if (a->curve->type == EC_EDWARDS)
1851     {
1852         return ecp_adde(a, b);
1853     }
1854
1855     if (a->curve->type == EC_WEIERSTRASS)
1856     {
1857         return ecp_addw(a, b, aminus3);
1858     }
1859
1860     return NULL;
1861 }
1862
1863 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1864 {
1865     struct ec_point *A, *ret;
1866     int bits, i;
1867
1868     A = ec_point_copy(a);
1869     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1870
1871     bits = bignum_bitcount(b);
1872     for (i = 0; ret != NULL && A != NULL && i < bits; ++i)
1873     {
1874         if (bignum_bit(b, i))
1875         {
1876             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1877             ec_point_free(ret);
1878             ret = tmp;
1879         }
1880         if (i+1 != bits)
1881         {
1882             struct ec_point *tmp = ecp_double(A, aminus3);
1883             ec_point_free(A);
1884             A = tmp;
1885         }
1886     }
1887
1888     if (!A) {
1889         ec_point_free(ret);
1890         ret = NULL;
1891     } else {
1892         ec_point_free(A);
1893     }
1894
1895     return ret;
1896 }
1897
1898 static struct ec_point *ecp_mulw(const struct ec_point *a, const Bignum b)
1899 {
1900     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1901
1902     if (!ecp_normalise(ret)) {
1903         ec_point_free(ret);
1904         return NULL;
1905     }
1906
1907     return ret;
1908 }
1909
1910 static struct ec_point *ecp_mule(const struct ec_point *a, const Bignum b)
1911 {
1912     int i;
1913     struct ec_point *ret;
1914
1915     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1916
1917     for (i = bignum_bitcount(b); i >= 0 && ret; --i)
1918     {
1919         {
1920             struct ec_point *tmp = ecp_double(ret, 0);
1921             ec_point_free(ret);
1922             ret = tmp;
1923         }
1924         if (ret && bignum_bit(b, i))
1925         {
1926             struct ec_point *tmp = ecp_add(ret, a, 0);
1927             ec_point_free(ret);
1928             ret = tmp;
1929         }
1930     }
1931
1932     return ret;
1933 }
1934
1935 static struct ec_point *ecp_mulm(const struct ec_point *p, const Bignum n)
1936 {
1937     struct ec_point *P1, *P2;
1938     int bits, i;
1939
1940     /* P1 <- P and P2 <- [2]P */
1941     P2 = ecp_double(p, 0);
1942     if (!P2) {
1943         return NULL;
1944     }
1945     P1 = ec_point_copy(p);
1946     if (!P1) {
1947         ec_point_free(P2);
1948         return NULL;
1949     }
1950
1951     /* for i = bits âˆ’ 2 down to 0 */
1952     bits = bignum_bitcount(n);
1953     for (i = bits - 2; P1 != NULL && P2 != NULL && i >= 0; --i)
1954     {
1955         if (!bignum_bit(n, i))
1956         {
1957             /* P2 <- P1 + P2 */
1958             struct ec_point *tmp = ecp_addm(P1, P2, p);
1959             ec_point_free(P2);
1960             P2 = tmp;
1961
1962             /* P1 <- [2]P1 */
1963             tmp = ecp_double(P1, 0);
1964             ec_point_free(P1);
1965             P1 = tmp;
1966         }
1967         else
1968         {
1969             /* P1 <- P1 + P2 */
1970             struct ec_point *tmp = ecp_addm(P1, P2, p);
1971             ec_point_free(P1);
1972             P1 = tmp;
1973
1974             /* P2 <- [2]P2 */
1975             tmp = ecp_double(P2, 0);
1976             ec_point_free(P2);
1977             P2 = tmp;
1978         }
1979     }
1980
1981     if (!P2) {
1982         if (P1) ec_point_free(P1);
1983         P1 = NULL;
1984     } else {
1985         ec_point_free(P2);
1986     }
1987
1988     if (!ecp_normalise(P1)) {
1989         ec_point_free(P1);
1990         return NULL;
1991     }
1992
1993     return P1;
1994 }
1995
1996 /* Not static because it is used by sshecdsag.c to generate a new key */
1997 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1998 {
1999     if (a->curve->type == EC_WEIERSTRASS) {
2000         return ecp_mulw(a, b);
2001     } else if (a->curve->type == EC_EDWARDS) {
2002         return ecp_mule(a, b);
2003     } else {
2004         return ecp_mulm(a, b);
2005     }
2006 }
2007
2008 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
2009                                    const struct ec_point *point)
2010 {
2011     struct ec_point *aG, *bP, *ret;
2012     int aminus3;
2013
2014     if (point->curve->type != EC_WEIERSTRASS) {
2015         return NULL;
2016     }
2017
2018     aminus3 = ec_aminus3(point->curve);
2019
2020     aG = ecp_mul_(&point->curve->w.G, a, aminus3);
2021     if (!aG) return NULL;
2022     bP = ecp_mul_(point, b, aminus3);
2023     if (!bP) {
2024         ec_point_free(aG);
2025         return NULL;
2026     }
2027
2028     ret = ecp_add(aG, bP, aminus3);
2029
2030     ec_point_free(aG);
2031     ec_point_free(bP);
2032
2033     if (!ecp_normalise(ret)) {
2034         ec_point_free(ret);
2035         return NULL;
2036     }
2037
2038     return ret;
2039 }
2040 static Bignum *ecp_edx(const struct ec_curve *curve, const Bignum y)
2041 {
2042     /* Get the x value on the given Edwards curve for a given y */
2043     Bignum x, xx;
2044
2045     /* xx = (y^2 - 1) / (d * y^2 + 1) */
2046     {
2047         Bignum tmp, tmp2, tmp3;
2048
2049         tmp = ecf_square(y, curve);
2050         if (!tmp) {
2051             return NULL;
2052         }
2053         tmp2 = modmul(curve->e.d, tmp, curve->p);
2054         if (!tmp2) {
2055             freebn(tmp);
2056             return NULL;
2057         }
2058         tmp3 = ecf_add(tmp2, One, curve);
2059         freebn(tmp2);
2060         if (!tmp3) {
2061             freebn(tmp);
2062             return NULL;
2063         }
2064         tmp2 = modinv(tmp3, curve->p);
2065         freebn(tmp3);
2066         if (!tmp2) {
2067             freebn(tmp);
2068             return NULL;
2069         }
2070
2071         tmp3 = modsub(tmp, One, curve->p);
2072         freebn(tmp);
2073         if (!tmp3) {
2074             freebn(tmp2);
2075             return NULL;
2076         }
2077         xx = modmul(tmp3, tmp2, curve->p);
2078         freebn(tmp3);
2079         freebn(tmp2);
2080         if (!xx) {
2081             return NULL;
2082         }
2083     }
2084
2085     /* x = xx^((p + 3) / 8) */
2086     {
2087         Bignum tmp, tmp2;
2088
2089         tmp = bignum_add_long(curve->p, 3);
2090         if (!tmp) {
2091             freebn(xx);
2092             return NULL;
2093         }
2094         tmp2 = bignum_rshift(tmp, 3);
2095         freebn(tmp);
2096         if (!tmp2) {
2097             freebn(xx);
2098             return NULL;
2099         }
2100         x = modpow(xx, tmp2, curve->p);
2101         freebn(tmp2);
2102         if (!x) {
2103             freebn(xx);
2104             return NULL;
2105         }
2106     }
2107
2108     /* if x^2 - xx != 0 then x = x*(2^((p - 1) / 4)) */
2109     {
2110         Bignum tmp, tmp2;
2111
2112         tmp = ecf_square(x, curve);
2113         if (!tmp) {
2114             freebn(x);
2115             freebn(xx);
2116             return NULL;
2117         }
2118         tmp2 = modsub(tmp, xx, curve->p);
2119         freebn(tmp);
2120         freebn(xx);
2121         if (!tmp2) {
2122             freebn(x);
2123             return NULL;
2124         }
2125         if (bignum_cmp(tmp2, Zero)) {
2126             Bignum tmp3;
2127
2128             freebn(tmp2);
2129
2130             tmp = modsub(curve->p, One, curve->p);
2131             if (!tmp) {
2132                 freebn(x);
2133                 return NULL;
2134             }
2135             tmp2 = bignum_rshift(tmp, 2);
2136             freebn(tmp);
2137             if (!tmp2) {
2138                 freebn(x);
2139                 return NULL;
2140             }
2141             tmp = bignum_from_long(2);
2142             if (!tmp) {
2143                 freebn(tmp2);
2144                 freebn(x);
2145                 return NULL;
2146             }
2147             tmp3 = modpow(tmp, tmp2, curve->p);
2148             freebn(tmp);
2149             freebn(tmp2);
2150             if (!tmp3) {
2151                 freebn(x);
2152                 return NULL;
2153             }
2154
2155             tmp = modmul(x, tmp3, curve->p);
2156             freebn(x);
2157             freebn(tmp3);
2158             x = tmp;
2159             if (!tmp) {
2160                 return NULL;
2161             }
2162         } else {
2163             freebn(tmp2);
2164         }
2165     }
2166
2167     /* if x % 2 != 0 then x = p - x */
2168     if (bignum_bit(x, 0)) {
2169         Bignum tmp = modsub(curve->p, x, curve->p);
2170         freebn(x);
2171         x = tmp;
2172         if (!tmp) {
2173             return NULL;
2174         }
2175     }
2176
2177     return x;
2178 }
2179
2180 /* ----------------------------------------------------------------------
2181  * Public point from private
2182  */
2183
2184 struct ec_point *ec_public(const Bignum privateKey, const struct ec_curve *curve)
2185 {
2186     if (curve->type == EC_WEIERSTRASS) {
2187         return ecp_mul(&curve->w.G, privateKey);
2188     } else if (curve->type == EC_EDWARDS) {
2189         /* hash = H(sk) (where hash creates 2 * fieldBits)
2190          * b = fieldBits
2191          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
2192          * publicKey = aB */
2193         struct ec_point *ret;
2194         unsigned char hash[512/8];
2195         Bignum a;
2196         int i, keylen;
2197         SHA512_State s;
2198         SHA512_Init(&s);
2199
2200         keylen = curve->fieldBits / 8;
2201         for (i = 0; i < keylen; ++i) {
2202             unsigned char b = bignum_byte(privateKey, i);
2203             SHA512_Bytes(&s, &b, 1);
2204         }
2205         SHA512_Final(&s, hash);
2206
2207         /* The second part is simply turning the hash into a Bignum, however
2208          * the 2^(b-2) bit *must* be set, and the bottom 2 bits *must* not be */
2209         hash[0] &= 0xfc; /* Unset bottom two bits (if set) */
2210         hash[31] &= 0x7f; /* Unset above (b-2) */
2211         hash[31] |= 0x40; /* Set 2^(b-2) */
2212         /* Chop off the top part and convert to int */
2213         a = bignum_from_bytes_le(hash, 32);
2214         if (!a) {
2215             return NULL;
2216         }
2217
2218         ret = ecp_mul(&curve->e.B, a);
2219         freebn(a);
2220         return ret;
2221     } else {
2222         return NULL;
2223     }
2224 }
2225
2226 /* ----------------------------------------------------------------------
2227  * Basic sign and verify routines
2228  */
2229
2230 static int _ecdsa_verify(const struct ec_point *publicKey,
2231                          const unsigned char *data, const int dataLen,
2232                          const Bignum r, const Bignum s)
2233 {
2234     int z_bits, n_bits;
2235     Bignum z;
2236     int valid = 0;
2237
2238     if (publicKey->curve->type != EC_WEIERSTRASS) {
2239         return 0;
2240     }
2241
2242     /* Sanity checks */
2243     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->w.n) >= 0
2244         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->w.n) >= 0)
2245     {
2246         return 0;
2247     }
2248
2249     /* z = left most bitlen(curve->n) of data */
2250     z = bignum_from_bytes(data, dataLen);
2251     if (!z) return 0;
2252     n_bits = bignum_bitcount(publicKey->curve->w.n);
2253     z_bits = bignum_bitcount(z);
2254     if (z_bits > n_bits)
2255     {
2256         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
2257         freebn(z);
2258         z = tmp;
2259         if (!z) return 0;
2260     }
2261
2262     /* Ensure z in range of n */
2263     {
2264         Bignum tmp = bigmod(z, publicKey->curve->w.n);
2265         freebn(z);
2266         z = tmp;
2267         if (!z) return 0;
2268     }
2269
2270     /* Calculate signature */
2271     {
2272         Bignum w, x, u1, u2;
2273         struct ec_point *tmp;
2274
2275         w = modinv(s, publicKey->curve->w.n);
2276         if (!w) {
2277             freebn(z);
2278             return 0;
2279         }
2280         u1 = modmul(z, w, publicKey->curve->w.n);
2281         if (!u1) {
2282             freebn(z);
2283             freebn(w);
2284             return 0;
2285         }
2286         u2 = modmul(r, w, publicKey->curve->w.n);
2287         freebn(w);
2288         if (!u2) {
2289             freebn(z);
2290             freebn(u1);
2291             return 0;
2292         }
2293
2294         tmp = ecp_summul(u1, u2, publicKey);
2295         freebn(u1);
2296         freebn(u2);
2297         if (!tmp) {
2298             freebn(z);
2299             return 0;
2300         }
2301
2302         x = bigmod(tmp->x, publicKey->curve->w.n);
2303         ec_point_free(tmp);
2304         if (!x) {
2305             freebn(z);
2306             return 0;
2307         }
2308
2309         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
2310         freebn(x);
2311     }
2312
2313     freebn(z);
2314
2315     return valid;
2316 }
2317
2318 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
2319                         const unsigned char *data, const int dataLen,
2320                         Bignum *r, Bignum *s)
2321 {
2322     unsigned char digest[20];
2323     int z_bits, n_bits;
2324     Bignum z, k;
2325     struct ec_point *kG;
2326
2327     *r = NULL;
2328     *s = NULL;
2329
2330     if (curve->type != EC_WEIERSTRASS) {
2331         return;
2332     }
2333
2334     /* z = left most bitlen(curve->n) of data */
2335     z = bignum_from_bytes(data, dataLen);
2336     if (!z) return;
2337     n_bits = bignum_bitcount(curve->w.n);
2338     z_bits = bignum_bitcount(z);
2339     if (z_bits > n_bits)
2340     {
2341         Bignum tmp;
2342         tmp = bignum_rshift(z, z_bits - n_bits);
2343         freebn(z);
2344         z = tmp;
2345         if (!z) return;
2346     }
2347
2348     /* Generate k between 1 and curve->n, using the same deterministic
2349      * k generation system we use for conventional DSA. */
2350     SHA_Simple(data, dataLen, digest);
2351     k = dss_gen_k("ECDSA deterministic k generator", curve->w.n, privateKey,
2352                   digest, sizeof(digest));
2353     if (!k) return;
2354
2355     kG = ecp_mul(&curve->w.G, k);
2356     if (!kG) {
2357         freebn(z);
2358         freebn(k);
2359         return;
2360     }
2361
2362     /* r = kG.x mod n */
2363     *r = bigmod(kG->x, curve->w.n);
2364     ec_point_free(kG);
2365     if (!*r) {
2366         freebn(z);
2367         freebn(k);
2368         return;
2369     }
2370
2371     /* s = (z + r * priv)/k mod n */
2372     {
2373         Bignum rPriv, zMod, first, firstMod, kInv;
2374         rPriv = modmul(*r, privateKey, curve->w.n);
2375         if (!rPriv) {
2376             freebn(*r);
2377             freebn(z);
2378             freebn(k);
2379             return;
2380         }
2381         zMod = bigmod(z, curve->w.n);
2382         freebn(z);
2383         if (!zMod) {
2384             freebn(rPriv);
2385             freebn(*r);
2386             freebn(k);
2387             return;
2388         }
2389         first = bigadd(rPriv, zMod);
2390         freebn(rPriv);
2391         freebn(zMod);
2392         if (!first) {
2393             freebn(*r);
2394             freebn(k);
2395             return;
2396         }
2397         firstMod = bigmod(first, curve->w.n);
2398         freebn(first);
2399         if (!firstMod) {
2400             freebn(*r);
2401             freebn(k);
2402             return;
2403         }
2404         kInv = modinv(k, curve->w.n);
2405         freebn(k);
2406         if (!kInv) {
2407             freebn(firstMod);
2408             freebn(*r);
2409             return;
2410         }
2411         *s = modmul(firstMod, kInv, curve->w.n);
2412         freebn(firstMod);
2413         freebn(kInv);
2414         if (!*s) {
2415             freebn(*r);
2416             return;
2417         }
2418     }
2419 }
2420
2421 /* ----------------------------------------------------------------------
2422  * Misc functions
2423  */
2424
2425 static void getstring(const char **data, int *datalen,
2426                       const char **p, int *length)
2427 {
2428     *p = NULL;
2429     if (*datalen < 4)
2430         return;
2431     *length = toint(GET_32BIT(*data));
2432     if (*length < 0)
2433         return;
2434     *datalen -= 4;
2435     *data += 4;
2436     if (*datalen < *length)
2437         return;
2438     *p = *data;
2439     *data += *length;
2440     *datalen -= *length;
2441 }
2442
2443 static Bignum getmp(const char **data, int *datalen)
2444 {
2445     const char *p;
2446     int length;
2447
2448     getstring(data, datalen, &p, &length);
2449     if (!p)
2450         return NULL;
2451     if (p[0] & 0x80)
2452         return NULL;                   /* negative mp */
2453     return bignum_from_bytes((unsigned char *)p, length);
2454 }
2455
2456 static Bignum getmp_le(const char **data, int *datalen)
2457 {
2458     const char *p;
2459     int length;
2460
2461     getstring(data, datalen, &p, &length);
2462     if (!p)
2463         return NULL;
2464     return bignum_from_bytes_le((const unsigned char *)p, length);
2465 }
2466
2467 static int decodepoint_ed(const char *p, int length, struct ec_point *point)
2468 {
2469     /* Got some conversion to do, first read in the y co-ord */
2470     int negative;
2471
2472     point->y = bignum_from_bytes_le((const unsigned char*)p, length);
2473     if (!point->y) {
2474         return 0;
2475     }
2476     if ((unsigned)bignum_bitcount(point->y) > point->curve->fieldBits) {
2477         freebn(point->y);
2478         point->y = NULL;
2479         return 0;
2480     }
2481     /* Read x bit and then reset it */
2482     negative = bignum_bit(point->y, point->curve->fieldBits - 1);
2483     bignum_set_bit(point->y, point->curve->fieldBits - 1, 0);
2484
2485     /* Get the x from the y */
2486     point->x = ecp_edx(point->curve, point->y);
2487     if (!point->x) {
2488         freebn(point->y);
2489         point->y = NULL;
2490         return 0;
2491     }
2492     if (negative) {
2493         Bignum tmp = modsub(point->curve->p, point->x, point->curve->p);
2494         freebn(point->x);
2495         point->x = tmp;
2496         if (!tmp) {
2497             freebn(point->y);
2498             point->y = NULL;
2499             return 0;
2500         }
2501     }
2502
2503     /* Verify the point is on the curve */
2504     if (!ec_point_verify(point)) {
2505         freebn(point->x);
2506         point->x = NULL;
2507         freebn(point->y);
2508         point->y = NULL;
2509         return 0;
2510     }
2511
2512     return 1;
2513 }
2514
2515 static int decodepoint(const char *p, int length, struct ec_point *point)
2516 {
2517     if (point->curve->type == EC_EDWARDS) {
2518         return decodepoint_ed(p, length, point);
2519     }
2520
2521     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
2522         return 0;
2523     /* Skip compression flag */
2524     ++p;
2525     --length;
2526     /* The two values must be equal length */
2527     if (length % 2 != 0) {
2528         point->x = NULL;
2529         point->y = NULL;
2530         point->z = NULL;
2531         return 0;
2532     }
2533     length = length / 2;
2534     point->x = bignum_from_bytes((const unsigned char *)p, length);
2535     if (!point->x) return 0;
2536     p += length;
2537     point->y = bignum_from_bytes((const unsigned char *)p, length);
2538     if (!point->y) {
2539         freebn(point->x);
2540         point->x = NULL;
2541         return 0;
2542     }
2543     point->z = NULL;
2544
2545     /* Verify the point is on the curve */
2546     if (!ec_point_verify(point)) {
2547         freebn(point->x);
2548         point->x = NULL;
2549         freebn(point->y);
2550         point->y = NULL;
2551         return 0;
2552     }
2553
2554     return 1;
2555 }
2556
2557 static int getmppoint(const char **data, int *datalen, struct ec_point *point)
2558 {
2559     const char *p;
2560     int length;
2561
2562     getstring(data, datalen, &p, &length);
2563     if (!p) return 0;
2564     return decodepoint(p, length, point);
2565 }
2566
2567 /* ----------------------------------------------------------------------
2568  * Exposed ECDSA interface
2569  */
2570
2571 static void ecdsa_freekey(void *key)
2572 {
2573     struct ec_key *ec = (struct ec_key *) key;
2574     if (!ec) return;
2575
2576     if (ec->publicKey.x)
2577         freebn(ec->publicKey.x);
2578     if (ec->publicKey.y)
2579         freebn(ec->publicKey.y);
2580     if (ec->publicKey.z)
2581         freebn(ec->publicKey.z);
2582     if (ec->privateKey)
2583         freebn(ec->privateKey);
2584     sfree(ec);
2585 }
2586
2587 static void *ecdsa_newkey(const char *data, int len)
2588 {
2589     const char *p;
2590     int slen;
2591     struct ec_key *ec;
2592     struct ec_curve *curve;
2593
2594     getstring(&data, &len, &p, &slen);
2595
2596     if (!p) {
2597         return NULL;
2598     }
2599     curve = ec_name_to_curve(p, slen);
2600     if (!curve) return NULL;
2601
2602     if (curve->type != EC_WEIERSTRASS && curve->type != EC_EDWARDS) {
2603         return NULL;
2604     }
2605
2606     /* Curve name is duplicated for Weierstrass form */
2607     if (curve->type == EC_WEIERSTRASS) {
2608         getstring(&data, &len, &p, &slen);
2609         if (curve != ec_name_to_curve(p, slen)) return NULL;
2610     }
2611
2612     ec = snew(struct ec_key);
2613
2614     ec->publicKey.curve = curve;
2615     ec->publicKey.infinity = 0;
2616     ec->publicKey.x = NULL;
2617     ec->publicKey.y = NULL;
2618     ec->publicKey.z = NULL;
2619     if (!getmppoint(&data, &len, &ec->publicKey)) {
2620         ecdsa_freekey(ec);
2621         return NULL;
2622     }
2623     ec->privateKey = NULL;
2624
2625     if (!ec->publicKey.x || !ec->publicKey.y ||
2626         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2627         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2628     {
2629         ecdsa_freekey(ec);
2630         ec = NULL;
2631     }
2632
2633     return ec;
2634 }
2635
2636 static char *ecdsa_fmtkey(void *key)
2637 {
2638     struct ec_key *ec = (struct ec_key *) key;
2639     char *p;
2640     int len, i, pos, nibbles;
2641     static const char hex[] = "0123456789abcdef";
2642     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2643         return NULL;
2644
2645     pos = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
2646     if (pos == 0) return NULL;
2647
2648     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
2649     len += pos; /* Curve name */
2650     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
2651     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
2652     p = snewn(len, char);
2653
2654     pos = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, (unsigned char*)p, pos);
2655     pos += sprintf(p + pos, ",0x");
2656     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
2657     if (nibbles < 1)
2658         nibbles = 1;
2659     for (i = nibbles; i--;) {
2660         p[pos++] =
2661             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
2662     }
2663     pos += sprintf(p + pos, ",0x");
2664     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
2665     if (nibbles < 1)
2666         nibbles = 1;
2667     for (i = nibbles; i--;) {
2668         p[pos++] =
2669             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
2670     }
2671     p[pos] = '\0';
2672     return p;
2673 }
2674
2675 static unsigned char *ecdsa_public_blob(void *key, int *len)
2676 {
2677     struct ec_key *ec = (struct ec_key *) key;
2678     int pointlen, bloblen, fullnamelen, namelen;
2679     int i;
2680     unsigned char *blob, *p;
2681
2682     if (ec->publicKey.curve->type == EC_EDWARDS) {
2683         /* Edwards compressed form "ssh-ed25519" point y[:-1] + x[0:1] */
2684         fullnamelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
2685         if (fullnamelen == 0) return NULL;
2686
2687         pointlen = ec->publicKey.curve->fieldBits / 8;
2688
2689         /* Can't handle this in our loop */
2690         if (pointlen < 2) return NULL;
2691
2692         bloblen = 4 + fullnamelen + 4 + pointlen;
2693         blob = snewn(bloblen, unsigned char);
2694         if (!blob) return NULL;
2695
2696         p = blob;
2697         PUT_32BIT(p, fullnamelen);
2698         p += 4;
2699         p += ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, p, fullnamelen);
2700         PUT_32BIT(p, pointlen);
2701         p += 4;
2702
2703         /* Unset last bit of y and set first bit of x in its place */
2704         for (i = 0; i < pointlen - 1; ++i) {
2705             *p++ = bignum_byte(ec->publicKey.y, i);
2706         }
2707         /* Unset last bit of y and set first bit of x in its place */
2708         *p = bignum_byte(ec->publicKey.y, i) & 0x7f;
2709         *p++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2710     } else if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2711         fullnamelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
2712         if (fullnamelen == 0) return NULL;
2713         namelen = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
2714         if (namelen == 0) return NULL;
2715
2716         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2717
2718         /*
2719          * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
2720          */
2721         bloblen = 4 + fullnamelen + 4 + namelen + 4 + 1 + (pointlen * 2);
2722         blob = snewn(bloblen, unsigned char);
2723
2724         p = blob;
2725         PUT_32BIT(p, fullnamelen);
2726         p += 4;
2727         p += ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, p, fullnamelen);
2728         PUT_32BIT(p, namelen);
2729         p += 4;
2730         p += ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, p, namelen);
2731         PUT_32BIT(p, (2 * pointlen) + 1);
2732         p += 4;
2733         *p++ = 0x04;
2734         for (i = pointlen; i--;) {
2735             *p++ = bignum_byte(ec->publicKey.x, i);
2736         }
2737         for (i = pointlen; i--;) {
2738             *p++ = bignum_byte(ec->publicKey.y, i);
2739         }
2740     } else {
2741         return NULL;
2742     }
2743
2744     assert(p == blob + bloblen);
2745     *len = bloblen;
2746
2747     return blob;
2748 }
2749
2750 static unsigned char *ecdsa_private_blob(void *key, int *len)
2751 {
2752     struct ec_key *ec = (struct ec_key *) key;
2753     int keylen, bloblen;
2754     int i;
2755     unsigned char *blob, *p;
2756
2757     if (!ec->privateKey) return NULL;
2758
2759     if (ec->publicKey.curve->type == EC_EDWARDS) {
2760         /* Unsigned */
2761         keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2762     } else {
2763         /* Signed */
2764         keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2765     }
2766
2767     /*
2768      * mpint privateKey. Total 4 + keylen.
2769      */
2770     bloblen = 4 + keylen;
2771     blob = snewn(bloblen, unsigned char);
2772
2773     p = blob;
2774     PUT_32BIT(p, keylen);
2775     p += 4;
2776     if (ec->publicKey.curve->type == EC_EDWARDS) {
2777         /* Little endian */
2778         for (i = 0; i < keylen; ++i)
2779             *p++ = bignum_byte(ec->privateKey, i);
2780     } else {
2781         for (i = keylen; i--;)
2782             *p++ = bignum_byte(ec->privateKey, i);
2783     }
2784
2785     assert(p == blob + bloblen);
2786     *len = bloblen;
2787     return blob;
2788 }
2789
2790 static void *ecdsa_createkey(const unsigned char *pub_blob, int pub_len,
2791                              const unsigned char *priv_blob, int priv_len)
2792 {
2793     struct ec_key *ec;
2794     struct ec_point *publicKey;
2795     const char *pb = (const char *) priv_blob;
2796
2797     ec = (struct ec_key*)ecdsa_newkey((const char *) pub_blob, pub_len);
2798     if (!ec) {
2799         return NULL;
2800     }
2801
2802     if (ec->publicKey.curve->type != EC_WEIERSTRASS
2803         && ec->publicKey.curve->type != EC_EDWARDS) {
2804         ecdsa_freekey(ec);
2805         return NULL;
2806     }
2807
2808     if (ec->publicKey.curve->type == EC_EDWARDS) {
2809         ec->privateKey = getmp_le(&pb, &priv_len);
2810     } else {
2811         ec->privateKey = getmp(&pb, &priv_len);
2812     }
2813     if (!ec->privateKey) {
2814         ecdsa_freekey(ec);
2815         return NULL;
2816     }
2817
2818     /* Check that private key generates public key */
2819     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2820
2821     if (!publicKey ||
2822         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2823         bignum_cmp(publicKey->y, ec->publicKey.y))
2824     {
2825         ecdsa_freekey(ec);
2826         ec = NULL;
2827     }
2828     ec_point_free(publicKey);
2829
2830     return ec;
2831 }
2832
2833 static void *ed25519_openssh_createkey(const unsigned char **blob, int *len)
2834 {
2835     struct ec_key *ec;
2836     struct ec_point *publicKey;
2837     const char *p, *q;
2838     int plen, qlen;
2839
2840     getstring((const char**)blob, len, &p, &plen);
2841     if (!p)
2842     {
2843         return NULL;
2844     }
2845
2846     ec = snew(struct ec_key);
2847     if (!ec)
2848     {
2849         return NULL;
2850     }
2851
2852     ec->publicKey.curve = ec_ed25519();
2853     ec->publicKey.infinity = 0;
2854     ec->privateKey = NULL;
2855     ec->publicKey.x = NULL;
2856     ec->publicKey.z = NULL;
2857     ec->publicKey.y = NULL;
2858
2859     if (!decodepoint_ed(p, plen, &ec->publicKey))
2860     {
2861         ecdsa_freekey(ec);
2862         return NULL;
2863     }
2864
2865     getstring((const char**)blob, len, &q, &qlen);
2866     if (!q)
2867         return NULL;
2868     if (qlen != 64)
2869         return NULL;
2870
2871     ec->privateKey = bignum_from_bytes_le((const unsigned char *)q, 32);
2872     if (!ec->privateKey) {
2873         ecdsa_freekey(ec);
2874         return NULL;
2875     }
2876
2877     /* Check that private key generates public key */
2878     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2879
2880     if (!publicKey ||
2881         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2882         bignum_cmp(publicKey->y, ec->publicKey.y))
2883     {
2884         ecdsa_freekey(ec);
2885         ec = NULL;
2886     }
2887     ec_point_free(publicKey);
2888
2889     /* The OpenSSH format for ed25519 private keys also for some
2890      * reason encodes an extra copy of the public key in the second
2891      * half of the secret-key string. Check that that's present and
2892      * correct as well, otherwise the key we think we've imported
2893      * won't behave identically to the way OpenSSH would have treated
2894      * it. */
2895     if (plen != 32 || 0 != memcmp(q + 32, p, 32)) {
2896         ecdsa_freekey(ec);
2897         return NULL;
2898     }
2899
2900     return ec;
2901 }
2902
2903 static int ed25519_openssh_fmtkey(void *key, unsigned char *blob, int len)
2904 {
2905     struct ec_key *ec = (struct ec_key *) key;
2906
2907     int pointlen;
2908     int keylen;
2909     int bloblen;
2910     int i;
2911
2912     if (ec->publicKey.curve->type != EC_EDWARDS) {
2913         return 0;
2914     }
2915
2916     pointlen = (bignum_bitcount(ec->publicKey.y) + 7) / 8;
2917     keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2918     bloblen = 4 + pointlen + 4 + keylen + pointlen;
2919
2920     if (bloblen > len)
2921         return bloblen;
2922
2923     /* Encode the public point */
2924     PUT_32BIT(blob, pointlen);
2925     blob += 4;
2926
2927     for (i = 0; i < pointlen - 1; ++i) {
2928          *blob++ = bignum_byte(ec->publicKey.y, i);
2929     }
2930     /* Unset last bit of y and set first bit of x in its place */
2931     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2932     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2933
2934     PUT_32BIT(blob, keylen + pointlen);
2935     blob += 4;
2936     for (i = 0; i < keylen; ++i) {
2937          *blob++ = bignum_byte(ec->privateKey, i);
2938     }
2939     /* Now encode an extra copy of the public point as the second half
2940      * of the private key string, as the OpenSSH format for some
2941      * reason requires */
2942     for (i = 0; i < pointlen - 1; ++i) {
2943          *blob++ = bignum_byte(ec->publicKey.y, i);
2944     }
2945     /* Unset last bit of y and set first bit of x in its place */
2946     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2947     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2948
2949     return bloblen;
2950 }
2951
2952 static void *ecdsa_openssh_createkey(const unsigned char **blob, int *len)
2953 {
2954     const char **b = (const char **) blob;
2955     const char *p;
2956     int slen;
2957     struct ec_key *ec;
2958     struct ec_curve *curve;
2959     struct ec_point *publicKey;
2960
2961     getstring(b, len, &p, &slen);
2962
2963     if (!p) {
2964         return NULL;
2965     }
2966     curve = ec_name_to_curve(p, slen);
2967     if (!curve) return NULL;
2968
2969     if (curve->type != EC_WEIERSTRASS) {
2970         return NULL;
2971     }
2972
2973     ec = snew(struct ec_key);
2974
2975     ec->publicKey.curve = curve;
2976     ec->publicKey.infinity = 0;
2977     ec->publicKey.x = NULL;
2978     ec->publicKey.y = NULL;
2979     ec->publicKey.z = NULL;
2980     if (!getmppoint(b, len, &ec->publicKey)) {
2981         ecdsa_freekey(ec);
2982         return NULL;
2983     }
2984     ec->privateKey = NULL;
2985
2986     if (!ec->publicKey.x || !ec->publicKey.y ||
2987         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2988         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2989     {
2990         ecdsa_freekey(ec);
2991         return NULL;
2992     }
2993
2994     ec->privateKey = getmp(b, len);
2995     if (ec->privateKey == NULL)
2996     {
2997         ecdsa_freekey(ec);
2998         return NULL;
2999     }
3000
3001     /* Now check that the private key makes the public key */
3002     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
3003     if (!publicKey)
3004     {
3005         ecdsa_freekey(ec);
3006         return NULL;
3007     }
3008
3009     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
3010         bignum_cmp(ec->publicKey.y, publicKey->y))
3011     {
3012         /* Private key doesn't make the public key on the given curve */
3013         ecdsa_freekey(ec);
3014         ec_point_free(publicKey);
3015         return NULL;
3016     }
3017
3018     ec_point_free(publicKey);
3019
3020     return ec;
3021 }
3022
3023 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
3024 {
3025     struct ec_key *ec = (struct ec_key *) key;
3026
3027     int pointlen;
3028     int namelen;
3029     int bloblen;
3030     int i;
3031
3032     if (ec->publicKey.curve->type != EC_WEIERSTRASS) {
3033         return 0;
3034     }
3035
3036     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
3037     namelen = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
3038     bloblen =
3039         4 + namelen /* <LEN> nistpXXX */
3040         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
3041         + ssh2_bignum_length(ec->privateKey);
3042
3043     if (bloblen > len)
3044         return bloblen;
3045
3046     bloblen = 0;
3047
3048     PUT_32BIT(blob+bloblen, namelen);
3049     bloblen += 4;
3050
3051     bloblen += ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, blob+bloblen, namelen);
3052
3053     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
3054     bloblen += 4;
3055     blob[bloblen++] = 0x04;
3056     for (i = pointlen; i--; )
3057         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
3058     for (i = pointlen; i--; )
3059         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
3060
3061     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
3062     PUT_32BIT(blob+bloblen, pointlen);
3063     bloblen += 4;
3064     for (i = pointlen; i--; )
3065         blob[bloblen++] = bignum_byte(ec->privateKey, i);
3066
3067     return bloblen;
3068 }
3069
3070 static int ecdsa_pubkey_bits(const void *blob, int len)
3071 {
3072     struct ec_key *ec;
3073     int ret;
3074
3075     ec = (struct ec_key*)ecdsa_newkey((const char *) blob, len);
3076     if (!ec)
3077         return -1;
3078     ret = ec->publicKey.curve->fieldBits;
3079     ecdsa_freekey(ec);
3080
3081     return ret;
3082 }
3083
3084 static char *ecdsa_fingerprint(void *key)
3085 {
3086     struct ec_key *ec = (struct ec_key *) key;
3087     struct MD5Context md5c;
3088     unsigned char digest[16], lenbuf[4];
3089     char *ret;
3090     unsigned char *name, *fullname;
3091     int pointlen, namelen, fullnamelen, i, j;
3092
3093     MD5Init(&md5c);
3094
3095     namelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
3096     name = snewn(namelen, unsigned char);
3097     ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, name, namelen);
3098
3099     if (ec->publicKey.curve->type == EC_EDWARDS) {
3100         unsigned char b;
3101
3102         /* Do it with the weird encoding */
3103         PUT_32BIT(lenbuf, namelen);
3104         MD5Update(&md5c, lenbuf, 4);
3105         MD5Update(&md5c, name, namelen);
3106
3107         pointlen = ec->publicKey.curve->fieldBits / 8;
3108         PUT_32BIT(lenbuf, pointlen);
3109         MD5Update(&md5c, lenbuf, 4);
3110         for (i = 0; i < pointlen - 1; ++i) {
3111             b = bignum_byte(ec->publicKey.y, i);
3112             MD5Update(&md5c, &b, 1);
3113         }
3114         /* Unset last bit of y and set first bit of x in its place */
3115         b = bignum_byte(ec->publicKey.y, i) & 0x7f;
3116         b |= bignum_bit(ec->publicKey.x, 0) << 7;
3117         MD5Update(&md5c, &b, 1);
3118     } else if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
3119         fullnamelen = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
3120         fullname = snewn(namelen, unsigned char);
3121         ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, fullname, fullnamelen);
3122
3123         PUT_32BIT(lenbuf, fullnamelen);
3124         MD5Update(&md5c, lenbuf, 4);
3125         MD5Update(&md5c, fullname, fullnamelen);
3126         sfree(fullname);
3127
3128         PUT_32BIT(lenbuf, namelen);
3129         MD5Update(&md5c, lenbuf, 4);
3130         MD5Update(&md5c, name, namelen);
3131
3132         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
3133         PUT_32BIT(lenbuf, 1 + (pointlen * 2));
3134         MD5Update(&md5c, lenbuf, 4);
3135         MD5Update(&md5c, (const unsigned char *)"\x04", 1);
3136         for (i = pointlen; i--; ) {
3137             unsigned char c = bignum_byte(ec->publicKey.x, i);
3138             MD5Update(&md5c, &c, 1);
3139         }
3140         for (i = pointlen; i--; ) {
3141             unsigned char c = bignum_byte(ec->publicKey.y, i);
3142             MD5Update(&md5c, &c, 1);
3143         }
3144     } else {
3145         sfree(name);
3146         return NULL;
3147     }
3148
3149     MD5Final(digest, &md5c);
3150
3151     ret = snewn(namelen + 1 + (16 * 3), char);
3152
3153     i = 0;
3154     memcpy(ret, name, namelen);
3155     i += namelen;
3156     sfree(name);
3157     ret[i++] = ' ';
3158     for (j = 0; j < 16; j++) {
3159         i += sprintf(ret + i, "%s%02x", j ? ":" : "", digest[j]);
3160     }
3161
3162     return ret;
3163 }
3164
3165 static int ecdsa_verifysig(void *key, const char *sig, int siglen,
3166                            const char *data, int datalen)
3167 {
3168     struct ec_key *ec = (struct ec_key *) key;
3169     const char *p;
3170     int slen;
3171     int digestLen;
3172     int ret;
3173
3174     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
3175         return 0;
3176
3177     /* Check the signature curve matches the key curve */
3178     getstring(&sig, &siglen, &p, &slen);
3179     if (!p) {
3180         return 0;
3181     }
3182     if (ec->publicKey.curve != ec_name_to_curve(p, slen)) {
3183         return 0;
3184     }
3185
3186     getstring(&sig, &siglen, &p, &slen);
3187     if (ec->publicKey.curve->type == EC_EDWARDS) {
3188         struct ec_point *r;
3189         Bignum s, h;
3190
3191         /* Check that the signature is two times the length of a point */
3192         if (slen != (ec->publicKey.curve->fieldBits / 8) * 2) {
3193             return 0;
3194         }
3195
3196         /* Check it's the 256 bit field so that SHA512 is the correct hash */
3197         if (ec->publicKey.curve->fieldBits != 256) {
3198             return 0;
3199         }
3200
3201         /* Get the signature */
3202         r = ec_point_new(ec->publicKey.curve, NULL, NULL, NULL, 0);
3203         if (!r) {
3204             return 0;
3205         }
3206         if (!decodepoint(p, ec->publicKey.curve->fieldBits / 8, r)) {
3207             ec_point_free(r);
3208             return 0;
3209         }
3210         s = bignum_from_bytes_le((unsigned char*)p + (ec->publicKey.curve->fieldBits / 8),
3211                                  ec->publicKey.curve->fieldBits / 8);
3212         if (!s) {
3213             ec_point_free(r);
3214             return 0;
3215         }
3216
3217         /* Get the hash of the encoded value of R + encoded value of pk + message */
3218         {
3219             int i, pointlen;
3220             unsigned char b;
3221             unsigned char digest[512 / 8];
3222             SHA512_State hs;
3223             SHA512_Init(&hs);
3224
3225             /* Add encoded r (no need to encode it again, it was in the signature) */
3226             SHA512_Bytes(&hs, p, ec->publicKey.curve->fieldBits / 8);
3227
3228             /* Encode pk and add it */
3229             pointlen = ec->publicKey.curve->fieldBits / 8;
3230             for (i = 0; i < pointlen - 1; ++i) {
3231                 b = bignum_byte(ec->publicKey.y, i);
3232                 SHA512_Bytes(&hs, &b, 1);
3233             }
3234             /* Unset last bit of y and set first bit of x in its place */
3235             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
3236             b |= bignum_bit(ec->publicKey.x, 0) << 7;
3237             SHA512_Bytes(&hs, &b, 1);
3238
3239             /* Add the message itself */
3240             SHA512_Bytes(&hs, data, datalen);
3241
3242             /* Get the hash */
3243             SHA512_Final(&hs, digest);
3244
3245             /* Convert to Bignum */
3246             h = bignum_from_bytes_le(digest, sizeof(digest));
3247             if (!h) {
3248                 ec_point_free(r);
3249                 freebn(s);
3250                 return 0;
3251             }
3252         }
3253
3254         /* Verify sB == r + h*publicKey */
3255         {
3256             struct ec_point *lhs, *rhs, *tmp;
3257
3258             /* lhs = sB */
3259             lhs = ecp_mul(&ec->publicKey.curve->e.B, s);
3260             freebn(s);
3261             if (!lhs) {
3262                 ec_point_free(r);
3263                 freebn(h);
3264                 return 0;
3265             }
3266
3267             /* rhs = r + h*publicKey */
3268             tmp = ecp_mul(&ec->publicKey, h);
3269             freebn(h);
3270             if (!tmp) {
3271                 ec_point_free(lhs);
3272                 ec_point_free(r);
3273                 return 0;
3274             }
3275             rhs = ecp_add(r, tmp, 0);
3276             ec_point_free(r);
3277             ec_point_free(tmp);
3278             if (!rhs) {
3279                 ec_point_free(lhs);
3280                 return 0;
3281             }
3282
3283             /* Check the point is the same */
3284             ret = !bignum_cmp(lhs->x, rhs->x);
3285             if (ret) {
3286                 ret = !bignum_cmp(lhs->y, rhs->y);
3287                 if (ret) {
3288                     ret = 1;
3289                 }
3290             }
3291             ec_point_free(lhs);
3292             ec_point_free(rhs);
3293         }
3294     } else {
3295         Bignum r, s;
3296         unsigned char digest[512 / 8];
3297
3298         r = getmp(&p, &slen);
3299         if (!r) return 0;
3300         s = getmp(&p, &slen);
3301         if (!s) {
3302             freebn(r);
3303             return 0;
3304         }
3305
3306         /* Perform correct hash function depending on curve size */
3307         if (ec->publicKey.curve->fieldBits <= 256) {
3308             SHA256_Simple(data, datalen, digest);
3309             digestLen = 256 / 8;
3310         } else if (ec->publicKey.curve->fieldBits <= 384) {
3311             SHA384_Simple(data, datalen, digest);
3312             digestLen = 384 / 8;
3313         } else {
3314             SHA512_Simple(data, datalen, digest);
3315             digestLen = 512 / 8;
3316         }
3317
3318         /* Verify the signature */
3319         ret = _ecdsa_verify(&ec->publicKey, digest, digestLen, r, s);
3320
3321         freebn(r);
3322         freebn(s);
3323     }
3324
3325     return ret;
3326 }
3327
3328 static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
3329                                  int *siglen)
3330 {
3331     struct ec_key *ec = (struct ec_key *) key;
3332     unsigned char digest[512 / 8];
3333     int digestLen;
3334     Bignum r = NULL, s = NULL;
3335     unsigned char *buf, *p;
3336     int rlen, slen, namelen;
3337     int i;
3338
3339     if (!ec->privateKey || !ec->publicKey.curve) {
3340         return NULL;
3341     }
3342
3343     if (ec->publicKey.curve->type == EC_EDWARDS) {
3344         struct ec_point *rp;
3345         int pointlen = ec->publicKey.curve->fieldBits / 8;
3346
3347         /* hash = H(sk) (where hash creates 2 * fieldBits)
3348          * b = fieldBits
3349          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
3350          * r = H(h[b/8:b/4] + m)
3351          * R = rB
3352          * S = (r + H(encodepoint(R) + encodepoint(pk) + m) * a) % l */
3353         {
3354             unsigned char hash[512/8];
3355             unsigned char b;
3356             Bignum a;
3357             SHA512_State hs;
3358             SHA512_Init(&hs);
3359
3360             for (i = 0; i < pointlen; ++i) {
3361                 unsigned char b = (unsigned char)bignum_byte(ec->privateKey, i);
3362                 SHA512_Bytes(&hs, &b, 1);
3363             }
3364
3365             SHA512_Final(&hs, hash);
3366
3367             /* The second part is simply turning the hash into a Bignum, however
3368              * the 2^(b-2) bit *must* be set, and the bottom 2 bits *must* not be */
3369             hash[0] &= 0xfc; /* Unset bottom two bits (if set) */
3370             hash[31] &= 0x7f; /* Unset above (b-2) */
3371             hash[31] |= 0x40; /* Set 2^(b-2) */
3372             /* Chop off the top part and convert to int */
3373             a = bignum_from_bytes_le(hash, 32);
3374             if (!a) {
3375                 return NULL;
3376             }
3377
3378             SHA512_Init(&hs);
3379             SHA512_Bytes(&hs,
3380                          hash+(ec->publicKey.curve->fieldBits / 8),
3381                          (ec->publicKey.curve->fieldBits / 4)
3382                          - (ec->publicKey.curve->fieldBits / 8));
3383             SHA512_Bytes(&hs, data, datalen);
3384             SHA512_Final(&hs, hash);
3385
3386             r = bignum_from_bytes_le(hash, 512/8);
3387             if (!r) {
3388                 freebn(a);
3389                 return NULL;
3390             }
3391             rp = ecp_mul(&ec->publicKey.curve->e.B, r);
3392             if (!rp) {
3393                 freebn(r);
3394                 freebn(a);
3395                 return NULL;
3396             }
3397
3398             /* Now calculate s */
3399             SHA512_Init(&hs);
3400             /* Encode the point R */
3401             for (i = 0; i < pointlen - 1; ++i) {
3402                 b = bignum_byte(rp->y, i);
3403                 SHA512_Bytes(&hs, &b, 1);
3404             }
3405             /* Unset last bit of y and set first bit of x in its place */
3406             b = bignum_byte(rp->y, i) & 0x7f;
3407             b |= bignum_bit(rp->x, 0) << 7;
3408             SHA512_Bytes(&hs, &b, 1);
3409
3410             /* Encode the point pk */
3411             for (i = 0; i < pointlen - 1; ++i) {
3412                 b = bignum_byte(ec->publicKey.y, i);
3413                 SHA512_Bytes(&hs, &b, 1);
3414             }
3415             /* Unset last bit of y and set first bit of x in its place */
3416             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
3417             b |= bignum_bit(ec->publicKey.x, 0) << 7;
3418             SHA512_Bytes(&hs, &b, 1);
3419
3420             /* Add the message */
3421             SHA512_Bytes(&hs, data, datalen);
3422             SHA512_Final(&hs, hash);
3423
3424             {
3425                 Bignum tmp, tmp2;
3426
3427                 tmp = bignum_from_bytes_le(hash, 512/8);
3428                 if (!tmp) {
3429                     ec_point_free(rp);
3430                     freebn(r);
3431                     freebn(a);
3432                     return NULL;
3433                 }
3434                 tmp2 = modmul(tmp, a, ec->publicKey.curve->e.l);
3435                 freebn(a);
3436                 freebn(tmp);
3437                 if (!tmp2) {
3438                     ec_point_free(rp);
3439                     freebn(r);
3440                     return NULL;
3441                 }
3442                 tmp = bigadd(r, tmp2);
3443                 freebn(r);
3444                 freebn(tmp2);
3445                 if (!tmp) {
3446                     ec_point_free(rp);
3447                     return NULL;
3448                 }
3449                 s = bigmod(tmp, ec->publicKey.curve->e.l);
3450                 freebn(tmp);
3451                 if (!s) {
3452                     ec_point_free(rp);
3453                     return NULL;
3454                 }
3455             }
3456         }
3457
3458         /* Format the output */
3459         namelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
3460         *siglen = 4+namelen+4+((ec->publicKey.curve->fieldBits / 8)*2);
3461         buf = snewn(*siglen, unsigned char);
3462         p = buf;
3463         PUT_32BIT(p, namelen);
3464         p += 4;
3465         p += ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, p, namelen);
3466         PUT_32BIT(p, ((ec->publicKey.curve->fieldBits / 8)*2));
3467         p += 4;
3468
3469         /* Encode the point */
3470         pointlen = ec->publicKey.curve->fieldBits / 8;
3471         for (i = 0; i < pointlen - 1; ++i) {
3472             *p++ = bignum_byte(rp->y, i);
3473         }
3474         /* Unset last bit of y and set first bit of x in its place */
3475         *p = bignum_byte(rp->y, i) & 0x7f;
3476         *p++ |= bignum_bit(rp->x, 0) << 7;
3477         ec_point_free(rp);
3478
3479         /* Encode the int */
3480         for (i = 0; i < pointlen; ++i) {
3481             *p++ = bignum_byte(s, i);
3482         }
3483         freebn(s);
3484     } else {
3485         /* Perform correct hash function depending on curve size */
3486         if (ec->publicKey.curve->fieldBits <= 256) {
3487             SHA256_Simple(data, datalen, digest);
3488             digestLen = 256 / 8;
3489         } else if (ec->publicKey.curve->fieldBits <= 384) {
3490             SHA384_Simple(data, datalen, digest);
3491             digestLen = 384 / 8;
3492         } else {
3493             SHA512_Simple(data, datalen, digest);
3494             digestLen = 512 / 8;
3495         }
3496
3497         /* Do the signature */
3498         _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
3499         if (!r || !s) {
3500             if (r) freebn(r);
3501             if (s) freebn(s);
3502             return NULL;
3503         }
3504
3505         rlen = (bignum_bitcount(r) + 8) / 8;
3506         slen = (bignum_bitcount(s) + 8) / 8;
3507
3508         namelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
3509
3510         /* Format the output */
3511         *siglen = 8+namelen+rlen+slen+8;
3512         buf = snewn(*siglen, unsigned char);
3513         p = buf;
3514         PUT_32BIT(p, namelen);
3515         p += 4;
3516         p += ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, p, namelen);
3517         PUT_32BIT(p, rlen + slen + 8);
3518         p += 4;
3519         PUT_32BIT(p, rlen);
3520         p += 4;
3521         for (i = rlen; i--;)
3522             *p++ = bignum_byte(r, i);
3523         PUT_32BIT(p, slen);
3524         p += 4;
3525         for (i = slen; i--;)
3526             *p++ = bignum_byte(s, i);
3527
3528         freebn(r);
3529         freebn(s);
3530     }
3531
3532     return buf;
3533 }
3534
3535 const struct ssh_signkey ssh_ecdsa_ed25519 = {
3536     ecdsa_newkey,
3537     ecdsa_freekey,
3538     ecdsa_fmtkey,
3539     ecdsa_public_blob,
3540     ecdsa_private_blob,
3541     ecdsa_createkey,
3542     ed25519_openssh_createkey,
3543     ed25519_openssh_fmtkey,
3544     2 /* point, private exponent */,
3545     ecdsa_pubkey_bits,
3546     ecdsa_fingerprint,
3547     ecdsa_verifysig,
3548     ecdsa_sign,
3549     "ssh-ed25519",
3550     "ssh-ed25519",
3551 };
3552
3553 const struct ssh_signkey ssh_ecdsa_nistp256 = {
3554     ecdsa_newkey,
3555     ecdsa_freekey,
3556     ecdsa_fmtkey,
3557     ecdsa_public_blob,
3558     ecdsa_private_blob,
3559     ecdsa_createkey,
3560     ecdsa_openssh_createkey,
3561     ecdsa_openssh_fmtkey,
3562     3 /* curve name, point, private exponent */,
3563     ecdsa_pubkey_bits,
3564     ecdsa_fingerprint,
3565     ecdsa_verifysig,
3566     ecdsa_sign,
3567     "ecdsa-sha2-nistp256",
3568     "ecdsa-sha2-nistp256",
3569 };
3570
3571 const struct ssh_signkey ssh_ecdsa_nistp384 = {
3572     ecdsa_newkey,
3573     ecdsa_freekey,
3574     ecdsa_fmtkey,
3575     ecdsa_public_blob,
3576     ecdsa_private_blob,
3577     ecdsa_createkey,
3578     ecdsa_openssh_createkey,
3579     ecdsa_openssh_fmtkey,
3580     3 /* curve name, point, private exponent */,
3581     ecdsa_pubkey_bits,
3582     ecdsa_fingerprint,
3583     ecdsa_verifysig,
3584     ecdsa_sign,
3585     "ecdsa-sha2-nistp384",
3586     "ecdsa-sha2-nistp384",
3587 };
3588
3589 const struct ssh_signkey ssh_ecdsa_nistp521 = {
3590     ecdsa_newkey,
3591     ecdsa_freekey,
3592     ecdsa_fmtkey,
3593     ecdsa_public_blob,
3594     ecdsa_private_blob,
3595     ecdsa_createkey,
3596     ecdsa_openssh_createkey,
3597     ecdsa_openssh_fmtkey,
3598     3 /* curve name, point, private exponent */,
3599     ecdsa_pubkey_bits,
3600     ecdsa_fingerprint,
3601     ecdsa_verifysig,
3602     ecdsa_sign,
3603     "ecdsa-sha2-nistp521",
3604     "ecdsa-sha2-nistp521",
3605 };
3606
3607 /* ----------------------------------------------------------------------
3608  * Exposed ECDH interface
3609  */
3610
3611 static Bignum ecdh_calculate(const Bignum private,
3612                              const struct ec_point *public)
3613 {
3614     struct ec_point *p;
3615     Bignum ret;
3616     p = ecp_mul(public, private);
3617     if (!p) return NULL;
3618     ret = p->x;
3619     p->x = NULL;
3620
3621     if (p->curve->type == EC_MONTGOMERY) {
3622         /* Do conversion in network byte order */
3623         int i;
3624         int bytes = (bignum_bitcount(ret)+7) / 8;
3625         unsigned char *byteorder = snewn(bytes, unsigned char);
3626         if (!byteorder) {
3627             ec_point_free(p);
3628             freebn(ret);
3629             return NULL;
3630         }
3631         for (i = 0; i < bytes; ++i) {
3632             byteorder[i] = bignum_byte(ret, i);
3633         }
3634         freebn(ret);
3635         ret = bignum_from_bytes(byteorder, bytes);
3636         sfree(byteorder);
3637     }
3638
3639     ec_point_free(p);
3640     return ret;
3641 }
3642
3643 void *ssh_ecdhkex_newkey(const char *name)
3644 {
3645     struct ec_curve *curve;
3646     struct ec_key *key;
3647     struct ec_point *publicKey;
3648
3649     curve = ec_name_to_curve(name, strlen(name));
3650
3651     key = snew(struct ec_key);
3652     if (!key) {
3653         return NULL;
3654     }
3655
3656     key->publicKey.curve = curve;
3657
3658     if (curve->type == EC_MONTGOMERY) {
3659         unsigned char bytes[32] = {0};
3660         int i;
3661
3662         for (i = 0; i < sizeof(bytes); ++i)
3663         {
3664             bytes[i] = (unsigned char)random_byte();
3665         }
3666         bytes[0] &= 248;
3667         bytes[31] &= 127;
3668         bytes[31] |= 64;
3669         key->privateKey = bignum_from_bytes(bytes, sizeof(bytes));
3670         for (i = 0; i < sizeof(bytes); ++i)
3671         {
3672             ((volatile char*)bytes)[i] = 0;
3673         }
3674         if (!key->privateKey) {
3675             sfree(key);
3676             return NULL;
3677         }
3678         publicKey = ecp_mul(&key->publicKey.curve->m.G, key->privateKey);
3679         if (!publicKey) {
3680             freebn(key->privateKey);
3681             sfree(key);
3682             return NULL;
3683         }
3684         key->publicKey.x = publicKey->x;
3685         key->publicKey.y = publicKey->y;
3686         key->publicKey.z = NULL;
3687         sfree(publicKey);
3688     } else {
3689         key->privateKey = bignum_random_in_range(One, key->publicKey.curve->w.n);
3690         if (!key->privateKey) {
3691             sfree(key);
3692             return NULL;
3693         }
3694         publicKey = ecp_mul(&key->publicKey.curve->w.G, key->privateKey);
3695         if (!publicKey) {
3696             freebn(key->privateKey);
3697             sfree(key);
3698             return NULL;
3699         }
3700         key->publicKey.x = publicKey->x;
3701         key->publicKey.y = publicKey->y;
3702         key->publicKey.z = NULL;
3703         sfree(publicKey);
3704     }
3705     return key;
3706 }
3707
3708 char *ssh_ecdhkex_getpublic(void *key, int *len)
3709 {
3710     struct ec_key *ec = (struct ec_key*)key;
3711     char *point, *p;
3712     int i;
3713     int pointlen;
3714
3715     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
3716
3717     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
3718         *len = 1 + pointlen * 2;
3719     } else {
3720         *len = pointlen;
3721     }
3722     point = (char*)snewn(*len, char);
3723     if (!point) {
3724         return NULL;
3725     }
3726
3727     p = point;
3728     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
3729         *p++ = 0x04;
3730         for (i = pointlen; i--;) {
3731             *p++ = bignum_byte(ec->publicKey.x, i);
3732         }
3733         for (i = pointlen; i--;) {
3734             *p++ = bignum_byte(ec->publicKey.y, i);
3735         }
3736     } else {
3737         for (i = 0; i < pointlen; ++i) {
3738             *p++ = bignum_byte(ec->publicKey.x, i);
3739         }
3740     }
3741
3742     return point;
3743 }
3744
3745 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
3746 {
3747     struct ec_key *ec = (struct ec_key*) key;
3748     struct ec_point remote;
3749     Bignum ret;
3750
3751     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
3752         remote.curve = ec->publicKey.curve;
3753         remote.infinity = 0;
3754         if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
3755             return NULL;
3756         }
3757     } else {
3758         /* Point length has to be the same length */
3759         if (remoteKeyLen != (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8) {
3760             return NULL;
3761         }
3762
3763         remote.curve = ec->publicKey.curve;
3764         remote.infinity = 0;
3765         remote.x = bignum_from_bytes_le((unsigned char*)remoteKey, remoteKeyLen);
3766         remote.y = NULL;
3767         remote.z = NULL;
3768     }
3769
3770     ret = ecdh_calculate(ec->privateKey, &remote);
3771     if (remote.x) freebn(remote.x);
3772     if (remote.y) freebn(remote.y);
3773     return ret;
3774 }
3775
3776 void ssh_ecdhkex_freekey(void *key)
3777 {
3778     ecdsa_freekey(key);
3779 }
3780
3781 static const struct ssh_kex ssh_ec_kex_curve25519 = {
3782     "curve25519-sha256@libssh.org", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha256
3783 };
3784
3785 static const struct ssh_kex ssh_ec_kex_nistp256 = {
3786     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha256
3787 };
3788
3789 static const struct ssh_kex ssh_ec_kex_nistp384 = {
3790     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha384
3791 };
3792
3793 static const struct ssh_kex ssh_ec_kex_nistp521 = {
3794     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha512
3795 };
3796
3797 static const struct ssh_kex *const ec_kex_list[] = {
3798     &ssh_ec_kex_curve25519,
3799     &ssh_ec_kex_nistp256,
3800     &ssh_ec_kex_nistp384,
3801     &ssh_ec_kex_nistp521
3802 };
3803
3804 const struct ssh_kexes ssh_ecdh_kex = {
3805     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
3806     ec_kex_list
3807 };