]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
Add missing consts in elliptic curve setup code.
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  *
5  * NOTE: Only curves on prime field are handled by the maths functions
6  *       in Weierstrass form using Jacobian co-ordinates.
7  *
8  *       Montgomery form curves are supported for DH. (Curve25519)
9  *
10  *       Edwards form curves are supported for DSA. (Ed25519)
11  */
12
13 /*
14  * References:
15  *
16  * Elliptic curves in SSH are specified in RFC 5656:
17  *   http://tools.ietf.org/html/rfc5656
18  *
19  * That specification delegates details of public key formatting and a
20  * lot of underlying mechanism to SEC 1:
21  *   http://www.secg.org/sec1-v2.pdf
22  *
23  * Montgomery maths from:
24  * Handbook of elliptic and hyperelliptic curve cryptography, Chapter 13
25  *   http://cs.ucsb.edu/~koc/ccs130h/2013/EllipticHyperelliptic-CohenFrey.pdf
26  *
27  * Edwards DSA:
28  *   http://ed25519.cr.yp.to/ed25519-20110926.pdf
29  */
30
31 #include <stdlib.h>
32 #include <assert.h>
33
34 #include "ssh.h"
35
36 /* ----------------------------------------------------------------------
37  * Elliptic curve definitions
38  */
39
40 static void initialise_wcurve(struct ec_curve *curve, int bits,
41                               const unsigned char *p,
42                               const unsigned char *a, const unsigned char *b,
43                               const unsigned char *n, const unsigned char *Gx,
44                               const unsigned char *Gy)
45 {
46     int length = bits / 8;
47     if (bits % 8) ++length;
48
49     curve->type = EC_WEIERSTRASS;
50
51     curve->fieldBits = bits;
52     curve->p = bignum_from_bytes(p, length);
53
54     /* Curve co-efficients */
55     curve->w.a = bignum_from_bytes(a, length);
56     curve->w.b = bignum_from_bytes(b, length);
57
58     /* Group order and generator */
59     curve->w.n = bignum_from_bytes(n, length);
60     curve->w.G.x = bignum_from_bytes(Gx, length);
61     curve->w.G.y = bignum_from_bytes(Gy, length);
62     curve->w.G.curve = curve;
63     curve->w.G.infinity = 0;
64 }
65
66 static void initialise_mcurve(struct ec_curve *curve, int bits,
67                               const unsigned char *p,
68                               const unsigned char *a, const unsigned char *b,
69                               const unsigned char *Gx)
70 {
71     int length = bits / 8;
72     if (bits % 8) ++length;
73
74     curve->type = EC_MONTGOMERY;
75
76     curve->fieldBits = bits;
77     curve->p = bignum_from_bytes(p, length);
78
79     /* Curve co-efficients */
80     curve->m.a = bignum_from_bytes(a, length);
81     curve->m.b = bignum_from_bytes(b, length);
82
83     /* Generator */
84     curve->m.G.x = bignum_from_bytes(Gx, length);
85     curve->m.G.y = NULL;
86     curve->m.G.z = NULL;
87     curve->m.G.curve = curve;
88     curve->m.G.infinity = 0;
89 }
90
91 static void initialise_ecurve(struct ec_curve *curve, int bits,
92                               const unsigned char *p,
93                               const unsigned char *l, const unsigned char *d,
94                               const unsigned char *Bx, const unsigned char *By)
95 {
96     int length = bits / 8;
97     if (bits % 8) ++length;
98
99     curve->type = EC_EDWARDS;
100
101     curve->fieldBits = bits;
102     curve->p = bignum_from_bytes(p, length);
103
104     /* Curve co-efficients */
105     curve->e.l = bignum_from_bytes(l, length);
106     curve->e.d = bignum_from_bytes(d, length);
107
108     /* Group order and generator */
109     curve->e.B.x = bignum_from_bytes(Bx, length);
110     curve->e.B.y = bignum_from_bytes(By, length);
111     curve->e.B.curve = curve;
112     curve->e.B.infinity = 0;
113 }
114
115 static struct ec_curve *ec_p256(void)
116 {
117     static struct ec_curve curve = { 0 };
118     static unsigned char initialised = 0;
119
120     if (!initialised)
121     {
122         static const unsigned char p[] = {
123             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
124             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
125             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
126             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
127         };
128         static const unsigned char a[] = {
129             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
130             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
131             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
132             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
133         };
134         static const unsigned char b[] = {
135             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
136             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
137             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
138             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
139         };
140         static const unsigned char n[] = {
141             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
142             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
143             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
144             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
145         };
146         static const unsigned char Gx[] = {
147             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
148             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
149             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
150             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
151         };
152         static const unsigned char Gy[] = {
153             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
154             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
155             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
156             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
157         };
158
159         initialise_wcurve(&curve, 256, p, a, b, n, Gx, Gy);
160         curve.name = "nistp256";
161
162         /* Now initialised, no need to do it again */
163         initialised = 1;
164     }
165
166     return &curve;
167 }
168
169 static struct ec_curve *ec_p384(void)
170 {
171     static struct ec_curve curve = { 0 };
172     static unsigned char initialised = 0;
173
174     if (!initialised)
175     {
176         static const unsigned char p[] = {
177             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
178             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
179             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
180             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
181             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
182             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
183         };
184         static const unsigned char a[] = {
185             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
186             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
187             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
188             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
189             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
190             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
191         };
192         static const unsigned char b[] = {
193             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
194             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
195             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
196             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
197             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
198             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
199         };
200         static const unsigned char n[] = {
201             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
202             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
203             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
204             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
205             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
206             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
207         };
208         static const unsigned char Gx[] = {
209             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
210             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
211             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
212             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
213             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
214             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
215         };
216         static const unsigned char Gy[] = {
217             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
218             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
219             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
220             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
221             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
222             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
223         };
224
225         initialise_wcurve(&curve, 384, p, a, b, n, Gx, Gy);
226         curve.name = "nistp384";
227
228         /* Now initialised, no need to do it again */
229         initialised = 1;
230     }
231
232     return &curve;
233 }
234
235 static struct ec_curve *ec_p521(void)
236 {
237     static struct ec_curve curve = { 0 };
238     static unsigned char initialised = 0;
239
240     if (!initialised)
241     {
242         static const unsigned char p[] = {
243             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
244             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
245             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
246             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
247             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
248             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
249             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
250             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
251             0xff, 0xff
252         };
253         static const unsigned char a[] = {
254             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
255             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
256             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
257             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
258             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
259             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
260             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
261             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
262             0xff, 0xfc
263         };
264         static const unsigned char b[] = {
265             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
266             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
267             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
268             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
269             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
270             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
271             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
272             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
273             0x3f, 0x00
274         };
275         static const unsigned char n[] = {
276             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
277             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
278             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
279             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
280             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
281             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
282             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
283             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
284             0x64, 0x09
285         };
286         static const unsigned char Gx[] = {
287             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
288             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
289             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
290             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
291             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
292             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
293             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
294             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
295             0xbd, 0x66
296         };
297         static const unsigned char Gy[] = {
298             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
299             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
300             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
301             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
302             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
303             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
304             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
305             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
306             0x66, 0x50
307         };
308
309         initialise_wcurve(&curve, 521, p, a, b, n, Gx, Gy);
310         curve.name = "nistp521";
311
312         /* Now initialised, no need to do it again */
313         initialised = 1;
314     }
315
316     return &curve;
317 }
318
319 static struct ec_curve *ec_curve25519(void)
320 {
321     static struct ec_curve curve = { 0 };
322     static unsigned char initialised = 0;
323
324     if (!initialised)
325     {
326         static const unsigned char p[] = {
327             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
328             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
329             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
330             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
331         };
332         static const unsigned char a[] = {
333             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
334             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
335             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
336             0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x6d, 0x06
337         };
338         static const unsigned char b[] = {
339             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
340             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
341             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
342             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01
343         };
344         static const unsigned char gx[32] = {
345             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
346             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
347             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
348             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09
349         };
350
351         initialise_mcurve(&curve, 256, p, a, b, gx);
352         /* This curve doesn't need a name, because it's never used in
353          * any format that embeds the curve name */
354         curve.name = NULL;
355
356         /* Now initialised, no need to do it again */
357         initialised = 1;
358     }
359
360     return &curve;
361 }
362
363 static struct ec_curve *ec_ed25519(void)
364 {
365     static struct ec_curve curve = { 0 };
366     static unsigned char initialised = 0;
367
368     if (!initialised)
369     {
370         static const unsigned char q[] = {
371             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
372             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
373             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
374             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
375         };
376         static const unsigned char l[32] = {
377             0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
378             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
379             0x14, 0xde, 0xf9, 0xde, 0xa2, 0xf7, 0x9c, 0xd6,
380             0x58, 0x12, 0x63, 0x1a, 0x5c, 0xf5, 0xd3, 0xed
381         };
382         static const unsigned char d[32] = {
383             0x52, 0x03, 0x6c, 0xee, 0x2b, 0x6f, 0xfe, 0x73,
384             0x8c, 0xc7, 0x40, 0x79, 0x77, 0x79, 0xe8, 0x98,
385             0x00, 0x70, 0x0a, 0x4d, 0x41, 0x41, 0xd8, 0xab,
386             0x75, 0xeb, 0x4d, 0xca, 0x13, 0x59, 0x78, 0xa3
387         };
388         static const unsigned char Bx[32] = {
389             0x21, 0x69, 0x36, 0xd3, 0xcd, 0x6e, 0x53, 0xfe,
390             0xc0, 0xa4, 0xe2, 0x31, 0xfd, 0xd6, 0xdc, 0x5c,
391             0x69, 0x2c, 0xc7, 0x60, 0x95, 0x25, 0xa7, 0xb2,
392             0xc9, 0x56, 0x2d, 0x60, 0x8f, 0x25, 0xd5, 0x1a
393         };
394         static const unsigned char By[32] = {
395             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
396             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
397             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
398             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x58
399         };
400
401         /* This curve doesn't need a name, because it's never used in
402          * any format that embeds the curve name */
403         curve.name = NULL;
404
405         initialise_ecurve(&curve, 256, q, l, d, Bx, By);
406
407         /* Now initialised, no need to do it again */
408         initialised = 1;
409     }
410
411     return &curve;
412 }
413
414 /* Return 1 if a is -3 % p, otherwise return 0
415  * This is used because there are some maths optimisations */
416 static int ec_aminus3(const struct ec_curve *curve)
417 {
418     int ret;
419     Bignum _p;
420
421     if (curve->type != EC_WEIERSTRASS) {
422         return 0;
423     }
424
425     _p = bignum_add_long(curve->w.a, 3);
426
427     ret = !bignum_cmp(curve->p, _p);
428     freebn(_p);
429     return ret;
430 }
431
432 /* ----------------------------------------------------------------------
433  * Elliptic curve field maths
434  */
435
436 static Bignum ecf_add(const Bignum a, const Bignum b,
437                       const struct ec_curve *curve)
438 {
439     Bignum a1, b1, ab, ret;
440
441     a1 = bigmod(a, curve->p);
442     b1 = bigmod(b, curve->p);
443
444     ab = bigadd(a1, b1);
445     freebn(a1);
446     freebn(b1);
447
448     ret = bigmod(ab, curve->p);
449     freebn(ab);
450
451     return ret;
452 }
453
454 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
455 {
456     return modmul(a, a, curve->p);
457 }
458
459 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
460 {
461     Bignum ret, tmp;
462
463     /* Double */
464     tmp = bignum_lshift(a, 1);
465
466     /* Add itself (i.e. treble) */
467     ret = bigadd(tmp, a);
468     freebn(tmp);
469
470     /* Normalise */
471     while (bignum_cmp(ret, curve->p) >= 0)
472     {
473         tmp = bigsub(ret, curve->p);
474         assert(tmp);
475         freebn(ret);
476         ret = tmp;
477     }
478
479     return ret;
480 }
481
482 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
483 {
484     Bignum ret = bignum_lshift(a, 1);
485     if (bignum_cmp(ret, curve->p) >= 0)
486     {
487         Bignum tmp = bigsub(ret, curve->p);
488         assert(tmp);
489         freebn(ret);
490         return tmp;
491     }
492     else
493     {
494         return ret;
495     }
496 }
497
498 /* ----------------------------------------------------------------------
499  * Memory functions
500  */
501
502 void ec_point_free(struct ec_point *point)
503 {
504     if (point == NULL) return;
505     point->curve = 0;
506     if (point->x) freebn(point->x);
507     if (point->y) freebn(point->y);
508     if (point->z) freebn(point->z);
509     point->infinity = 0;
510     sfree(point);
511 }
512
513 static struct ec_point *ec_point_new(const struct ec_curve *curve,
514                                      const Bignum x, const Bignum y, const Bignum z,
515                                      unsigned char infinity)
516 {
517     struct ec_point *point = snewn(1, struct ec_point);
518     point->curve = curve;
519     point->x = x;
520     point->y = y;
521     point->z = z;
522     point->infinity = infinity ? 1 : 0;
523     return point;
524 }
525
526 static struct ec_point *ec_point_copy(const struct ec_point *a)
527 {
528     if (a == NULL) return NULL;
529     return ec_point_new(a->curve,
530                         a->x ? copybn(a->x) : NULL,
531                         a->y ? copybn(a->y) : NULL,
532                         a->z ? copybn(a->z) : NULL,
533                         a->infinity);
534 }
535
536 static int ec_point_verify(const struct ec_point *a)
537 {
538     if (a->infinity) {
539         return 1;
540     } else if (a->curve->type == EC_EDWARDS) {
541         /* Check y^2 - x^2 - 1 - d * x^2 * y^2 == 0 */
542         Bignum y2, x2, tmp, tmp2, tmp3;
543         int ret;
544
545         y2 = ecf_square(a->y, a->curve);
546         x2 = ecf_square(a->x, a->curve);
547         tmp = modmul(a->curve->e.d, x2, a->curve->p);
548         tmp2 = modmul(tmp, y2, a->curve->p);
549         freebn(tmp);
550         tmp = modsub(y2, x2, a->curve->p);
551         freebn(y2);
552         freebn(x2);
553         tmp3 = modsub(tmp, tmp2, a->curve->p);
554         freebn(tmp);
555         freebn(tmp2);
556         ret = !bignum_cmp(tmp3, One);
557         freebn(tmp3);
558         return ret;
559     } else if (a->curve->type == EC_WEIERSTRASS) {
560         /* Verify y^2 = x^3 + ax + b */
561         int ret = 0;
562
563         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
564
565         Bignum Three = bignum_from_long(3);
566
567         lhs = modmul(a->y, a->y, a->curve->p);
568
569         /* This uses montgomery multiplication to optimise */
570         x3 = modpow(a->x, Three, a->curve->p);
571         freebn(Three);
572         ax = modmul(a->curve->w.a, a->x, a->curve->p);
573         x3ax = bigadd(x3, ax);
574         freebn(x3); x3 = NULL;
575         freebn(ax); ax = NULL;
576         x3axm = bigmod(x3ax, a->curve->p);
577         freebn(x3ax); x3ax = NULL;
578         x3axb = bigadd(x3axm, a->curve->w.b);
579         freebn(x3axm); x3axm = NULL;
580         rhs = bigmod(x3axb, a->curve->p);
581         freebn(x3axb);
582
583         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
584         freebn(lhs);
585         freebn(rhs);
586
587         return ret;
588     } else {
589         return 0;
590     }
591 }
592
593 /* ----------------------------------------------------------------------
594  * Elliptic curve point maths
595  */
596
597 /* Returns 1 on success and 0 on memory error */
598 static int ecp_normalise(struct ec_point *a)
599 {
600     if (!a) {
601         /* No point */
602         return 0;
603     }
604
605     if (a->infinity) {
606         /* Point is at infinity - i.e. normalised */
607         return 1;
608     }
609
610     if (a->curve->type == EC_WEIERSTRASS) {
611         /* In Jacobian Coordinates the triple (X, Y, Z) represents
612            the affine point (X / Z^2, Y / Z^3) */
613
614         Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
615
616         if (!a->x || !a->y) {
617             /* No point defined */
618             return 0;
619         } else if (!a->z) {
620             /* Already normalised */
621             return 1;
622         }
623
624         Z2 = ecf_square(a->z, a->curve);
625         Z2inv = modinv(Z2, a->curve->p);
626         if (!Z2inv) {
627             freebn(Z2);
628             return 0;
629         }
630         tx = modmul(a->x, Z2inv, a->curve->p);
631         freebn(Z2inv);
632
633         Z3 = modmul(Z2, a->z, a->curve->p);
634         freebn(Z2);
635         Z3inv = modinv(Z3, a->curve->p);
636         freebn(Z3);
637         if (!Z3inv) {
638             freebn(tx);
639             return 0;
640         }
641         ty = modmul(a->y, Z3inv, a->curve->p);
642         freebn(Z3inv);
643
644         freebn(a->x);
645         a->x = tx;
646         freebn(a->y);
647         a->y = ty;
648         freebn(a->z);
649         a->z = NULL;
650         return 1;
651     } else if (a->curve->type == EC_MONTGOMERY) {
652         /* In Montgomery (X : Z) represents the x co-ord (X / Z, ?) */
653
654         Bignum tmp, tmp2;
655
656         if (!a->x) {
657             /* No point defined */
658             return 0;
659         } else if (!a->z) {
660             /* Already normalised */
661             return 1;
662         }
663
664         tmp = modinv(a->z, a->curve->p);
665         if (!tmp) {
666             return 0;
667         }
668         tmp2 = modmul(a->x, tmp, a->curve->p);
669         freebn(tmp);
670
671         freebn(a->z);
672         a->z = NULL;
673         freebn(a->x);
674         a->x = tmp2;
675         return 1;
676     } else if (a->curve->type == EC_EDWARDS) {
677         /* Always normalised */
678         return 1;
679     } else {
680         return 0;
681     }
682 }
683
684 static struct ec_point *ecp_doublew(const struct ec_point *a, const int aminus3)
685 {
686     Bignum S, M, outx, outy, outz;
687
688     if (bignum_cmp(a->y, Zero) == 0)
689     {
690         /* Identity */
691         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
692     }
693
694     /* S = 4*X*Y^2 */
695     {
696         Bignum Y2, XY2, _2XY2;
697
698         Y2 = ecf_square(a->y, a->curve);
699         XY2 = modmul(a->x, Y2, a->curve->p);
700         freebn(Y2);
701
702         _2XY2 = ecf_double(XY2, a->curve);
703         freebn(XY2);
704         S = ecf_double(_2XY2, a->curve);
705         freebn(_2XY2);
706     }
707
708     /* Faster calculation if a = -3 */
709     if (aminus3) {
710         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
711         Bignum Z2, XpZ2, XmZ2, second;
712
713         if (a->z == NULL) {
714             Z2 = copybn(One);
715         } else {
716             Z2 = ecf_square(a->z, a->curve);
717         }
718
719         XpZ2 = ecf_add(a->x, Z2, a->curve);
720         XmZ2 = modsub(a->x, Z2, a->curve->p);
721         freebn(Z2);
722
723         second = modmul(XpZ2, XmZ2, a->curve->p);
724         freebn(XpZ2);
725         freebn(XmZ2);
726
727         M = ecf_treble(second, a->curve);
728         freebn(second);
729     } else {
730         /* M = 3*X^2 + a*Z^4 */
731         Bignum _3X2, X2, aZ4;
732
733         if (a->z == NULL) {
734             aZ4 = copybn(a->curve->w.a);
735         } else {
736             Bignum Z2, Z4;
737
738             Z2 = ecf_square(a->z, a->curve);
739             Z4 = ecf_square(Z2, a->curve);
740             freebn(Z2);
741             aZ4 = modmul(a->curve->w.a, Z4, a->curve->p);
742             freebn(Z4);
743         }
744
745         X2 = modmul(a->x, a->x, a->curve->p);
746         _3X2 = ecf_treble(X2, a->curve);
747         freebn(X2);
748         M = ecf_add(_3X2, aZ4, a->curve);
749         freebn(_3X2);
750         freebn(aZ4);
751     }
752
753     /* X' = M^2 - 2*S */
754     {
755         Bignum M2, _2S;
756
757         M2 = ecf_square(M, a->curve);
758         _2S = ecf_double(S, a->curve);
759         outx = modsub(M2, _2S, a->curve->p);
760         freebn(M2);
761         freebn(_2S);
762     }
763
764     /* Y' = M*(S - X') - 8*Y^4 */
765     {
766         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
767
768         SX = modsub(S, outx, a->curve->p);
769         freebn(S);
770         MSX = modmul(M, SX, a->curve->p);
771         freebn(SX);
772         freebn(M);
773         Y2 = ecf_square(a->y, a->curve);
774         Y4 = ecf_square(Y2, a->curve);
775         freebn(Y2);
776         Eight = bignum_from_long(8);
777         _8Y4 = modmul(Eight, Y4, a->curve->p);
778         freebn(Eight);
779         freebn(Y4);
780         outy = modsub(MSX, _8Y4, a->curve->p);
781         freebn(MSX);
782         freebn(_8Y4);
783     }
784
785     /* Z' = 2*Y*Z */
786     {
787         Bignum YZ;
788
789         if (a->z == NULL) {
790             YZ = copybn(a->y);
791         } else {
792             YZ = modmul(a->y, a->z, a->curve->p);
793         }
794
795         outz = ecf_double(YZ, a->curve);
796         freebn(YZ);
797     }
798
799     return ec_point_new(a->curve, outx, outy, outz, 0);
800 }
801
802 static struct ec_point *ecp_doublem(const struct ec_point *a)
803 {
804     Bignum z, outx, outz, xpz, xmz;
805
806     z = a->z;
807     if (!z) {
808         z = One;
809     }
810
811     /* 4xz = (x + z)^2 - (x - z)^2 */
812     {
813         Bignum tmp;
814
815         tmp = ecf_add(a->x, z, a->curve);
816         xpz = ecf_square(tmp, a->curve);
817         freebn(tmp);
818
819         tmp = modsub(a->x, z, a->curve->p);
820         xmz = ecf_square(tmp, a->curve);
821         freebn(tmp);
822     }
823
824     /* outx = (x + z)^2 * (x - z)^2 */
825     outx = modmul(xpz, xmz, a->curve->p);
826
827     /* outz = 4xz * ((x - z)^2 + ((A + 2) / 4)*4xz) */
828     {
829         Bignum _4xz, tmp, tmp2, tmp3;
830
831         tmp = bignum_from_long(2);
832         tmp2 = ecf_add(a->curve->m.a, tmp, a->curve);
833         freebn(tmp);
834
835         _4xz = modsub(xpz, xmz, a->curve->p);
836         freebn(xpz);
837         tmp = modmul(tmp2, _4xz, a->curve->p);
838         freebn(tmp2);
839
840         tmp2 = bignum_from_long(4);
841         tmp3 = modinv(tmp2, a->curve->p);
842         freebn(tmp2);
843         if (!tmp3) {
844             freebn(tmp);
845             freebn(_4xz);
846             freebn(outx);
847             freebn(xmz);
848             return NULL;
849         }
850         tmp2 = modmul(tmp, tmp3, a->curve->p);
851         freebn(tmp);
852         freebn(tmp3);
853
854         tmp = ecf_add(xmz, tmp2, a->curve);
855         freebn(xmz);
856         freebn(tmp2);
857         outz = modmul(_4xz, tmp, a->curve->p);
858         freebn(_4xz);
859         freebn(tmp);
860     }
861
862     return ec_point_new(a->curve, outx, NULL, outz, 0);
863 }
864
865 /* Forward declaration for Edwards curve doubling */
866 static struct ec_point *ecp_add(const struct ec_point *a,
867                                 const struct ec_point *b,
868                                 const int aminus3);
869
870 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
871 {
872     if (a->infinity)
873     {
874         /* Identity */
875         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
876     }
877
878     if (a->curve->type == EC_EDWARDS)
879     {
880         return ecp_add(a, a, aminus3);
881     }
882     else if (a->curve->type == EC_WEIERSTRASS)
883     {
884         return ecp_doublew(a, aminus3);
885     }
886     else
887     {
888         return ecp_doublem(a);
889     }
890 }
891
892 static struct ec_point *ecp_addw(const struct ec_point *a,
893                                  const struct ec_point *b,
894                                  const int aminus3)
895 {
896     Bignum U1, U2, S1, S2, outx, outy, outz;
897
898     /* U1 = X1*Z2^2 */
899     /* S1 = Y1*Z2^3 */
900     if (b->z) {
901         Bignum Z2, Z3;
902
903         Z2 = ecf_square(b->z, a->curve);
904         U1 = modmul(a->x, Z2, a->curve->p);
905         Z3 = modmul(Z2, b->z, a->curve->p);
906         freebn(Z2);
907         S1 = modmul(a->y, Z3, a->curve->p);
908         freebn(Z3);
909     } else {
910         U1 = copybn(a->x);
911         S1 = copybn(a->y);
912     }
913
914     /* U2 = X2*Z1^2 */
915     /* S2 = Y2*Z1^3 */
916     if (a->z) {
917         Bignum Z2, Z3;
918
919         Z2 = ecf_square(a->z, b->curve);
920         U2 = modmul(b->x, Z2, b->curve->p);
921         Z3 = modmul(Z2, a->z, b->curve->p);
922         freebn(Z2);
923         S2 = modmul(b->y, Z3, b->curve->p);
924         freebn(Z3);
925     } else {
926         U2 = copybn(b->x);
927         S2 = copybn(b->y);
928     }
929
930     /* Check if multiplying by self */
931     if (bignum_cmp(U1, U2) == 0)
932     {
933         freebn(U1);
934         freebn(U2);
935         if (bignum_cmp(S1, S2) == 0)
936         {
937             freebn(S1);
938             freebn(S2);
939             return ecp_double(a, aminus3);
940         }
941         else
942         {
943             freebn(S1);
944             freebn(S2);
945             /* Infinity */
946             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
947         }
948     }
949
950     {
951         Bignum H, R, UH2, H3;
952
953         /* H = U2 - U1 */
954         H = modsub(U2, U1, a->curve->p);
955         freebn(U2);
956
957         /* R = S2 - S1 */
958         R = modsub(S2, S1, a->curve->p);
959         freebn(S2);
960
961         /* X3 = R^2 - H^3 - 2*U1*H^2 */
962         {
963             Bignum R2, H2, _2UH2, first;
964
965             H2 = ecf_square(H, a->curve);
966             UH2 = modmul(U1, H2, a->curve->p);
967             freebn(U1);
968             H3 = modmul(H2, H, a->curve->p);
969             freebn(H2);
970             R2 = ecf_square(R, a->curve);
971             _2UH2 = ecf_double(UH2, a->curve);
972             first = modsub(R2, H3, a->curve->p);
973             freebn(R2);
974             outx = modsub(first, _2UH2, a->curve->p);
975             freebn(first);
976             freebn(_2UH2);
977         }
978
979         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
980         {
981             Bignum RUH2mX, UH2mX, SH3;
982
983             UH2mX = modsub(UH2, outx, a->curve->p);
984             freebn(UH2);
985             RUH2mX = modmul(R, UH2mX, a->curve->p);
986             freebn(UH2mX);
987             freebn(R);
988             SH3 = modmul(S1, H3, a->curve->p);
989             freebn(S1);
990             freebn(H3);
991
992             outy = modsub(RUH2mX, SH3, a->curve->p);
993             freebn(RUH2mX);
994             freebn(SH3);
995         }
996
997         /* Z3 = H*Z1*Z2 */
998         if (a->z && b->z) {
999             Bignum ZZ;
1000
1001             ZZ = modmul(a->z, b->z, a->curve->p);
1002             outz = modmul(H, ZZ, a->curve->p);
1003             freebn(H);
1004             freebn(ZZ);
1005         } else if (a->z) {
1006             outz = modmul(H, a->z, a->curve->p);
1007             freebn(H);
1008         } else if (b->z) {
1009             outz = modmul(H, b->z, a->curve->p);
1010             freebn(H);
1011         } else {
1012             outz = H;
1013         }
1014     }
1015
1016     return ec_point_new(a->curve, outx, outy, outz, 0);
1017 }
1018
1019 static struct ec_point *ecp_addm(const struct ec_point *a,
1020                                  const struct ec_point *b,
1021                                  const struct ec_point *base)
1022 {
1023     Bignum outx, outz, az, bz;
1024
1025     az = a->z;
1026     if (!az) {
1027         az = One;
1028     }
1029     bz = b->z;
1030     if (!bz) {
1031         bz = One;
1032     }
1033
1034     /* a-b is maintained at 1 due to Montgomery ladder implementation */
1035     /* Xa+b = Za-b * ((Xa - Za)*(Xb + Zb) + (Xa + Za)*(Xb - Zb))^2 */
1036     /* Za+b = Xa-b * ((Xa - Za)*(Xb + Zb) - (Xa + Za)*(Xb - Zb))^2 */
1037     {
1038         Bignum tmp, tmp2, tmp3, tmp4;
1039
1040         /* (Xa + Za) * (Xb - Zb) */
1041         tmp = ecf_add(a->x, az, a->curve);
1042         tmp2 = modsub(b->x, bz, a->curve->p);
1043         tmp3 = modmul(tmp, tmp2, a->curve->p);
1044         freebn(tmp);
1045         freebn(tmp2);
1046
1047         /* (Xa - Za) * (Xb + Zb) */
1048         tmp = modsub(a->x, az, a->curve->p);
1049         tmp2 = ecf_add(b->x, bz, a->curve);
1050         tmp4 = modmul(tmp, tmp2, a->curve->p);
1051         freebn(tmp);
1052         freebn(tmp2);
1053
1054         tmp = ecf_add(tmp3, tmp4, a->curve);
1055         outx = ecf_square(tmp, a->curve);
1056         freebn(tmp);
1057
1058         tmp = modsub(tmp3, tmp4, a->curve->p);
1059         freebn(tmp3);
1060         freebn(tmp4);
1061         tmp2 = ecf_square(tmp, a->curve);
1062         freebn(tmp);
1063         outz = modmul(base->x, tmp2, a->curve->p);
1064         freebn(tmp2);
1065     }
1066
1067     return ec_point_new(a->curve, outx, NULL, outz, 0);
1068 }
1069
1070 static struct ec_point *ecp_adde(const struct ec_point *a,
1071                                  const struct ec_point *b)
1072 {
1073     Bignum outx, outy, dmul;
1074
1075     /* outx = (a->x * b->y + b->x * a->y) /
1076      *        (1 + a->curve->e.d * a->x * b->x * a->y * b->y) */
1077     {
1078         Bignum tmp, tmp2, tmp3, tmp4;
1079
1080         tmp = modmul(a->x, b->y, a->curve->p);
1081         tmp2 = modmul(b->x, a->y, a->curve->p);
1082         tmp3 = ecf_add(tmp, tmp2, a->curve);
1083
1084         tmp4 = modmul(tmp, tmp2, a->curve->p);
1085         freebn(tmp);
1086         freebn(tmp2);
1087         dmul = modmul(a->curve->e.d, tmp4, a->curve->p);
1088         freebn(tmp4);
1089
1090         tmp = ecf_add(One, dmul, a->curve);
1091         tmp2 = modinv(tmp, a->curve->p);
1092         freebn(tmp);
1093         if (!tmp2)
1094         {
1095             freebn(tmp3);
1096             freebn(dmul);
1097             return NULL;
1098         }
1099
1100         outx = modmul(tmp3, tmp2, a->curve->p);
1101         freebn(tmp3);
1102         freebn(tmp2);
1103     }
1104
1105     /* outy = (a->y * b->y + a->x * b->x) /
1106      *        (1 - a->curve->e.d * a->x * b->x * a->y * b->y) */
1107     {
1108         Bignum tmp, tmp2, tmp3, tmp4;
1109
1110         tmp = modsub(One, dmul, a->curve->p);
1111         freebn(dmul);
1112
1113         tmp2 = modinv(tmp, a->curve->p);
1114         freebn(tmp);
1115         if (!tmp2)
1116         {
1117             freebn(outx);
1118             return NULL;
1119         }
1120
1121         tmp = modmul(a->y, b->y, a->curve->p);
1122         tmp3 = modmul(a->x, b->x, a->curve->p);
1123         tmp4 = ecf_add(tmp, tmp3, a->curve);
1124         freebn(tmp);
1125         freebn(tmp3);
1126
1127         outy = modmul(tmp4, tmp2, a->curve->p);
1128         freebn(tmp4);
1129         freebn(tmp2);
1130     }
1131
1132     return ec_point_new(a->curve, outx, outy, NULL, 0);
1133 }
1134
1135 static struct ec_point *ecp_add(const struct ec_point *a,
1136                                 const struct ec_point *b,
1137                                 const int aminus3)
1138 {
1139     if (a->curve != b->curve) {
1140         return NULL;
1141     }
1142
1143     /* Check if multiplying by infinity */
1144     if (a->infinity) return ec_point_copy(b);
1145     if (b->infinity) return ec_point_copy(a);
1146
1147     if (a->curve->type == EC_EDWARDS)
1148     {
1149         return ecp_adde(a, b);
1150     }
1151
1152     if (a->curve->type == EC_WEIERSTRASS)
1153     {
1154         return ecp_addw(a, b, aminus3);
1155     }
1156
1157     return NULL;
1158 }
1159
1160 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1161 {
1162     struct ec_point *A, *ret;
1163     int bits, i;
1164
1165     A = ec_point_copy(a);
1166     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1167
1168     bits = bignum_bitcount(b);
1169     for (i = 0; i < bits; ++i)
1170     {
1171         if (bignum_bit(b, i))
1172         {
1173             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1174             ec_point_free(ret);
1175             ret = tmp;
1176         }
1177         if (i+1 != bits)
1178         {
1179             struct ec_point *tmp = ecp_double(A, aminus3);
1180             ec_point_free(A);
1181             A = tmp;
1182         }
1183     }
1184
1185     ec_point_free(A);
1186     return ret;
1187 }
1188
1189 static struct ec_point *ecp_mulw(const struct ec_point *a, const Bignum b)
1190 {
1191     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1192
1193     if (!ecp_normalise(ret)) {
1194         ec_point_free(ret);
1195         return NULL;
1196     }
1197
1198     return ret;
1199 }
1200
1201 static struct ec_point *ecp_mule(const struct ec_point *a, const Bignum b)
1202 {
1203     int i;
1204     struct ec_point *ret;
1205
1206     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1207
1208     for (i = bignum_bitcount(b); i >= 0 && ret; --i)
1209     {
1210         {
1211             struct ec_point *tmp = ecp_double(ret, 0);
1212             ec_point_free(ret);
1213             ret = tmp;
1214         }
1215         if (ret && bignum_bit(b, i))
1216         {
1217             struct ec_point *tmp = ecp_add(ret, a, 0);
1218             ec_point_free(ret);
1219             ret = tmp;
1220         }
1221     }
1222
1223     return ret;
1224 }
1225
1226 static struct ec_point *ecp_mulm(const struct ec_point *p, const Bignum n)
1227 {
1228     struct ec_point *P1, *P2;
1229     int bits, i;
1230
1231     /* P1 <- P and P2 <- [2]P */
1232     P2 = ecp_double(p, 0);
1233     P1 = ec_point_copy(p);
1234
1235     /* for i = bits âˆ’ 2 down to 0 */
1236     bits = bignum_bitcount(n);
1237     for (i = bits - 2; i >= 0; --i)
1238     {
1239         if (!bignum_bit(n, i))
1240         {
1241             /* P2 <- P1 + P2 */
1242             struct ec_point *tmp = ecp_addm(P1, P2, p);
1243             ec_point_free(P2);
1244             P2 = tmp;
1245
1246             /* P1 <- [2]P1 */
1247             tmp = ecp_double(P1, 0);
1248             ec_point_free(P1);
1249             P1 = tmp;
1250         }
1251         else
1252         {
1253             /* P1 <- P1 + P2 */
1254             struct ec_point *tmp = ecp_addm(P1, P2, p);
1255             ec_point_free(P1);
1256             P1 = tmp;
1257
1258             /* P2 <- [2]P2 */
1259             tmp = ecp_double(P2, 0);
1260             ec_point_free(P2);
1261             P2 = tmp;
1262         }
1263     }
1264
1265     ec_point_free(P2);
1266
1267     if (!ecp_normalise(P1)) {
1268         ec_point_free(P1);
1269         return NULL;
1270     }
1271
1272     return P1;
1273 }
1274
1275 /* Not static because it is used by sshecdsag.c to generate a new key */
1276 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1277 {
1278     if (a->curve->type == EC_WEIERSTRASS) {
1279         return ecp_mulw(a, b);
1280     } else if (a->curve->type == EC_EDWARDS) {
1281         return ecp_mule(a, b);
1282     } else {
1283         return ecp_mulm(a, b);
1284     }
1285 }
1286
1287 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
1288                                    const struct ec_point *point)
1289 {
1290     struct ec_point *aG, *bP, *ret;
1291     int aminus3;
1292
1293     if (point->curve->type != EC_WEIERSTRASS) {
1294         return NULL;
1295     }
1296
1297     aminus3 = ec_aminus3(point->curve);
1298
1299     aG = ecp_mul_(&point->curve->w.G, a, aminus3);
1300     if (!aG) return NULL;
1301     bP = ecp_mul_(point, b, aminus3);
1302     if (!bP) {
1303         ec_point_free(aG);
1304         return NULL;
1305     }
1306
1307     ret = ecp_add(aG, bP, aminus3);
1308
1309     ec_point_free(aG);
1310     ec_point_free(bP);
1311
1312     if (!ecp_normalise(ret)) {
1313         ec_point_free(ret);
1314         return NULL;
1315     }
1316
1317     return ret;
1318 }
1319 static Bignum *ecp_edx(const struct ec_curve *curve, const Bignum y)
1320 {
1321     /* Get the x value on the given Edwards curve for a given y */
1322     Bignum x, xx;
1323
1324     /* xx = (y^2 - 1) / (d * y^2 + 1) */
1325     {
1326         Bignum tmp, tmp2, tmp3;
1327
1328         tmp = ecf_square(y, curve);
1329         tmp2 = modmul(curve->e.d, tmp, curve->p);
1330         tmp3 = ecf_add(tmp2, One, curve);
1331         freebn(tmp2);
1332         tmp2 = modinv(tmp3, curve->p);
1333         freebn(tmp3);
1334         if (!tmp2) {
1335             freebn(tmp);
1336             return NULL;
1337         }
1338
1339         tmp3 = modsub(tmp, One, curve->p);
1340         freebn(tmp);
1341         xx = modmul(tmp3, tmp2, curve->p);
1342         freebn(tmp3);
1343         freebn(tmp2);
1344     }
1345
1346     /* x = xx^((p + 3) / 8) */
1347     {
1348         Bignum tmp, tmp2;
1349
1350         tmp = bignum_add_long(curve->p, 3);
1351         tmp2 = bignum_rshift(tmp, 3);
1352         freebn(tmp);
1353         x = modpow(xx, tmp2, curve->p);
1354         freebn(tmp2);
1355     }
1356
1357     /* if x^2 - xx != 0 then x = x*(2^((p - 1) / 4)) */
1358     {
1359         Bignum tmp, tmp2;
1360
1361         tmp = ecf_square(x, curve);
1362         tmp2 = modsub(tmp, xx, curve->p);
1363         freebn(tmp);
1364         freebn(xx);
1365         if (bignum_cmp(tmp2, Zero)) {
1366             Bignum tmp3;
1367
1368             freebn(tmp2);
1369
1370             tmp = modsub(curve->p, One, curve->p);
1371             tmp2 = bignum_rshift(tmp, 2);
1372             freebn(tmp);
1373             tmp = bignum_from_long(2);
1374             tmp3 = modpow(tmp, tmp2, curve->p);
1375             freebn(tmp);
1376             freebn(tmp2);
1377
1378             tmp = modmul(x, tmp3, curve->p);
1379             freebn(x);
1380             freebn(tmp3);
1381             x = tmp;
1382         } else {
1383             freebn(tmp2);
1384         }
1385     }
1386
1387     /* if x % 2 != 0 then x = p - x */
1388     if (bignum_bit(x, 0)) {
1389         Bignum tmp = modsub(curve->p, x, curve->p);
1390         freebn(x);
1391         x = tmp;
1392     }
1393
1394     return x;
1395 }
1396
1397 /* ----------------------------------------------------------------------
1398  * Public point from private
1399  */
1400
1401 struct ec_point *ec_public(const Bignum privateKey, const struct ec_curve *curve)
1402 {
1403     if (curve->type == EC_WEIERSTRASS) {
1404         return ecp_mul(&curve->w.G, privateKey);
1405     } else if (curve->type == EC_EDWARDS) {
1406         /* hash = H(sk) (where hash creates 2 * fieldBits)
1407          * b = fieldBits
1408          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
1409          * publicKey = aB */
1410         struct ec_point *ret;
1411         unsigned char hash[512/8];
1412         Bignum a;
1413         int i, keylen;
1414         SHA512_State s;
1415         SHA512_Init(&s);
1416
1417         keylen = curve->fieldBits / 8;
1418         for (i = 0; i < keylen; ++i) {
1419             unsigned char b = bignum_byte(privateKey, i);
1420             SHA512_Bytes(&s, &b, 1);
1421         }
1422         SHA512_Final(&s, hash);
1423
1424         /* The second part is simply turning the hash into a Bignum,
1425          * however the 2^(b-2) bit *must* be set, and the bottom 3
1426          * bits *must* not be */
1427         hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
1428         hash[31] &= 0x7f; /* Unset above (b-2) */
1429         hash[31] |= 0x40; /* Set 2^(b-2) */
1430         /* Chop off the top part and convert to int */
1431         a = bignum_from_bytes_le(hash, 32);
1432
1433         ret = ecp_mul(&curve->e.B, a);
1434         freebn(a);
1435         return ret;
1436     } else {
1437         return NULL;
1438     }
1439 }
1440
1441 /* ----------------------------------------------------------------------
1442  * Basic sign and verify routines
1443  */
1444
1445 static int _ecdsa_verify(const struct ec_point *publicKey,
1446                          const unsigned char *data, const int dataLen,
1447                          const Bignum r, const Bignum s)
1448 {
1449     int z_bits, n_bits;
1450     Bignum z;
1451     int valid = 0;
1452
1453     if (publicKey->curve->type != EC_WEIERSTRASS) {
1454         return 0;
1455     }
1456
1457     /* Sanity checks */
1458     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->w.n) >= 0
1459         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->w.n) >= 0)
1460     {
1461         return 0;
1462     }
1463
1464     /* z = left most bitlen(curve->n) of data */
1465     z = bignum_from_bytes(data, dataLen);
1466     n_bits = bignum_bitcount(publicKey->curve->w.n);
1467     z_bits = bignum_bitcount(z);
1468     if (z_bits > n_bits)
1469     {
1470         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
1471         freebn(z);
1472         z = tmp;
1473     }
1474
1475     /* Ensure z in range of n */
1476     {
1477         Bignum tmp = bigmod(z, publicKey->curve->w.n);
1478         freebn(z);
1479         z = tmp;
1480     }
1481
1482     /* Calculate signature */
1483     {
1484         Bignum w, x, u1, u2;
1485         struct ec_point *tmp;
1486
1487         w = modinv(s, publicKey->curve->w.n);
1488         if (!w) {
1489             freebn(z);
1490             return 0;
1491         }
1492         u1 = modmul(z, w, publicKey->curve->w.n);
1493         u2 = modmul(r, w, publicKey->curve->w.n);
1494         freebn(w);
1495
1496         tmp = ecp_summul(u1, u2, publicKey);
1497         freebn(u1);
1498         freebn(u2);
1499         if (!tmp) {
1500             freebn(z);
1501             return 0;
1502         }
1503
1504         x = bigmod(tmp->x, publicKey->curve->w.n);
1505         ec_point_free(tmp);
1506
1507         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
1508         freebn(x);
1509     }
1510
1511     freebn(z);
1512
1513     return valid;
1514 }
1515
1516 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
1517                         const unsigned char *data, const int dataLen,
1518                         Bignum *r, Bignum *s)
1519 {
1520     unsigned char digest[20];
1521     int z_bits, n_bits;
1522     Bignum z, k;
1523     struct ec_point *kG;
1524
1525     *r = NULL;
1526     *s = NULL;
1527
1528     if (curve->type != EC_WEIERSTRASS) {
1529         return;
1530     }
1531
1532     /* z = left most bitlen(curve->n) of data */
1533     z = bignum_from_bytes(data, dataLen);
1534     n_bits = bignum_bitcount(curve->w.n);
1535     z_bits = bignum_bitcount(z);
1536     if (z_bits > n_bits)
1537     {
1538         Bignum tmp;
1539         tmp = bignum_rshift(z, z_bits - n_bits);
1540         freebn(z);
1541         z = tmp;
1542     }
1543
1544     /* Generate k between 1 and curve->n, using the same deterministic
1545      * k generation system we use for conventional DSA. */
1546     SHA_Simple(data, dataLen, digest);
1547     k = dss_gen_k("ECDSA deterministic k generator", curve->w.n, privateKey,
1548                   digest, sizeof(digest));
1549
1550     kG = ecp_mul(&curve->w.G, k);
1551     if (!kG) {
1552         freebn(z);
1553         freebn(k);
1554         return;
1555     }
1556
1557     /* r = kG.x mod n */
1558     *r = bigmod(kG->x, curve->w.n);
1559     ec_point_free(kG);
1560
1561     /* s = (z + r * priv)/k mod n */
1562     {
1563         Bignum rPriv, zMod, first, firstMod, kInv;
1564         rPriv = modmul(*r, privateKey, curve->w.n);
1565         zMod = bigmod(z, curve->w.n);
1566         freebn(z);
1567         first = bigadd(rPriv, zMod);
1568         freebn(rPriv);
1569         freebn(zMod);
1570         firstMod = bigmod(first, curve->w.n);
1571         freebn(first);
1572         kInv = modinv(k, curve->w.n);
1573         freebn(k);
1574         if (!kInv) {
1575             freebn(firstMod);
1576             freebn(*r);
1577             return;
1578         }
1579         *s = modmul(firstMod, kInv, curve->w.n);
1580         freebn(firstMod);
1581         freebn(kInv);
1582     }
1583 }
1584
1585 /* ----------------------------------------------------------------------
1586  * Misc functions
1587  */
1588
1589 static void getstring(const char **data, int *datalen,
1590                       const char **p, int *length)
1591 {
1592     *p = NULL;
1593     if (*datalen < 4)
1594         return;
1595     *length = toint(GET_32BIT(*data));
1596     if (*length < 0)
1597         return;
1598     *datalen -= 4;
1599     *data += 4;
1600     if (*datalen < *length)
1601         return;
1602     *p = *data;
1603     *data += *length;
1604     *datalen -= *length;
1605 }
1606
1607 static Bignum getmp(const char **data, int *datalen)
1608 {
1609     const char *p;
1610     int length;
1611
1612     getstring(data, datalen, &p, &length);
1613     if (!p)
1614         return NULL;
1615     if (p[0] & 0x80)
1616         return NULL;                   /* negative mp */
1617     return bignum_from_bytes((unsigned char *)p, length);
1618 }
1619
1620 static Bignum getmp_le(const char **data, int *datalen)
1621 {
1622     const char *p;
1623     int length;
1624
1625     getstring(data, datalen, &p, &length);
1626     if (!p)
1627         return NULL;
1628     return bignum_from_bytes_le((const unsigned char *)p, length);
1629 }
1630
1631 static int decodepoint_ed(const char *p, int length, struct ec_point *point)
1632 {
1633     /* Got some conversion to do, first read in the y co-ord */
1634     int negative;
1635
1636     point->y = bignum_from_bytes_le((const unsigned char*)p, length);
1637     if ((unsigned)bignum_bitcount(point->y) > point->curve->fieldBits) {
1638         freebn(point->y);
1639         point->y = NULL;
1640         return 0;
1641     }
1642     /* Read x bit and then reset it */
1643     negative = bignum_bit(point->y, point->curve->fieldBits - 1);
1644     bignum_set_bit(point->y, point->curve->fieldBits - 1, 0);
1645
1646     /* Get the x from the y */
1647     point->x = ecp_edx(point->curve, point->y);
1648     if (!point->x) {
1649         freebn(point->y);
1650         point->y = NULL;
1651         return 0;
1652     }
1653     if (negative) {
1654         Bignum tmp = modsub(point->curve->p, point->x, point->curve->p);
1655         freebn(point->x);
1656         point->x = tmp;
1657     }
1658
1659     /* Verify the point is on the curve */
1660     if (!ec_point_verify(point)) {
1661         freebn(point->x);
1662         point->x = NULL;
1663         freebn(point->y);
1664         point->y = NULL;
1665         return 0;
1666     }
1667
1668     return 1;
1669 }
1670
1671 static int decodepoint(const char *p, int length, struct ec_point *point)
1672 {
1673     if (point->curve->type == EC_EDWARDS) {
1674         return decodepoint_ed(p, length, point);
1675     }
1676
1677     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
1678         return 0;
1679     /* Skip compression flag */
1680     ++p;
1681     --length;
1682     /* The two values must be equal length */
1683     if (length % 2 != 0) {
1684         point->x = NULL;
1685         point->y = NULL;
1686         point->z = NULL;
1687         return 0;
1688     }
1689     length = length / 2;
1690     point->x = bignum_from_bytes((const unsigned char *)p, length);
1691     p += length;
1692     point->y = bignum_from_bytes((const unsigned char *)p, length);
1693     point->z = NULL;
1694
1695     /* Verify the point is on the curve */
1696     if (!ec_point_verify(point)) {
1697         freebn(point->x);
1698         point->x = NULL;
1699         freebn(point->y);
1700         point->y = NULL;
1701         return 0;
1702     }
1703
1704     return 1;
1705 }
1706
1707 static int getmppoint(const char **data, int *datalen, struct ec_point *point)
1708 {
1709     const char *p;
1710     int length;
1711
1712     getstring(data, datalen, &p, &length);
1713     if (!p) return 0;
1714     return decodepoint(p, length, point);
1715 }
1716
1717 /* ----------------------------------------------------------------------
1718  * Exposed ECDSA interface
1719  */
1720
1721 struct ecsign_extra {
1722     struct ec_curve *(*curve)(void);
1723     const struct ssh_hash *hash;
1724
1725     /* These fields are used by the OpenSSH PEM format importer/exporter */
1726     const unsigned char *oid;
1727     int oidlen;
1728 };
1729
1730 static void ecdsa_freekey(void *key)
1731 {
1732     struct ec_key *ec = (struct ec_key *) key;
1733     if (!ec) return;
1734
1735     if (ec->publicKey.x)
1736         freebn(ec->publicKey.x);
1737     if (ec->publicKey.y)
1738         freebn(ec->publicKey.y);
1739     if (ec->publicKey.z)
1740         freebn(ec->publicKey.z);
1741     if (ec->privateKey)
1742         freebn(ec->privateKey);
1743     sfree(ec);
1744 }
1745
1746 static void *ecdsa_newkey(const struct ssh_signkey *self,
1747                           const char *data, int len)
1748 {
1749     const struct ecsign_extra *extra =
1750         (const struct ecsign_extra *)self->extra;
1751     const char *p;
1752     int slen;
1753     struct ec_key *ec;
1754     struct ec_curve *curve;
1755
1756     getstring(&data, &len, &p, &slen);
1757
1758     if (!p) {
1759         return NULL;
1760     }
1761     curve = extra->curve();
1762     assert(curve->type == EC_WEIERSTRASS || curve->type == EC_EDWARDS);
1763
1764     /* Curve name is duplicated for Weierstrass form */
1765     if (curve->type == EC_WEIERSTRASS) {
1766         getstring(&data, &len, &p, &slen);
1767         if (!match_ssh_id(slen, p, curve->name)) return NULL;
1768     }
1769
1770     ec = snew(struct ec_key);
1771
1772     ec->signalg = self;
1773     ec->publicKey.curve = curve;
1774     ec->publicKey.infinity = 0;
1775     ec->publicKey.x = NULL;
1776     ec->publicKey.y = NULL;
1777     ec->publicKey.z = NULL;
1778     if (!getmppoint(&data, &len, &ec->publicKey)) {
1779         ecdsa_freekey(ec);
1780         return NULL;
1781     }
1782     ec->privateKey = NULL;
1783
1784     if (!ec->publicKey.x || !ec->publicKey.y ||
1785         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
1786         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
1787     {
1788         ecdsa_freekey(ec);
1789         ec = NULL;
1790     }
1791
1792     return ec;
1793 }
1794
1795 static char *ecdsa_fmtkey(void *key)
1796 {
1797     struct ec_key *ec = (struct ec_key *) key;
1798     char *p;
1799     int len, i, pos, nibbles;
1800     static const char hex[] = "0123456789abcdef";
1801     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
1802         return NULL;
1803
1804     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
1805     if (ec->publicKey.curve->name)
1806         len += strlen(ec->publicKey.curve->name); /* Curve name */
1807     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
1808     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
1809     p = snewn(len, char);
1810
1811     pos = 0;
1812     if (ec->publicKey.curve->name)
1813         pos += sprintf(p + pos, "%s,", ec->publicKey.curve->name);
1814     pos += sprintf(p + pos, "0x");
1815     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
1816     if (nibbles < 1)
1817         nibbles = 1;
1818     for (i = nibbles; i--;) {
1819         p[pos++] =
1820             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
1821     }
1822     pos += sprintf(p + pos, ",0x");
1823     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
1824     if (nibbles < 1)
1825         nibbles = 1;
1826     for (i = nibbles; i--;) {
1827         p[pos++] =
1828             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
1829     }
1830     p[pos] = '\0';
1831     return p;
1832 }
1833
1834 static unsigned char *ecdsa_public_blob(void *key, int *len)
1835 {
1836     struct ec_key *ec = (struct ec_key *) key;
1837     int pointlen, bloblen, fullnamelen, namelen;
1838     int i;
1839     unsigned char *blob, *p;
1840
1841     fullnamelen = strlen(ec->signalg->name);
1842
1843     if (ec->publicKey.curve->type == EC_EDWARDS) {
1844         /* Edwards compressed form "ssh-ed25519" point y[:-1] + x[0:1] */
1845
1846         pointlen = ec->publicKey.curve->fieldBits / 8;
1847
1848         /* Can't handle this in our loop */
1849         if (pointlen < 2) return NULL;
1850
1851         bloblen = 4 + fullnamelen + 4 + pointlen;
1852         blob = snewn(bloblen, unsigned char);
1853
1854         p = blob;
1855         PUT_32BIT(p, fullnamelen);
1856         p += 4;
1857         memcpy(p, ec->signalg->name, fullnamelen);
1858         p += fullnamelen;
1859         PUT_32BIT(p, pointlen);
1860         p += 4;
1861
1862         /* Unset last bit of y and set first bit of x in its place */
1863         for (i = 0; i < pointlen - 1; ++i) {
1864             *p++ = bignum_byte(ec->publicKey.y, i);
1865         }
1866         /* Unset last bit of y and set first bit of x in its place */
1867         *p = bignum_byte(ec->publicKey.y, i) & 0x7f;
1868         *p++ |= bignum_bit(ec->publicKey.x, 0) << 7;
1869     } else if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
1870         assert(ec->publicKey.curve->name);
1871         namelen = strlen(ec->publicKey.curve->name);
1872
1873         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1874
1875         /*
1876          * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
1877          */
1878         bloblen = 4 + fullnamelen + 4 + namelen + 4 + 1 + (pointlen * 2);
1879         blob = snewn(bloblen, unsigned char);
1880
1881         p = blob;
1882         PUT_32BIT(p, fullnamelen);
1883         p += 4;
1884         memcpy(p, ec->signalg->name, fullnamelen);
1885         p += fullnamelen;
1886         PUT_32BIT(p, namelen);
1887         p += 4;
1888         memcpy(p, ec->publicKey.curve->name, namelen);
1889         p += namelen;
1890         PUT_32BIT(p, (2 * pointlen) + 1);
1891         p += 4;
1892         *p++ = 0x04;
1893         for (i = pointlen; i--;) {
1894             *p++ = bignum_byte(ec->publicKey.x, i);
1895         }
1896         for (i = pointlen; i--;) {
1897             *p++ = bignum_byte(ec->publicKey.y, i);
1898         }
1899     } else {
1900         return NULL;
1901     }
1902
1903     assert(p == blob + bloblen);
1904     *len = bloblen;
1905
1906     return blob;
1907 }
1908
1909 static unsigned char *ecdsa_private_blob(void *key, int *len)
1910 {
1911     struct ec_key *ec = (struct ec_key *) key;
1912     int keylen, bloblen;
1913     int i;
1914     unsigned char *blob, *p;
1915
1916     if (!ec->privateKey) return NULL;
1917
1918     if (ec->publicKey.curve->type == EC_EDWARDS) {
1919         /* Unsigned */
1920         keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
1921     } else {
1922         /* Signed */
1923         keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
1924     }
1925
1926     /*
1927      * mpint privateKey. Total 4 + keylen.
1928      */
1929     bloblen = 4 + keylen;
1930     blob = snewn(bloblen, unsigned char);
1931
1932     p = blob;
1933     PUT_32BIT(p, keylen);
1934     p += 4;
1935     if (ec->publicKey.curve->type == EC_EDWARDS) {
1936         /* Little endian */
1937         for (i = 0; i < keylen; ++i)
1938             *p++ = bignum_byte(ec->privateKey, i);
1939     } else {
1940         for (i = keylen; i--;)
1941             *p++ = bignum_byte(ec->privateKey, i);
1942     }
1943
1944     assert(p == blob + bloblen);
1945     *len = bloblen;
1946     return blob;
1947 }
1948
1949 static void *ecdsa_createkey(const struct ssh_signkey *self,
1950                              const unsigned char *pub_blob, int pub_len,
1951                              const unsigned char *priv_blob, int priv_len)
1952 {
1953     struct ec_key *ec;
1954     struct ec_point *publicKey;
1955     const char *pb = (const char *) priv_blob;
1956
1957     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) pub_blob, pub_len);
1958     if (!ec) {
1959         return NULL;
1960     }
1961
1962     if (ec->publicKey.curve->type != EC_WEIERSTRASS
1963         && ec->publicKey.curve->type != EC_EDWARDS) {
1964         ecdsa_freekey(ec);
1965         return NULL;
1966     }
1967
1968     if (ec->publicKey.curve->type == EC_EDWARDS) {
1969         ec->privateKey = getmp_le(&pb, &priv_len);
1970     } else {
1971         ec->privateKey = getmp(&pb, &priv_len);
1972     }
1973     if (!ec->privateKey) {
1974         ecdsa_freekey(ec);
1975         return NULL;
1976     }
1977
1978     /* Check that private key generates public key */
1979     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
1980
1981     if (!publicKey ||
1982         bignum_cmp(publicKey->x, ec->publicKey.x) ||
1983         bignum_cmp(publicKey->y, ec->publicKey.y))
1984     {
1985         ecdsa_freekey(ec);
1986         ec = NULL;
1987     }
1988     ec_point_free(publicKey);
1989
1990     return ec;
1991 }
1992
1993 static void *ed25519_openssh_createkey(const struct ssh_signkey *self,
1994                                        const unsigned char **blob, int *len)
1995 {
1996     struct ec_key *ec;
1997     struct ec_point *publicKey;
1998     const char *p, *q;
1999     int plen, qlen;
2000
2001     getstring((const char**)blob, len, &p, &plen);
2002     if (!p)
2003     {
2004         return NULL;
2005     }
2006
2007     ec = snew(struct ec_key);
2008
2009     ec->signalg = self;
2010     ec->publicKey.curve = ec_ed25519();
2011     ec->publicKey.infinity = 0;
2012     ec->privateKey = NULL;
2013     ec->publicKey.x = NULL;
2014     ec->publicKey.z = NULL;
2015     ec->publicKey.y = NULL;
2016
2017     if (!decodepoint_ed(p, plen, &ec->publicKey))
2018     {
2019         ecdsa_freekey(ec);
2020         return NULL;
2021     }
2022
2023     getstring((const char**)blob, len, &q, &qlen);
2024     if (!q)
2025         return NULL;
2026     if (qlen != 64)
2027         return NULL;
2028
2029     ec->privateKey = bignum_from_bytes_le((const unsigned char *)q, 32);
2030
2031     /* Check that private key generates public key */
2032     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2033
2034     if (!publicKey ||
2035         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2036         bignum_cmp(publicKey->y, ec->publicKey.y))
2037     {
2038         ecdsa_freekey(ec);
2039         ec = NULL;
2040     }
2041     ec_point_free(publicKey);
2042
2043     /* The OpenSSH format for ed25519 private keys also for some
2044      * reason encodes an extra copy of the public key in the second
2045      * half of the secret-key string. Check that that's present and
2046      * correct as well, otherwise the key we think we've imported
2047      * won't behave identically to the way OpenSSH would have treated
2048      * it. */
2049     if (plen != 32 || 0 != memcmp(q + 32, p, 32)) {
2050         ecdsa_freekey(ec);
2051         return NULL;
2052     }
2053
2054     return ec;
2055 }
2056
2057 static int ed25519_openssh_fmtkey(void *key, unsigned char *blob, int len)
2058 {
2059     struct ec_key *ec = (struct ec_key *) key;
2060
2061     int pointlen;
2062     int keylen;
2063     int bloblen;
2064     int i;
2065
2066     if (ec->publicKey.curve->type != EC_EDWARDS) {
2067         return 0;
2068     }
2069
2070     pointlen = (bignum_bitcount(ec->publicKey.y) + 7) / 8;
2071     keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2072     bloblen = 4 + pointlen + 4 + keylen + pointlen;
2073
2074     if (bloblen > len)
2075         return bloblen;
2076
2077     /* Encode the public point */
2078     PUT_32BIT(blob, pointlen);
2079     blob += 4;
2080
2081     for (i = 0; i < pointlen - 1; ++i) {
2082          *blob++ = bignum_byte(ec->publicKey.y, i);
2083     }
2084     /* Unset last bit of y and set first bit of x in its place */
2085     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2086     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2087
2088     PUT_32BIT(blob, keylen + pointlen);
2089     blob += 4;
2090     for (i = 0; i < keylen; ++i) {
2091          *blob++ = bignum_byte(ec->privateKey, i);
2092     }
2093     /* Now encode an extra copy of the public point as the second half
2094      * of the private key string, as the OpenSSH format for some
2095      * reason requires */
2096     for (i = 0; i < pointlen - 1; ++i) {
2097          *blob++ = bignum_byte(ec->publicKey.y, i);
2098     }
2099     /* Unset last bit of y and set first bit of x in its place */
2100     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2101     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2102
2103     return bloblen;
2104 }
2105
2106 static void *ecdsa_openssh_createkey(const struct ssh_signkey *self,
2107                                      const unsigned char **blob, int *len)
2108 {
2109     const struct ecsign_extra *extra =
2110         (const struct ecsign_extra *)self->extra;
2111     const char **b = (const char **) blob;
2112     const char *p;
2113     int slen;
2114     struct ec_key *ec;
2115     struct ec_curve *curve;
2116     struct ec_point *publicKey;
2117
2118     getstring(b, len, &p, &slen);
2119
2120     if (!p) {
2121         return NULL;
2122     }
2123     curve = extra->curve();
2124     assert(curve->type == EC_WEIERSTRASS);
2125
2126     ec = snew(struct ec_key);
2127
2128     ec->signalg = self;
2129     ec->publicKey.curve = curve;
2130     ec->publicKey.infinity = 0;
2131     ec->publicKey.x = NULL;
2132     ec->publicKey.y = NULL;
2133     ec->publicKey.z = NULL;
2134     if (!getmppoint(b, len, &ec->publicKey)) {
2135         ecdsa_freekey(ec);
2136         return NULL;
2137     }
2138     ec->privateKey = NULL;
2139
2140     if (!ec->publicKey.x || !ec->publicKey.y ||
2141         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2142         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2143     {
2144         ecdsa_freekey(ec);
2145         return NULL;
2146     }
2147
2148     ec->privateKey = getmp(b, len);
2149     if (ec->privateKey == NULL)
2150     {
2151         ecdsa_freekey(ec);
2152         return NULL;
2153     }
2154
2155     /* Now check that the private key makes the public key */
2156     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2157     if (!publicKey)
2158     {
2159         ecdsa_freekey(ec);
2160         return NULL;
2161     }
2162
2163     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
2164         bignum_cmp(ec->publicKey.y, publicKey->y))
2165     {
2166         /* Private key doesn't make the public key on the given curve */
2167         ecdsa_freekey(ec);
2168         ec_point_free(publicKey);
2169         return NULL;
2170     }
2171
2172     ec_point_free(publicKey);
2173
2174     return ec;
2175 }
2176
2177 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
2178 {
2179     struct ec_key *ec = (struct ec_key *) key;
2180
2181     int pointlen;
2182     int namelen;
2183     int bloblen;
2184     int i;
2185
2186     if (ec->publicKey.curve->type != EC_WEIERSTRASS) {
2187         return 0;
2188     }
2189
2190     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2191     namelen = strlen(ec->publicKey.curve->name);
2192     bloblen =
2193         4 + namelen /* <LEN> nistpXXX */
2194         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
2195         + ssh2_bignum_length(ec->privateKey);
2196
2197     if (bloblen > len)
2198         return bloblen;
2199
2200     bloblen = 0;
2201
2202     PUT_32BIT(blob+bloblen, namelen);
2203     bloblen += 4;
2204     memcpy(blob+bloblen, ec->publicKey.curve->name, namelen);
2205     bloblen += namelen;
2206
2207     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
2208     bloblen += 4;
2209     blob[bloblen++] = 0x04;
2210     for (i = pointlen; i--; )
2211         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
2212     for (i = pointlen; i--; )
2213         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
2214
2215     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2216     PUT_32BIT(blob+bloblen, pointlen);
2217     bloblen += 4;
2218     for (i = pointlen; i--; )
2219         blob[bloblen++] = bignum_byte(ec->privateKey, i);
2220
2221     return bloblen;
2222 }
2223
2224 static int ecdsa_pubkey_bits(const struct ssh_signkey *self,
2225                              const void *blob, int len)
2226 {
2227     struct ec_key *ec;
2228     int ret;
2229
2230     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) blob, len);
2231     if (!ec)
2232         return -1;
2233     ret = ec->publicKey.curve->fieldBits;
2234     ecdsa_freekey(ec);
2235
2236     return ret;
2237 }
2238
2239 static int ecdsa_verifysig(void *key, const char *sig, int siglen,
2240                            const char *data, int datalen)
2241 {
2242     struct ec_key *ec = (struct ec_key *) key;
2243     const struct ecsign_extra *extra =
2244         (const struct ecsign_extra *)ec->signalg->extra;
2245     const char *p;
2246     int slen;
2247     int digestLen;
2248     int ret;
2249
2250     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2251         return 0;
2252
2253     /* Check the signature starts with the algorithm name */
2254     getstring(&sig, &siglen, &p, &slen);
2255     if (!p) {
2256         return 0;
2257     }
2258     if (!match_ssh_id(slen, p, ec->signalg->name)) {
2259         return 0;
2260     }
2261
2262     getstring(&sig, &siglen, &p, &slen);
2263     if (ec->publicKey.curve->type == EC_EDWARDS) {
2264         struct ec_point *r;
2265         Bignum s, h;
2266
2267         /* Check that the signature is two times the length of a point */
2268         if (slen != (ec->publicKey.curve->fieldBits / 8) * 2) {
2269             return 0;
2270         }
2271
2272         /* Check it's the 256 bit field so that SHA512 is the correct hash */
2273         if (ec->publicKey.curve->fieldBits != 256) {
2274             return 0;
2275         }
2276
2277         /* Get the signature */
2278         r = ec_point_new(ec->publicKey.curve, NULL, NULL, NULL, 0);
2279         if (!r) {
2280             return 0;
2281         }
2282         if (!decodepoint(p, ec->publicKey.curve->fieldBits / 8, r)) {
2283             ec_point_free(r);
2284             return 0;
2285         }
2286         s = bignum_from_bytes_le((unsigned char*)p + (ec->publicKey.curve->fieldBits / 8),
2287                                  ec->publicKey.curve->fieldBits / 8);
2288
2289         /* Get the hash of the encoded value of R + encoded value of pk + message */
2290         {
2291             int i, pointlen;
2292             unsigned char b;
2293             unsigned char digest[512 / 8];
2294             SHA512_State hs;
2295             SHA512_Init(&hs);
2296
2297             /* Add encoded r (no need to encode it again, it was in the signature) */
2298             SHA512_Bytes(&hs, p, ec->publicKey.curve->fieldBits / 8);
2299
2300             /* Encode pk and add it */
2301             pointlen = ec->publicKey.curve->fieldBits / 8;
2302             for (i = 0; i < pointlen - 1; ++i) {
2303                 b = bignum_byte(ec->publicKey.y, i);
2304                 SHA512_Bytes(&hs, &b, 1);
2305             }
2306             /* Unset last bit of y and set first bit of x in its place */
2307             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
2308             b |= bignum_bit(ec->publicKey.x, 0) << 7;
2309             SHA512_Bytes(&hs, &b, 1);
2310
2311             /* Add the message itself */
2312             SHA512_Bytes(&hs, data, datalen);
2313
2314             /* Get the hash */
2315             SHA512_Final(&hs, digest);
2316
2317             /* Convert to Bignum */
2318             h = bignum_from_bytes_le(digest, sizeof(digest));
2319         }
2320
2321         /* Verify sB == r + h*publicKey */
2322         {
2323             struct ec_point *lhs, *rhs, *tmp;
2324
2325             /* lhs = sB */
2326             lhs = ecp_mul(&ec->publicKey.curve->e.B, s);
2327             freebn(s);
2328             if (!lhs) {
2329                 ec_point_free(r);
2330                 freebn(h);
2331                 return 0;
2332             }
2333
2334             /* rhs = r + h*publicKey */
2335             tmp = ecp_mul(&ec->publicKey, h);
2336             freebn(h);
2337             if (!tmp) {
2338                 ec_point_free(lhs);
2339                 ec_point_free(r);
2340                 return 0;
2341             }
2342             rhs = ecp_add(r, tmp, 0);
2343             ec_point_free(r);
2344             ec_point_free(tmp);
2345             if (!rhs) {
2346                 ec_point_free(lhs);
2347                 return 0;
2348             }
2349
2350             /* Check the point is the same */
2351             ret = !bignum_cmp(lhs->x, rhs->x);
2352             if (ret) {
2353                 ret = !bignum_cmp(lhs->y, rhs->y);
2354                 if (ret) {
2355                     ret = 1;
2356                 }
2357             }
2358             ec_point_free(lhs);
2359             ec_point_free(rhs);
2360         }
2361     } else {
2362         Bignum r, s;
2363         unsigned char digest[512 / 8];
2364         void *hashctx;
2365
2366         r = getmp(&p, &slen);
2367         if (!r) return 0;
2368         s = getmp(&p, &slen);
2369         if (!s) {
2370             freebn(r);
2371             return 0;
2372         }
2373
2374         digestLen = extra->hash->hlen;
2375         assert(digestLen <= sizeof(digest));
2376         hashctx = extra->hash->init();
2377         extra->hash->bytes(hashctx, data, datalen);
2378         extra->hash->final(hashctx, digest);
2379
2380         /* Verify the signature */
2381         ret = _ecdsa_verify(&ec->publicKey, digest, digestLen, r, s);
2382
2383         freebn(r);
2384         freebn(s);
2385     }
2386
2387     return ret;
2388 }
2389
2390 static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
2391                                  int *siglen)
2392 {
2393     struct ec_key *ec = (struct ec_key *) key;
2394     const struct ecsign_extra *extra =
2395         (const struct ecsign_extra *)ec->signalg->extra;
2396     unsigned char digest[512 / 8];
2397     int digestLen;
2398     Bignum r = NULL, s = NULL;
2399     unsigned char *buf, *p;
2400     int rlen, slen, namelen;
2401     int i;
2402
2403     if (!ec->privateKey || !ec->publicKey.curve) {
2404         return NULL;
2405     }
2406
2407     if (ec->publicKey.curve->type == EC_EDWARDS) {
2408         struct ec_point *rp;
2409         int pointlen = ec->publicKey.curve->fieldBits / 8;
2410
2411         /* hash = H(sk) (where hash creates 2 * fieldBits)
2412          * b = fieldBits
2413          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
2414          * r = H(h[b/8:b/4] + m)
2415          * R = rB
2416          * S = (r + H(encodepoint(R) + encodepoint(pk) + m) * a) % l */
2417         {
2418             unsigned char hash[512/8];
2419             unsigned char b;
2420             Bignum a;
2421             SHA512_State hs;
2422             SHA512_Init(&hs);
2423
2424             for (i = 0; i < pointlen; ++i) {
2425                 unsigned char b = (unsigned char)bignum_byte(ec->privateKey, i);
2426                 SHA512_Bytes(&hs, &b, 1);
2427             }
2428
2429             SHA512_Final(&hs, hash);
2430
2431             /* The second part is simply turning the hash into a
2432              * Bignum, however the 2^(b-2) bit *must* be set, and the
2433              * bottom 3 bits *must* not be */
2434             hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
2435             hash[31] &= 0x7f; /* Unset above (b-2) */
2436             hash[31] |= 0x40; /* Set 2^(b-2) */
2437             /* Chop off the top part and convert to int */
2438             a = bignum_from_bytes_le(hash, 32);
2439
2440             SHA512_Init(&hs);
2441             SHA512_Bytes(&hs,
2442                          hash+(ec->publicKey.curve->fieldBits / 8),
2443                          (ec->publicKey.curve->fieldBits / 4)
2444                          - (ec->publicKey.curve->fieldBits / 8));
2445             SHA512_Bytes(&hs, data, datalen);
2446             SHA512_Final(&hs, hash);
2447
2448             r = bignum_from_bytes_le(hash, 512/8);
2449             rp = ecp_mul(&ec->publicKey.curve->e.B, r);
2450             if (!rp) {
2451                 freebn(r);
2452                 freebn(a);
2453                 return NULL;
2454             }
2455
2456             /* Now calculate s */
2457             SHA512_Init(&hs);
2458             /* Encode the point R */
2459             for (i = 0; i < pointlen - 1; ++i) {
2460                 b = bignum_byte(rp->y, i);
2461                 SHA512_Bytes(&hs, &b, 1);
2462             }
2463             /* Unset last bit of y and set first bit of x in its place */
2464             b = bignum_byte(rp->y, i) & 0x7f;
2465             b |= bignum_bit(rp->x, 0) << 7;
2466             SHA512_Bytes(&hs, &b, 1);
2467
2468             /* Encode the point pk */
2469             for (i = 0; i < pointlen - 1; ++i) {
2470                 b = bignum_byte(ec->publicKey.y, i);
2471                 SHA512_Bytes(&hs, &b, 1);
2472             }
2473             /* Unset last bit of y and set first bit of x in its place */
2474             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
2475             b |= bignum_bit(ec->publicKey.x, 0) << 7;
2476             SHA512_Bytes(&hs, &b, 1);
2477
2478             /* Add the message */
2479             SHA512_Bytes(&hs, data, datalen);
2480             SHA512_Final(&hs, hash);
2481
2482             {
2483                 Bignum tmp, tmp2;
2484
2485                 tmp = bignum_from_bytes_le(hash, 512/8);
2486                 tmp2 = modmul(tmp, a, ec->publicKey.curve->e.l);
2487                 freebn(a);
2488                 freebn(tmp);
2489                 tmp = bigadd(r, tmp2);
2490                 freebn(r);
2491                 freebn(tmp2);
2492                 s = bigmod(tmp, ec->publicKey.curve->e.l);
2493                 freebn(tmp);
2494             }
2495         }
2496
2497         /* Format the output */
2498         namelen = strlen(ec->signalg->name);
2499         *siglen = 4+namelen+4+((ec->publicKey.curve->fieldBits / 8)*2);
2500         buf = snewn(*siglen, unsigned char);
2501         p = buf;
2502         PUT_32BIT(p, namelen);
2503         p += 4;
2504         memcpy(p, ec->signalg->name, namelen);
2505         p += namelen;
2506         PUT_32BIT(p, ((ec->publicKey.curve->fieldBits / 8)*2));
2507         p += 4;
2508
2509         /* Encode the point */
2510         pointlen = ec->publicKey.curve->fieldBits / 8;
2511         for (i = 0; i < pointlen - 1; ++i) {
2512             *p++ = bignum_byte(rp->y, i);
2513         }
2514         /* Unset last bit of y and set first bit of x in its place */
2515         *p = bignum_byte(rp->y, i) & 0x7f;
2516         *p++ |= bignum_bit(rp->x, 0) << 7;
2517         ec_point_free(rp);
2518
2519         /* Encode the int */
2520         for (i = 0; i < pointlen; ++i) {
2521             *p++ = bignum_byte(s, i);
2522         }
2523         freebn(s);
2524     } else {
2525         void *hashctx;
2526
2527         digestLen = extra->hash->hlen;
2528         assert(digestLen <= sizeof(digest));
2529         hashctx = extra->hash->init();
2530         extra->hash->bytes(hashctx, data, datalen);
2531         extra->hash->final(hashctx, digest);
2532
2533         /* Do the signature */
2534         _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
2535         if (!r || !s) {
2536             if (r) freebn(r);
2537             if (s) freebn(s);
2538             return NULL;
2539         }
2540
2541         rlen = (bignum_bitcount(r) + 8) / 8;
2542         slen = (bignum_bitcount(s) + 8) / 8;
2543
2544         namelen = strlen(ec->signalg->name);
2545
2546         /* Format the output */
2547         *siglen = 8+namelen+rlen+slen+8;
2548         buf = snewn(*siglen, unsigned char);
2549         p = buf;
2550         PUT_32BIT(p, namelen);
2551         p += 4;
2552         memcpy(p, ec->signalg->name, namelen);
2553         p += namelen;
2554         PUT_32BIT(p, rlen + slen + 8);
2555         p += 4;
2556         PUT_32BIT(p, rlen);
2557         p += 4;
2558         for (i = rlen; i--;)
2559             *p++ = bignum_byte(r, i);
2560         PUT_32BIT(p, slen);
2561         p += 4;
2562         for (i = slen; i--;)
2563             *p++ = bignum_byte(s, i);
2564
2565         freebn(r);
2566         freebn(s);
2567     }
2568
2569     return buf;
2570 }
2571
2572 const struct ecsign_extra sign_extra_ed25519 = {
2573     ec_ed25519, NULL,
2574     NULL, 0,
2575 };
2576 const struct ssh_signkey ssh_ecdsa_ed25519 = {
2577     ecdsa_newkey,
2578     ecdsa_freekey,
2579     ecdsa_fmtkey,
2580     ecdsa_public_blob,
2581     ecdsa_private_blob,
2582     ecdsa_createkey,
2583     ed25519_openssh_createkey,
2584     ed25519_openssh_fmtkey,
2585     2 /* point, private exponent */,
2586     ecdsa_pubkey_bits,
2587     ecdsa_verifysig,
2588     ecdsa_sign,
2589     "ssh-ed25519",
2590     "ssh-ed25519",
2591     &sign_extra_ed25519,
2592 };
2593
2594 /* OID: 1.2.840.10045.3.1.7 (ansiX9p256r1) */
2595 static const unsigned char nistp256_oid[] = {
2596     0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07
2597 };
2598 const struct ecsign_extra sign_extra_nistp256 = {
2599     ec_p256, &ssh_sha256,
2600     nistp256_oid, lenof(nistp256_oid),
2601 };
2602 const struct ssh_signkey ssh_ecdsa_nistp256 = {
2603     ecdsa_newkey,
2604     ecdsa_freekey,
2605     ecdsa_fmtkey,
2606     ecdsa_public_blob,
2607     ecdsa_private_blob,
2608     ecdsa_createkey,
2609     ecdsa_openssh_createkey,
2610     ecdsa_openssh_fmtkey,
2611     3 /* curve name, point, private exponent */,
2612     ecdsa_pubkey_bits,
2613     ecdsa_verifysig,
2614     ecdsa_sign,
2615     "ecdsa-sha2-nistp256",
2616     "ecdsa-sha2-nistp256",
2617     &sign_extra_nistp256,
2618 };
2619
2620 /* OID: 1.3.132.0.34 (secp384r1) */
2621 static const unsigned char nistp384_oid[] = {
2622     0x2b, 0x81, 0x04, 0x00, 0x22
2623 };
2624 const struct ecsign_extra sign_extra_nistp384 = {
2625     ec_p384, &ssh_sha384,
2626     nistp384_oid, lenof(nistp384_oid),
2627 };
2628 const struct ssh_signkey ssh_ecdsa_nistp384 = {
2629     ecdsa_newkey,
2630     ecdsa_freekey,
2631     ecdsa_fmtkey,
2632     ecdsa_public_blob,
2633     ecdsa_private_blob,
2634     ecdsa_createkey,
2635     ecdsa_openssh_createkey,
2636     ecdsa_openssh_fmtkey,
2637     3 /* curve name, point, private exponent */,
2638     ecdsa_pubkey_bits,
2639     ecdsa_verifysig,
2640     ecdsa_sign,
2641     "ecdsa-sha2-nistp384",
2642     "ecdsa-sha2-nistp384",
2643     &sign_extra_nistp384,
2644 };
2645
2646 /* OID: 1.3.132.0.35 (secp521r1) */
2647 static const unsigned char nistp521_oid[] = {
2648     0x2b, 0x81, 0x04, 0x00, 0x23
2649 };
2650 const struct ecsign_extra sign_extra_nistp521 = {
2651     ec_p521, &ssh_sha512,
2652     nistp521_oid, lenof(nistp521_oid),
2653 };
2654 const struct ssh_signkey ssh_ecdsa_nistp521 = {
2655     ecdsa_newkey,
2656     ecdsa_freekey,
2657     ecdsa_fmtkey,
2658     ecdsa_public_blob,
2659     ecdsa_private_blob,
2660     ecdsa_createkey,
2661     ecdsa_openssh_createkey,
2662     ecdsa_openssh_fmtkey,
2663     3 /* curve name, point, private exponent */,
2664     ecdsa_pubkey_bits,
2665     ecdsa_verifysig,
2666     ecdsa_sign,
2667     "ecdsa-sha2-nistp521",
2668     "ecdsa-sha2-nistp521",
2669     &sign_extra_nistp521,
2670 };
2671
2672 /* ----------------------------------------------------------------------
2673  * Exposed ECDH interface
2674  */
2675
2676 struct eckex_extra {
2677     struct ec_curve *(*curve)(void);
2678 };
2679
2680 static Bignum ecdh_calculate(const Bignum private,
2681                              const struct ec_point *public)
2682 {
2683     struct ec_point *p;
2684     Bignum ret;
2685     p = ecp_mul(public, private);
2686     if (!p) return NULL;
2687     ret = p->x;
2688     p->x = NULL;
2689
2690     if (p->curve->type == EC_MONTGOMERY) {
2691         /* Do conversion in network byte order */
2692         int i;
2693         int bytes = (bignum_bitcount(ret)+7) / 8;
2694         unsigned char *byteorder = snewn(bytes, unsigned char);
2695         for (i = 0; i < bytes; ++i) {
2696             byteorder[i] = bignum_byte(ret, i);
2697         }
2698         freebn(ret);
2699         ret = bignum_from_bytes(byteorder, bytes);
2700         sfree(byteorder);
2701     }
2702
2703     ec_point_free(p);
2704     return ret;
2705 }
2706
2707 void *ssh_ecdhkex_newkey(const struct ssh_kex *kex)
2708 {
2709     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
2710     struct ec_curve *curve;
2711     struct ec_key *key;
2712     struct ec_point *publicKey;
2713
2714     curve = extra->curve();
2715
2716     key = snew(struct ec_key);
2717
2718     key->signalg = NULL;
2719     key->publicKey.curve = curve;
2720
2721     if (curve->type == EC_MONTGOMERY) {
2722         unsigned char bytes[32] = {0};
2723         int i;
2724
2725         for (i = 0; i < sizeof(bytes); ++i)
2726         {
2727             bytes[i] = (unsigned char)random_byte();
2728         }
2729         bytes[0] &= 248;
2730         bytes[31] &= 127;
2731         bytes[31] |= 64;
2732         key->privateKey = bignum_from_bytes(bytes, sizeof(bytes));
2733         for (i = 0; i < sizeof(bytes); ++i)
2734         {
2735             ((volatile char*)bytes)[i] = 0;
2736         }
2737         if (!key->privateKey) {
2738             sfree(key);
2739             return NULL;
2740         }
2741         publicKey = ecp_mul(&key->publicKey.curve->m.G, key->privateKey);
2742         if (!publicKey) {
2743             freebn(key->privateKey);
2744             sfree(key);
2745             return NULL;
2746         }
2747         key->publicKey.x = publicKey->x;
2748         key->publicKey.y = publicKey->y;
2749         key->publicKey.z = NULL;
2750         sfree(publicKey);
2751     } else {
2752         key->privateKey = bignum_random_in_range(One, key->publicKey.curve->w.n);
2753         if (!key->privateKey) {
2754             sfree(key);
2755             return NULL;
2756         }
2757         publicKey = ecp_mul(&key->publicKey.curve->w.G, key->privateKey);
2758         if (!publicKey) {
2759             freebn(key->privateKey);
2760             sfree(key);
2761             return NULL;
2762         }
2763         key->publicKey.x = publicKey->x;
2764         key->publicKey.y = publicKey->y;
2765         key->publicKey.z = NULL;
2766         sfree(publicKey);
2767     }
2768     return key;
2769 }
2770
2771 char *ssh_ecdhkex_getpublic(void *key, int *len)
2772 {
2773     struct ec_key *ec = (struct ec_key*)key;
2774     char *point, *p;
2775     int i;
2776     int pointlen;
2777
2778     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2779
2780     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2781         *len = 1 + pointlen * 2;
2782     } else {
2783         *len = pointlen;
2784     }
2785     point = (char*)snewn(*len, char);
2786
2787     p = point;
2788     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2789         *p++ = 0x04;
2790         for (i = pointlen; i--;) {
2791             *p++ = bignum_byte(ec->publicKey.x, i);
2792         }
2793         for (i = pointlen; i--;) {
2794             *p++ = bignum_byte(ec->publicKey.y, i);
2795         }
2796     } else {
2797         for (i = 0; i < pointlen; ++i) {
2798             *p++ = bignum_byte(ec->publicKey.x, i);
2799         }
2800     }
2801
2802     return point;
2803 }
2804
2805 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
2806 {
2807     struct ec_key *ec = (struct ec_key*) key;
2808     struct ec_point remote;
2809     Bignum ret;
2810
2811     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2812         remote.curve = ec->publicKey.curve;
2813         remote.infinity = 0;
2814         if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
2815             return NULL;
2816         }
2817     } else {
2818         /* Point length has to be the same length */
2819         if (remoteKeyLen != (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8) {
2820             return NULL;
2821         }
2822
2823         remote.curve = ec->publicKey.curve;
2824         remote.infinity = 0;
2825         remote.x = bignum_from_bytes_le((unsigned char*)remoteKey, remoteKeyLen);
2826         remote.y = NULL;
2827         remote.z = NULL;
2828     }
2829
2830     ret = ecdh_calculate(ec->privateKey, &remote);
2831     if (remote.x) freebn(remote.x);
2832     if (remote.y) freebn(remote.y);
2833     return ret;
2834 }
2835
2836 void ssh_ecdhkex_freekey(void *key)
2837 {
2838     ecdsa_freekey(key);
2839 }
2840
2841 static const struct eckex_extra kex_extra_curve25519 = { ec_curve25519 };
2842 static const struct ssh_kex ssh_ec_kex_curve25519 = {
2843     "curve25519-sha256@libssh.org", NULL, KEXTYPE_ECDH,
2844     &ssh_sha256, &kex_extra_curve25519,
2845 };
2846
2847 const struct eckex_extra kex_extra_nistp256 = { ec_p256 };
2848 static const struct ssh_kex ssh_ec_kex_nistp256 = {
2849     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH,
2850     &ssh_sha256, &kex_extra_nistp256,
2851 };
2852
2853 const struct eckex_extra kex_extra_nistp384 = { ec_p384 };
2854 static const struct ssh_kex ssh_ec_kex_nistp384 = {
2855     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH,
2856     &ssh_sha384, &kex_extra_nistp384,
2857 };
2858
2859 const struct eckex_extra kex_extra_nistp521 = { ec_p521 };
2860 static const struct ssh_kex ssh_ec_kex_nistp521 = {
2861     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH,
2862     &ssh_sha512, &kex_extra_nistp521,
2863 };
2864
2865 static const struct ssh_kex *const ec_kex_list[] = {
2866     &ssh_ec_kex_curve25519,
2867     &ssh_ec_kex_nistp256,
2868     &ssh_ec_kex_nistp384,
2869     &ssh_ec_kex_nistp521,
2870 };
2871
2872 const struct ssh_kexes ssh_ecdh_kex = {
2873     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
2874     ec_kex_list
2875 };
2876
2877 /* ----------------------------------------------------------------------
2878  * Helper functions for finding key algorithms and returning auxiliary
2879  * data.
2880  */
2881
2882 const struct ssh_signkey *ec_alg_by_oid(int len, const void *oid,
2883                                         const struct ec_curve **curve)
2884 {
2885     static const struct ssh_signkey *algs_with_oid[] = {
2886         &ssh_ecdsa_nistp256,
2887         &ssh_ecdsa_nistp384,
2888         &ssh_ecdsa_nistp521,
2889     };
2890     int i;
2891
2892     for (i = 0; i < lenof(algs_with_oid); i++) {
2893         const struct ssh_signkey *alg = algs_with_oid[i];
2894         const struct ecsign_extra *extra =
2895             (const struct ecsign_extra *)alg->extra;
2896         if (len == extra->oidlen && !memcmp(oid, extra->oid, len)) {
2897             *curve = extra->curve();
2898             return alg;
2899         }
2900     }
2901     return NULL;
2902 }
2903
2904 const unsigned char *ec_alg_oid(const struct ssh_signkey *alg,
2905                                 int *oidlen)
2906 {
2907     const struct ecsign_extra *extra = (const struct ecsign_extra *)alg->extra;
2908     *oidlen = extra->oidlen;
2909     return extra->oid;
2910 }
2911
2912 const int ec_nist_alg_and_curve_by_bits(int bits,
2913                                         const struct ec_curve **curve,
2914                                         const struct ssh_signkey **alg)
2915 {
2916     switch (bits) {
2917       case 256: *alg = &ssh_ecdsa_nistp256; break;
2918       case 384: *alg = &ssh_ecdsa_nistp384; break;
2919       case 521: *alg = &ssh_ecdsa_nistp521; break;
2920       default: return FALSE;
2921     }
2922     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
2923     return TRUE;
2924 }
2925
2926 const int ec_ed_alg_and_curve_by_bits(int bits,
2927                                       const struct ec_curve **curve,
2928                                       const struct ssh_signkey **alg)
2929 {
2930     switch (bits) {
2931       case 256: *alg = &ssh_ecdsa_ed25519; break;
2932       default: return FALSE;
2933     }
2934     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
2935     return TRUE;
2936 }