]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
Fix an assertion failure when loading Ed25519 keys.
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  *
5  * NOTE: Only curves on prime field are handled by the maths functions
6  *       in Weierstrass form using Jacobian co-ordinates.
7  *
8  *       Montgomery form curves are supported for DH. (Curve25519)
9  *
10  *       Edwards form curves are supported for DSA. (Ed25519)
11  */
12
13 /*
14  * References:
15  *
16  * Elliptic curves in SSH are specified in RFC 5656:
17  *   http://tools.ietf.org/html/rfc5656
18  *
19  * That specification delegates details of public key formatting and a
20  * lot of underlying mechanism to SEC 1:
21  *   http://www.secg.org/sec1-v2.pdf
22  *
23  * Montgomery maths from:
24  * Handbook of elliptic and hyperelliptic curve cryptography, Chapter 13
25  *   http://cs.ucsb.edu/~koc/ccs130h/2013/EllipticHyperelliptic-CohenFrey.pdf
26  *
27  * Curve25519 spec from libssh (with reference to other things in the
28  * libssh code):
29  *   https://git.libssh.org/users/aris/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt
30  *
31  * Edwards DSA:
32  *   http://ed25519.cr.yp.to/ed25519-20110926.pdf
33  */
34
35 #include <stdlib.h>
36 #include <assert.h>
37
38 #include "ssh.h"
39
40 /* ----------------------------------------------------------------------
41  * Elliptic curve definitions
42  */
43
44 static void initialise_wcurve(struct ec_curve *curve, int bits,
45                               const unsigned char *p,
46                               const unsigned char *a, const unsigned char *b,
47                               const unsigned char *n, const unsigned char *Gx,
48                               const unsigned char *Gy)
49 {
50     int length = bits / 8;
51     if (bits % 8) ++length;
52
53     curve->type = EC_WEIERSTRASS;
54
55     curve->fieldBits = bits;
56     curve->p = bignum_from_bytes(p, length);
57
58     /* Curve co-efficients */
59     curve->w.a = bignum_from_bytes(a, length);
60     curve->w.b = bignum_from_bytes(b, length);
61
62     /* Group order and generator */
63     curve->w.n = bignum_from_bytes(n, length);
64     curve->w.G.x = bignum_from_bytes(Gx, length);
65     curve->w.G.y = bignum_from_bytes(Gy, length);
66     curve->w.G.curve = curve;
67     curve->w.G.infinity = 0;
68 }
69
70 static void initialise_mcurve(struct ec_curve *curve, int bits,
71                               const unsigned char *p,
72                               const unsigned char *a, const unsigned char *b,
73                               const unsigned char *Gx)
74 {
75     int length = bits / 8;
76     if (bits % 8) ++length;
77
78     curve->type = EC_MONTGOMERY;
79
80     curve->fieldBits = bits;
81     curve->p = bignum_from_bytes(p, length);
82
83     /* Curve co-efficients */
84     curve->m.a = bignum_from_bytes(a, length);
85     curve->m.b = bignum_from_bytes(b, length);
86
87     /* Generator */
88     curve->m.G.x = bignum_from_bytes(Gx, length);
89     curve->m.G.y = NULL;
90     curve->m.G.z = NULL;
91     curve->m.G.curve = curve;
92     curve->m.G.infinity = 0;
93 }
94
95 static void initialise_ecurve(struct ec_curve *curve, int bits,
96                               const unsigned char *p,
97                               const unsigned char *l, const unsigned char *d,
98                               const unsigned char *Bx, const unsigned char *By)
99 {
100     int length = bits / 8;
101     if (bits % 8) ++length;
102
103     curve->type = EC_EDWARDS;
104
105     curve->fieldBits = bits;
106     curve->p = bignum_from_bytes(p, length);
107
108     /* Curve co-efficients */
109     curve->e.l = bignum_from_bytes(l, length);
110     curve->e.d = bignum_from_bytes(d, length);
111
112     /* Group order and generator */
113     curve->e.B.x = bignum_from_bytes(Bx, length);
114     curve->e.B.y = bignum_from_bytes(By, length);
115     curve->e.B.curve = curve;
116     curve->e.B.infinity = 0;
117 }
118
119 static struct ec_curve *ec_p256(void)
120 {
121     static struct ec_curve curve = { 0 };
122     static unsigned char initialised = 0;
123
124     if (!initialised)
125     {
126         static const unsigned char p[] = {
127             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
128             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
129             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
130             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
131         };
132         static const unsigned char a[] = {
133             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
134             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
135             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
136             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
137         };
138         static const unsigned char b[] = {
139             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
140             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
141             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
142             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
143         };
144         static const unsigned char n[] = {
145             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
146             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
147             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
148             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
149         };
150         static const unsigned char Gx[] = {
151             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
152             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
153             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
154             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
155         };
156         static const unsigned char Gy[] = {
157             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
158             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
159             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
160             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
161         };
162
163         initialise_wcurve(&curve, 256, p, a, b, n, Gx, Gy);
164         curve.textname = curve.name = "nistp256";
165
166         /* Now initialised, no need to do it again */
167         initialised = 1;
168     }
169
170     return &curve;
171 }
172
173 static struct ec_curve *ec_p384(void)
174 {
175     static struct ec_curve curve = { 0 };
176     static unsigned char initialised = 0;
177
178     if (!initialised)
179     {
180         static const unsigned char p[] = {
181             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
182             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
183             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
184             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
185             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
186             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
187         };
188         static const unsigned char a[] = {
189             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
190             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
191             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
192             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
193             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
194             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
195         };
196         static const unsigned char b[] = {
197             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
198             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
199             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
200             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
201             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
202             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
203         };
204         static const unsigned char n[] = {
205             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
206             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
207             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
208             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
209             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
210             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
211         };
212         static const unsigned char Gx[] = {
213             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
214             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
215             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
216             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
217             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
218             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
219         };
220         static const unsigned char Gy[] = {
221             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
222             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
223             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
224             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
225             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
226             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
227         };
228
229         initialise_wcurve(&curve, 384, p, a, b, n, Gx, Gy);
230         curve.textname = curve.name = "nistp384";
231
232         /* Now initialised, no need to do it again */
233         initialised = 1;
234     }
235
236     return &curve;
237 }
238
239 static struct ec_curve *ec_p521(void)
240 {
241     static struct ec_curve curve = { 0 };
242     static unsigned char initialised = 0;
243
244     if (!initialised)
245     {
246         static const unsigned char p[] = {
247             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
248             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
249             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
250             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
251             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
252             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
253             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
254             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
255             0xff, 0xff
256         };
257         static const unsigned char a[] = {
258             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
259             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
260             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
261             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
262             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
263             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
264             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
265             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
266             0xff, 0xfc
267         };
268         static const unsigned char b[] = {
269             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
270             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
271             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
272             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
273             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
274             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
275             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
276             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
277             0x3f, 0x00
278         };
279         static const unsigned char n[] = {
280             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
281             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
282             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
283             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
284             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
285             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
286             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
287             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
288             0x64, 0x09
289         };
290         static const unsigned char Gx[] = {
291             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
292             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
293             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
294             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
295             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
296             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
297             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
298             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
299             0xbd, 0x66
300         };
301         static const unsigned char Gy[] = {
302             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
303             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
304             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
305             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
306             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
307             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
308             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
309             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
310             0x66, 0x50
311         };
312
313         initialise_wcurve(&curve, 521, p, a, b, n, Gx, Gy);
314         curve.textname = curve.name = "nistp521";
315
316         /* Now initialised, no need to do it again */
317         initialised = 1;
318     }
319
320     return &curve;
321 }
322
323 static struct ec_curve *ec_curve25519(void)
324 {
325     static struct ec_curve curve = { 0 };
326     static unsigned char initialised = 0;
327
328     if (!initialised)
329     {
330         static const unsigned char p[] = {
331             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
332             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
333             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
334             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
335         };
336         static const unsigned char a[] = {
337             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
338             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
339             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
340             0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x6d, 0x06
341         };
342         static const unsigned char b[] = {
343             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
344             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
345             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
346             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01
347         };
348         static const unsigned char gx[32] = {
349             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
350             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
351             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
352             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09
353         };
354
355         initialise_mcurve(&curve, 256, p, a, b, gx);
356         /* This curve doesn't need a name, because it's never used in
357          * any format that embeds the curve name */
358         curve.name = NULL;
359         curve.textname = "Curve25519";
360
361         /* Now initialised, no need to do it again */
362         initialised = 1;
363     }
364
365     return &curve;
366 }
367
368 static struct ec_curve *ec_ed25519(void)
369 {
370     static struct ec_curve curve = { 0 };
371     static unsigned char initialised = 0;
372
373     if (!initialised)
374     {
375         static const unsigned char q[] = {
376             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
377             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
378             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
379             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
380         };
381         static const unsigned char l[32] = {
382             0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
383             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
384             0x14, 0xde, 0xf9, 0xde, 0xa2, 0xf7, 0x9c, 0xd6,
385             0x58, 0x12, 0x63, 0x1a, 0x5c, 0xf5, 0xd3, 0xed
386         };
387         static const unsigned char d[32] = {
388             0x52, 0x03, 0x6c, 0xee, 0x2b, 0x6f, 0xfe, 0x73,
389             0x8c, 0xc7, 0x40, 0x79, 0x77, 0x79, 0xe8, 0x98,
390             0x00, 0x70, 0x0a, 0x4d, 0x41, 0x41, 0xd8, 0xab,
391             0x75, 0xeb, 0x4d, 0xca, 0x13, 0x59, 0x78, 0xa3
392         };
393         static const unsigned char Bx[32] = {
394             0x21, 0x69, 0x36, 0xd3, 0xcd, 0x6e, 0x53, 0xfe,
395             0xc0, 0xa4, 0xe2, 0x31, 0xfd, 0xd6, 0xdc, 0x5c,
396             0x69, 0x2c, 0xc7, 0x60, 0x95, 0x25, 0xa7, 0xb2,
397             0xc9, 0x56, 0x2d, 0x60, 0x8f, 0x25, 0xd5, 0x1a
398         };
399         static const unsigned char By[32] = {
400             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
401             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
402             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
403             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x58
404         };
405
406         /* This curve doesn't need a name, because it's never used in
407          * any format that embeds the curve name */
408         curve.name = NULL;
409
410         initialise_ecurve(&curve, 256, q, l, d, Bx, By);
411         curve.textname = "Ed25519";
412
413         /* Now initialised, no need to do it again */
414         initialised = 1;
415     }
416
417     return &curve;
418 }
419
420 /* Return 1 if a is -3 % p, otherwise return 0
421  * This is used because there are some maths optimisations */
422 static int ec_aminus3(const struct ec_curve *curve)
423 {
424     int ret;
425     Bignum _p;
426
427     if (curve->type != EC_WEIERSTRASS) {
428         return 0;
429     }
430
431     _p = bignum_add_long(curve->w.a, 3);
432
433     ret = !bignum_cmp(curve->p, _p);
434     freebn(_p);
435     return ret;
436 }
437
438 /* ----------------------------------------------------------------------
439  * Elliptic curve field maths
440  */
441
442 static Bignum ecf_add(const Bignum a, const Bignum b,
443                       const struct ec_curve *curve)
444 {
445     Bignum a1, b1, ab, ret;
446
447     a1 = bigmod(a, curve->p);
448     b1 = bigmod(b, curve->p);
449
450     ab = bigadd(a1, b1);
451     freebn(a1);
452     freebn(b1);
453
454     ret = bigmod(ab, curve->p);
455     freebn(ab);
456
457     return ret;
458 }
459
460 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
461 {
462     return modmul(a, a, curve->p);
463 }
464
465 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
466 {
467     Bignum ret, tmp;
468
469     /* Double */
470     tmp = bignum_lshift(a, 1);
471
472     /* Add itself (i.e. treble) */
473     ret = bigadd(tmp, a);
474     freebn(tmp);
475
476     /* Normalise */
477     while (bignum_cmp(ret, curve->p) >= 0)
478     {
479         tmp = bigsub(ret, curve->p);
480         assert(tmp);
481         freebn(ret);
482         ret = tmp;
483     }
484
485     return ret;
486 }
487
488 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
489 {
490     Bignum ret = bignum_lshift(a, 1);
491     if (bignum_cmp(ret, curve->p) >= 0)
492     {
493         Bignum tmp = bigsub(ret, curve->p);
494         assert(tmp);
495         freebn(ret);
496         return tmp;
497     }
498     else
499     {
500         return ret;
501     }
502 }
503
504 /* ----------------------------------------------------------------------
505  * Memory functions
506  */
507
508 void ec_point_free(struct ec_point *point)
509 {
510     if (point == NULL) return;
511     point->curve = 0;
512     if (point->x) freebn(point->x);
513     if (point->y) freebn(point->y);
514     if (point->z) freebn(point->z);
515     point->infinity = 0;
516     sfree(point);
517 }
518
519 static struct ec_point *ec_point_new(const struct ec_curve *curve,
520                                      const Bignum x, const Bignum y, const Bignum z,
521                                      unsigned char infinity)
522 {
523     struct ec_point *point = snewn(1, struct ec_point);
524     point->curve = curve;
525     point->x = x;
526     point->y = y;
527     point->z = z;
528     point->infinity = infinity ? 1 : 0;
529     return point;
530 }
531
532 static struct ec_point *ec_point_copy(const struct ec_point *a)
533 {
534     if (a == NULL) return NULL;
535     return ec_point_new(a->curve,
536                         a->x ? copybn(a->x) : NULL,
537                         a->y ? copybn(a->y) : NULL,
538                         a->z ? copybn(a->z) : NULL,
539                         a->infinity);
540 }
541
542 static int ec_point_verify(const struct ec_point *a)
543 {
544     if (a->infinity) {
545         return 1;
546     } else if (a->curve->type == EC_EDWARDS) {
547         /* Check y^2 - x^2 - 1 - d * x^2 * y^2 == 0 */
548         Bignum y2, x2, tmp, tmp2, tmp3;
549         int ret;
550
551         y2 = ecf_square(a->y, a->curve);
552         x2 = ecf_square(a->x, a->curve);
553         tmp = modmul(a->curve->e.d, x2, a->curve->p);
554         tmp2 = modmul(tmp, y2, a->curve->p);
555         freebn(tmp);
556         tmp = modsub(y2, x2, a->curve->p);
557         freebn(y2);
558         freebn(x2);
559         tmp3 = modsub(tmp, tmp2, a->curve->p);
560         freebn(tmp);
561         freebn(tmp2);
562         ret = !bignum_cmp(tmp3, One);
563         freebn(tmp3);
564         return ret;
565     } else if (a->curve->type == EC_WEIERSTRASS) {
566         /* Verify y^2 = x^3 + ax + b */
567         int ret = 0;
568
569         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
570
571         Bignum Three = bignum_from_long(3);
572
573         lhs = modmul(a->y, a->y, a->curve->p);
574
575         /* This uses montgomery multiplication to optimise */
576         x3 = modpow(a->x, Three, a->curve->p);
577         freebn(Three);
578         ax = modmul(a->curve->w.a, a->x, a->curve->p);
579         x3ax = bigadd(x3, ax);
580         freebn(x3); x3 = NULL;
581         freebn(ax); ax = NULL;
582         x3axm = bigmod(x3ax, a->curve->p);
583         freebn(x3ax); x3ax = NULL;
584         x3axb = bigadd(x3axm, a->curve->w.b);
585         freebn(x3axm); x3axm = NULL;
586         rhs = bigmod(x3axb, a->curve->p);
587         freebn(x3axb);
588
589         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
590         freebn(lhs);
591         freebn(rhs);
592
593         return ret;
594     } else {
595         return 0;
596     }
597 }
598
599 /* ----------------------------------------------------------------------
600  * Elliptic curve point maths
601  */
602
603 /* Returns 1 on success and 0 on memory error */
604 static int ecp_normalise(struct ec_point *a)
605 {
606     if (!a) {
607         /* No point */
608         return 0;
609     }
610
611     if (a->infinity) {
612         /* Point is at infinity - i.e. normalised */
613         return 1;
614     }
615
616     if (a->curve->type == EC_WEIERSTRASS) {
617         /* In Jacobian Coordinates the triple (X, Y, Z) represents
618            the affine point (X / Z^2, Y / Z^3) */
619
620         Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
621
622         if (!a->x || !a->y) {
623             /* No point defined */
624             return 0;
625         } else if (!a->z) {
626             /* Already normalised */
627             return 1;
628         }
629
630         Z2 = ecf_square(a->z, a->curve);
631         Z2inv = modinv(Z2, a->curve->p);
632         if (!Z2inv) {
633             freebn(Z2);
634             return 0;
635         }
636         tx = modmul(a->x, Z2inv, a->curve->p);
637         freebn(Z2inv);
638
639         Z3 = modmul(Z2, a->z, a->curve->p);
640         freebn(Z2);
641         Z3inv = modinv(Z3, a->curve->p);
642         freebn(Z3);
643         if (!Z3inv) {
644             freebn(tx);
645             return 0;
646         }
647         ty = modmul(a->y, Z3inv, a->curve->p);
648         freebn(Z3inv);
649
650         freebn(a->x);
651         a->x = tx;
652         freebn(a->y);
653         a->y = ty;
654         freebn(a->z);
655         a->z = NULL;
656         return 1;
657     } else if (a->curve->type == EC_MONTGOMERY) {
658         /* In Montgomery (X : Z) represents the x co-ord (X / Z, ?) */
659
660         Bignum tmp, tmp2;
661
662         if (!a->x) {
663             /* No point defined */
664             return 0;
665         } else if (!a->z) {
666             /* Already normalised */
667             return 1;
668         }
669
670         tmp = modinv(a->z, a->curve->p);
671         if (!tmp) {
672             return 0;
673         }
674         tmp2 = modmul(a->x, tmp, a->curve->p);
675         freebn(tmp);
676
677         freebn(a->z);
678         a->z = NULL;
679         freebn(a->x);
680         a->x = tmp2;
681         return 1;
682     } else if (a->curve->type == EC_EDWARDS) {
683         /* Always normalised */
684         return 1;
685     } else {
686         return 0;
687     }
688 }
689
690 static struct ec_point *ecp_doublew(const struct ec_point *a, const int aminus3)
691 {
692     Bignum S, M, outx, outy, outz;
693
694     if (bignum_cmp(a->y, Zero) == 0)
695     {
696         /* Identity */
697         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
698     }
699
700     /* S = 4*X*Y^2 */
701     {
702         Bignum Y2, XY2, _2XY2;
703
704         Y2 = ecf_square(a->y, a->curve);
705         XY2 = modmul(a->x, Y2, a->curve->p);
706         freebn(Y2);
707
708         _2XY2 = ecf_double(XY2, a->curve);
709         freebn(XY2);
710         S = ecf_double(_2XY2, a->curve);
711         freebn(_2XY2);
712     }
713
714     /* Faster calculation if a = -3 */
715     if (aminus3) {
716         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
717         Bignum Z2, XpZ2, XmZ2, second;
718
719         if (a->z == NULL) {
720             Z2 = copybn(One);
721         } else {
722             Z2 = ecf_square(a->z, a->curve);
723         }
724
725         XpZ2 = ecf_add(a->x, Z2, a->curve);
726         XmZ2 = modsub(a->x, Z2, a->curve->p);
727         freebn(Z2);
728
729         second = modmul(XpZ2, XmZ2, a->curve->p);
730         freebn(XpZ2);
731         freebn(XmZ2);
732
733         M = ecf_treble(second, a->curve);
734         freebn(second);
735     } else {
736         /* M = 3*X^2 + a*Z^4 */
737         Bignum _3X2, X2, aZ4;
738
739         if (a->z == NULL) {
740             aZ4 = copybn(a->curve->w.a);
741         } else {
742             Bignum Z2, Z4;
743
744             Z2 = ecf_square(a->z, a->curve);
745             Z4 = ecf_square(Z2, a->curve);
746             freebn(Z2);
747             aZ4 = modmul(a->curve->w.a, Z4, a->curve->p);
748             freebn(Z4);
749         }
750
751         X2 = modmul(a->x, a->x, a->curve->p);
752         _3X2 = ecf_treble(X2, a->curve);
753         freebn(X2);
754         M = ecf_add(_3X2, aZ4, a->curve);
755         freebn(_3X2);
756         freebn(aZ4);
757     }
758
759     /* X' = M^2 - 2*S */
760     {
761         Bignum M2, _2S;
762
763         M2 = ecf_square(M, a->curve);
764         _2S = ecf_double(S, a->curve);
765         outx = modsub(M2, _2S, a->curve->p);
766         freebn(M2);
767         freebn(_2S);
768     }
769
770     /* Y' = M*(S - X') - 8*Y^4 */
771     {
772         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
773
774         SX = modsub(S, outx, a->curve->p);
775         freebn(S);
776         MSX = modmul(M, SX, a->curve->p);
777         freebn(SX);
778         freebn(M);
779         Y2 = ecf_square(a->y, a->curve);
780         Y4 = ecf_square(Y2, a->curve);
781         freebn(Y2);
782         Eight = bignum_from_long(8);
783         _8Y4 = modmul(Eight, Y4, a->curve->p);
784         freebn(Eight);
785         freebn(Y4);
786         outy = modsub(MSX, _8Y4, a->curve->p);
787         freebn(MSX);
788         freebn(_8Y4);
789     }
790
791     /* Z' = 2*Y*Z */
792     {
793         Bignum YZ;
794
795         if (a->z == NULL) {
796             YZ = copybn(a->y);
797         } else {
798             YZ = modmul(a->y, a->z, a->curve->p);
799         }
800
801         outz = ecf_double(YZ, a->curve);
802         freebn(YZ);
803     }
804
805     return ec_point_new(a->curve, outx, outy, outz, 0);
806 }
807
808 static struct ec_point *ecp_doublem(const struct ec_point *a)
809 {
810     Bignum z, outx, outz, xpz, xmz;
811
812     z = a->z;
813     if (!z) {
814         z = One;
815     }
816
817     /* 4xz = (x + z)^2 - (x - z)^2 */
818     {
819         Bignum tmp;
820
821         tmp = ecf_add(a->x, z, a->curve);
822         xpz = ecf_square(tmp, a->curve);
823         freebn(tmp);
824
825         tmp = modsub(a->x, z, a->curve->p);
826         xmz = ecf_square(tmp, a->curve);
827         freebn(tmp);
828     }
829
830     /* outx = (x + z)^2 * (x - z)^2 */
831     outx = modmul(xpz, xmz, a->curve->p);
832
833     /* outz = 4xz * ((x - z)^2 + ((A + 2) / 4)*4xz) */
834     {
835         Bignum _4xz, tmp, tmp2, tmp3;
836
837         tmp = bignum_from_long(2);
838         tmp2 = ecf_add(a->curve->m.a, tmp, a->curve);
839         freebn(tmp);
840
841         _4xz = modsub(xpz, xmz, a->curve->p);
842         freebn(xpz);
843         tmp = modmul(tmp2, _4xz, a->curve->p);
844         freebn(tmp2);
845
846         tmp2 = bignum_from_long(4);
847         tmp3 = modinv(tmp2, a->curve->p);
848         freebn(tmp2);
849         if (!tmp3) {
850             freebn(tmp);
851             freebn(_4xz);
852             freebn(outx);
853             freebn(xmz);
854             return NULL;
855         }
856         tmp2 = modmul(tmp, tmp3, a->curve->p);
857         freebn(tmp);
858         freebn(tmp3);
859
860         tmp = ecf_add(xmz, tmp2, a->curve);
861         freebn(xmz);
862         freebn(tmp2);
863         outz = modmul(_4xz, tmp, a->curve->p);
864         freebn(_4xz);
865         freebn(tmp);
866     }
867
868     return ec_point_new(a->curve, outx, NULL, outz, 0);
869 }
870
871 /* Forward declaration for Edwards curve doubling */
872 static struct ec_point *ecp_add(const struct ec_point *a,
873                                 const struct ec_point *b,
874                                 const int aminus3);
875
876 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
877 {
878     if (a->infinity)
879     {
880         /* Identity */
881         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
882     }
883
884     if (a->curve->type == EC_EDWARDS)
885     {
886         return ecp_add(a, a, aminus3);
887     }
888     else if (a->curve->type == EC_WEIERSTRASS)
889     {
890         return ecp_doublew(a, aminus3);
891     }
892     else
893     {
894         return ecp_doublem(a);
895     }
896 }
897
898 static struct ec_point *ecp_addw(const struct ec_point *a,
899                                  const struct ec_point *b,
900                                  const int aminus3)
901 {
902     Bignum U1, U2, S1, S2, outx, outy, outz;
903
904     /* U1 = X1*Z2^2 */
905     /* S1 = Y1*Z2^3 */
906     if (b->z) {
907         Bignum Z2, Z3;
908
909         Z2 = ecf_square(b->z, a->curve);
910         U1 = modmul(a->x, Z2, a->curve->p);
911         Z3 = modmul(Z2, b->z, a->curve->p);
912         freebn(Z2);
913         S1 = modmul(a->y, Z3, a->curve->p);
914         freebn(Z3);
915     } else {
916         U1 = copybn(a->x);
917         S1 = copybn(a->y);
918     }
919
920     /* U2 = X2*Z1^2 */
921     /* S2 = Y2*Z1^3 */
922     if (a->z) {
923         Bignum Z2, Z3;
924
925         Z2 = ecf_square(a->z, b->curve);
926         U2 = modmul(b->x, Z2, b->curve->p);
927         Z3 = modmul(Z2, a->z, b->curve->p);
928         freebn(Z2);
929         S2 = modmul(b->y, Z3, b->curve->p);
930         freebn(Z3);
931     } else {
932         U2 = copybn(b->x);
933         S2 = copybn(b->y);
934     }
935
936     /* Check if multiplying by self */
937     if (bignum_cmp(U1, U2) == 0)
938     {
939         freebn(U1);
940         freebn(U2);
941         if (bignum_cmp(S1, S2) == 0)
942         {
943             freebn(S1);
944             freebn(S2);
945             return ecp_double(a, aminus3);
946         }
947         else
948         {
949             freebn(S1);
950             freebn(S2);
951             /* Infinity */
952             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
953         }
954     }
955
956     {
957         Bignum H, R, UH2, H3;
958
959         /* H = U2 - U1 */
960         H = modsub(U2, U1, a->curve->p);
961         freebn(U2);
962
963         /* R = S2 - S1 */
964         R = modsub(S2, S1, a->curve->p);
965         freebn(S2);
966
967         /* X3 = R^2 - H^3 - 2*U1*H^2 */
968         {
969             Bignum R2, H2, _2UH2, first;
970
971             H2 = ecf_square(H, a->curve);
972             UH2 = modmul(U1, H2, a->curve->p);
973             freebn(U1);
974             H3 = modmul(H2, H, a->curve->p);
975             freebn(H2);
976             R2 = ecf_square(R, a->curve);
977             _2UH2 = ecf_double(UH2, a->curve);
978             first = modsub(R2, H3, a->curve->p);
979             freebn(R2);
980             outx = modsub(first, _2UH2, a->curve->p);
981             freebn(first);
982             freebn(_2UH2);
983         }
984
985         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
986         {
987             Bignum RUH2mX, UH2mX, SH3;
988
989             UH2mX = modsub(UH2, outx, a->curve->p);
990             freebn(UH2);
991             RUH2mX = modmul(R, UH2mX, a->curve->p);
992             freebn(UH2mX);
993             freebn(R);
994             SH3 = modmul(S1, H3, a->curve->p);
995             freebn(S1);
996             freebn(H3);
997
998             outy = modsub(RUH2mX, SH3, a->curve->p);
999             freebn(RUH2mX);
1000             freebn(SH3);
1001         }
1002
1003         /* Z3 = H*Z1*Z2 */
1004         if (a->z && b->z) {
1005             Bignum ZZ;
1006
1007             ZZ = modmul(a->z, b->z, a->curve->p);
1008             outz = modmul(H, ZZ, a->curve->p);
1009             freebn(H);
1010             freebn(ZZ);
1011         } else if (a->z) {
1012             outz = modmul(H, a->z, a->curve->p);
1013             freebn(H);
1014         } else if (b->z) {
1015             outz = modmul(H, b->z, a->curve->p);
1016             freebn(H);
1017         } else {
1018             outz = H;
1019         }
1020     }
1021
1022     return ec_point_new(a->curve, outx, outy, outz, 0);
1023 }
1024
1025 static struct ec_point *ecp_addm(const struct ec_point *a,
1026                                  const struct ec_point *b,
1027                                  const struct ec_point *base)
1028 {
1029     Bignum outx, outz, az, bz;
1030
1031     az = a->z;
1032     if (!az) {
1033         az = One;
1034     }
1035     bz = b->z;
1036     if (!bz) {
1037         bz = One;
1038     }
1039
1040     /* a-b is maintained at 1 due to Montgomery ladder implementation */
1041     /* Xa+b = Za-b * ((Xa - Za)*(Xb + Zb) + (Xa + Za)*(Xb - Zb))^2 */
1042     /* Za+b = Xa-b * ((Xa - Za)*(Xb + Zb) - (Xa + Za)*(Xb - Zb))^2 */
1043     {
1044         Bignum tmp, tmp2, tmp3, tmp4;
1045
1046         /* (Xa + Za) * (Xb - Zb) */
1047         tmp = ecf_add(a->x, az, a->curve);
1048         tmp2 = modsub(b->x, bz, a->curve->p);
1049         tmp3 = modmul(tmp, tmp2, a->curve->p);
1050         freebn(tmp);
1051         freebn(tmp2);
1052
1053         /* (Xa - Za) * (Xb + Zb) */
1054         tmp = modsub(a->x, az, a->curve->p);
1055         tmp2 = ecf_add(b->x, bz, a->curve);
1056         tmp4 = modmul(tmp, tmp2, a->curve->p);
1057         freebn(tmp);
1058         freebn(tmp2);
1059
1060         tmp = ecf_add(tmp3, tmp4, a->curve);
1061         outx = ecf_square(tmp, a->curve);
1062         freebn(tmp);
1063
1064         tmp = modsub(tmp3, tmp4, a->curve->p);
1065         freebn(tmp3);
1066         freebn(tmp4);
1067         tmp2 = ecf_square(tmp, a->curve);
1068         freebn(tmp);
1069         outz = modmul(base->x, tmp2, a->curve->p);
1070         freebn(tmp2);
1071     }
1072
1073     return ec_point_new(a->curve, outx, NULL, outz, 0);
1074 }
1075
1076 static struct ec_point *ecp_adde(const struct ec_point *a,
1077                                  const struct ec_point *b)
1078 {
1079     Bignum outx, outy, dmul;
1080
1081     /* outx = (a->x * b->y + b->x * a->y) /
1082      *        (1 + a->curve->e.d * a->x * b->x * a->y * b->y) */
1083     {
1084         Bignum tmp, tmp2, tmp3, tmp4;
1085
1086         tmp = modmul(a->x, b->y, a->curve->p);
1087         tmp2 = modmul(b->x, a->y, a->curve->p);
1088         tmp3 = ecf_add(tmp, tmp2, a->curve);
1089
1090         tmp4 = modmul(tmp, tmp2, a->curve->p);
1091         freebn(tmp);
1092         freebn(tmp2);
1093         dmul = modmul(a->curve->e.d, tmp4, a->curve->p);
1094         freebn(tmp4);
1095
1096         tmp = ecf_add(One, dmul, a->curve);
1097         tmp2 = modinv(tmp, a->curve->p);
1098         freebn(tmp);
1099         if (!tmp2)
1100         {
1101             freebn(tmp3);
1102             freebn(dmul);
1103             return NULL;
1104         }
1105
1106         outx = modmul(tmp3, tmp2, a->curve->p);
1107         freebn(tmp3);
1108         freebn(tmp2);
1109     }
1110
1111     /* outy = (a->y * b->y + a->x * b->x) /
1112      *        (1 - a->curve->e.d * a->x * b->x * a->y * b->y) */
1113     {
1114         Bignum tmp, tmp2, tmp3, tmp4;
1115
1116         tmp = modsub(One, dmul, a->curve->p);
1117         freebn(dmul);
1118
1119         tmp2 = modinv(tmp, a->curve->p);
1120         freebn(tmp);
1121         if (!tmp2)
1122         {
1123             freebn(outx);
1124             return NULL;
1125         }
1126
1127         tmp = modmul(a->y, b->y, a->curve->p);
1128         tmp3 = modmul(a->x, b->x, a->curve->p);
1129         tmp4 = ecf_add(tmp, tmp3, a->curve);
1130         freebn(tmp);
1131         freebn(tmp3);
1132
1133         outy = modmul(tmp4, tmp2, a->curve->p);
1134         freebn(tmp4);
1135         freebn(tmp2);
1136     }
1137
1138     return ec_point_new(a->curve, outx, outy, NULL, 0);
1139 }
1140
1141 static struct ec_point *ecp_add(const struct ec_point *a,
1142                                 const struct ec_point *b,
1143                                 const int aminus3)
1144 {
1145     if (a->curve != b->curve) {
1146         return NULL;
1147     }
1148
1149     /* Check if multiplying by infinity */
1150     if (a->infinity) return ec_point_copy(b);
1151     if (b->infinity) return ec_point_copy(a);
1152
1153     if (a->curve->type == EC_EDWARDS)
1154     {
1155         return ecp_adde(a, b);
1156     }
1157
1158     if (a->curve->type == EC_WEIERSTRASS)
1159     {
1160         return ecp_addw(a, b, aminus3);
1161     }
1162
1163     return NULL;
1164 }
1165
1166 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1167 {
1168     struct ec_point *A, *ret;
1169     int bits, i;
1170
1171     A = ec_point_copy(a);
1172     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1173
1174     bits = bignum_bitcount(b);
1175     for (i = 0; i < bits; ++i)
1176     {
1177         if (bignum_bit(b, i))
1178         {
1179             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1180             ec_point_free(ret);
1181             ret = tmp;
1182         }
1183         if (i+1 != bits)
1184         {
1185             struct ec_point *tmp = ecp_double(A, aminus3);
1186             ec_point_free(A);
1187             A = tmp;
1188         }
1189     }
1190
1191     ec_point_free(A);
1192     return ret;
1193 }
1194
1195 static struct ec_point *ecp_mulw(const struct ec_point *a, const Bignum b)
1196 {
1197     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1198
1199     if (!ecp_normalise(ret)) {
1200         ec_point_free(ret);
1201         return NULL;
1202     }
1203
1204     return ret;
1205 }
1206
1207 static struct ec_point *ecp_mule(const struct ec_point *a, const Bignum b)
1208 {
1209     int i;
1210     struct ec_point *ret;
1211
1212     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1213
1214     for (i = bignum_bitcount(b); i >= 0 && ret; --i)
1215     {
1216         {
1217             struct ec_point *tmp = ecp_double(ret, 0);
1218             ec_point_free(ret);
1219             ret = tmp;
1220         }
1221         if (ret && bignum_bit(b, i))
1222         {
1223             struct ec_point *tmp = ecp_add(ret, a, 0);
1224             ec_point_free(ret);
1225             ret = tmp;
1226         }
1227     }
1228
1229     return ret;
1230 }
1231
1232 static struct ec_point *ecp_mulm(const struct ec_point *p, const Bignum n)
1233 {
1234     struct ec_point *P1, *P2;
1235     int bits, i;
1236
1237     /* P1 <- P and P2 <- [2]P */
1238     P2 = ecp_double(p, 0);
1239     P1 = ec_point_copy(p);
1240
1241     /* for i = bits âˆ’ 2 down to 0 */
1242     bits = bignum_bitcount(n);
1243     for (i = bits - 2; i >= 0; --i)
1244     {
1245         if (!bignum_bit(n, i))
1246         {
1247             /* P2 <- P1 + P2 */
1248             struct ec_point *tmp = ecp_addm(P1, P2, p);
1249             ec_point_free(P2);
1250             P2 = tmp;
1251
1252             /* P1 <- [2]P1 */
1253             tmp = ecp_double(P1, 0);
1254             ec_point_free(P1);
1255             P1 = tmp;
1256         }
1257         else
1258         {
1259             /* P1 <- P1 + P2 */
1260             struct ec_point *tmp = ecp_addm(P1, P2, p);
1261             ec_point_free(P1);
1262             P1 = tmp;
1263
1264             /* P2 <- [2]P2 */
1265             tmp = ecp_double(P2, 0);
1266             ec_point_free(P2);
1267             P2 = tmp;
1268         }
1269     }
1270
1271     ec_point_free(P2);
1272
1273     if (!ecp_normalise(P1)) {
1274         ec_point_free(P1);
1275         return NULL;
1276     }
1277
1278     return P1;
1279 }
1280
1281 /* Not static because it is used by sshecdsag.c to generate a new key */
1282 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1283 {
1284     if (a->curve->type == EC_WEIERSTRASS) {
1285         return ecp_mulw(a, b);
1286     } else if (a->curve->type == EC_EDWARDS) {
1287         return ecp_mule(a, b);
1288     } else {
1289         return ecp_mulm(a, b);
1290     }
1291 }
1292
1293 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
1294                                    const struct ec_point *point)
1295 {
1296     struct ec_point *aG, *bP, *ret;
1297     int aminus3;
1298
1299     if (point->curve->type != EC_WEIERSTRASS) {
1300         return NULL;
1301     }
1302
1303     aminus3 = ec_aminus3(point->curve);
1304
1305     aG = ecp_mul_(&point->curve->w.G, a, aminus3);
1306     if (!aG) return NULL;
1307     bP = ecp_mul_(point, b, aminus3);
1308     if (!bP) {
1309         ec_point_free(aG);
1310         return NULL;
1311     }
1312
1313     ret = ecp_add(aG, bP, aminus3);
1314
1315     ec_point_free(aG);
1316     ec_point_free(bP);
1317
1318     if (!ecp_normalise(ret)) {
1319         ec_point_free(ret);
1320         return NULL;
1321     }
1322
1323     return ret;
1324 }
1325 static Bignum *ecp_edx(const struct ec_curve *curve, const Bignum y)
1326 {
1327     /* Get the x value on the given Edwards curve for a given y */
1328     Bignum x, xx;
1329
1330     /* xx = (y^2 - 1) / (d * y^2 + 1) */
1331     {
1332         Bignum tmp, tmp2, tmp3;
1333
1334         tmp = ecf_square(y, curve);
1335         tmp2 = modmul(curve->e.d, tmp, curve->p);
1336         tmp3 = ecf_add(tmp2, One, curve);
1337         freebn(tmp2);
1338         tmp2 = modinv(tmp3, curve->p);
1339         freebn(tmp3);
1340         if (!tmp2) {
1341             freebn(tmp);
1342             return NULL;
1343         }
1344
1345         tmp3 = modsub(tmp, One, curve->p);
1346         freebn(tmp);
1347         xx = modmul(tmp3, tmp2, curve->p);
1348         freebn(tmp3);
1349         freebn(tmp2);
1350     }
1351
1352     /* x = xx^((p + 3) / 8) */
1353     {
1354         Bignum tmp, tmp2;
1355
1356         tmp = bignum_add_long(curve->p, 3);
1357         tmp2 = bignum_rshift(tmp, 3);
1358         freebn(tmp);
1359         x = modpow(xx, tmp2, curve->p);
1360         freebn(tmp2);
1361     }
1362
1363     /* if x^2 - xx != 0 then x = x*(2^((p - 1) / 4)) */
1364     {
1365         Bignum tmp, tmp2;
1366
1367         tmp = ecf_square(x, curve);
1368         tmp2 = modsub(tmp, xx, curve->p);
1369         freebn(tmp);
1370         freebn(xx);
1371         if (bignum_cmp(tmp2, Zero)) {
1372             Bignum tmp3;
1373
1374             freebn(tmp2);
1375
1376             tmp = modsub(curve->p, One, curve->p);
1377             tmp2 = bignum_rshift(tmp, 2);
1378             freebn(tmp);
1379             tmp = bignum_from_long(2);
1380             tmp3 = modpow(tmp, tmp2, curve->p);
1381             freebn(tmp);
1382             freebn(tmp2);
1383
1384             tmp = modmul(x, tmp3, curve->p);
1385             freebn(x);
1386             freebn(tmp3);
1387             x = tmp;
1388         } else {
1389             freebn(tmp2);
1390         }
1391     }
1392
1393     /* if x % 2 != 0 then x = p - x */
1394     if (bignum_bit(x, 0)) {
1395         Bignum tmp = modsub(curve->p, x, curve->p);
1396         freebn(x);
1397         x = tmp;
1398     }
1399
1400     return x;
1401 }
1402
1403 /* ----------------------------------------------------------------------
1404  * Public point from private
1405  */
1406
1407 struct ec_point *ec_public(const Bignum privateKey, const struct ec_curve *curve)
1408 {
1409     if (curve->type == EC_WEIERSTRASS) {
1410         return ecp_mul(&curve->w.G, privateKey);
1411     } else if (curve->type == EC_EDWARDS) {
1412         /* hash = H(sk) (where hash creates 2 * fieldBits)
1413          * b = fieldBits
1414          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
1415          * publicKey = aB */
1416         struct ec_point *ret;
1417         unsigned char hash[512/8];
1418         Bignum a;
1419         int i, keylen;
1420         SHA512_State s;
1421         SHA512_Init(&s);
1422
1423         keylen = curve->fieldBits / 8;
1424         for (i = 0; i < keylen; ++i) {
1425             unsigned char b = bignum_byte(privateKey, i);
1426             SHA512_Bytes(&s, &b, 1);
1427         }
1428         SHA512_Final(&s, hash);
1429
1430         /* The second part is simply turning the hash into a Bignum,
1431          * however the 2^(b-2) bit *must* be set, and the bottom 3
1432          * bits *must* not be */
1433         hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
1434         hash[31] &= 0x7f; /* Unset above (b-2) */
1435         hash[31] |= 0x40; /* Set 2^(b-2) */
1436         /* Chop off the top part and convert to int */
1437         a = bignum_from_bytes_le(hash, 32);
1438
1439         ret = ecp_mul(&curve->e.B, a);
1440         freebn(a);
1441         return ret;
1442     } else {
1443         return NULL;
1444     }
1445 }
1446
1447 /* ----------------------------------------------------------------------
1448  * Basic sign and verify routines
1449  */
1450
1451 static int _ecdsa_verify(const struct ec_point *publicKey,
1452                          const unsigned char *data, const int dataLen,
1453                          const Bignum r, const Bignum s)
1454 {
1455     int z_bits, n_bits;
1456     Bignum z;
1457     int valid = 0;
1458
1459     if (publicKey->curve->type != EC_WEIERSTRASS) {
1460         return 0;
1461     }
1462
1463     /* Sanity checks */
1464     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->w.n) >= 0
1465         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->w.n) >= 0)
1466     {
1467         return 0;
1468     }
1469
1470     /* z = left most bitlen(curve->n) of data */
1471     z = bignum_from_bytes(data, dataLen);
1472     n_bits = bignum_bitcount(publicKey->curve->w.n);
1473     z_bits = bignum_bitcount(z);
1474     if (z_bits > n_bits)
1475     {
1476         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
1477         freebn(z);
1478         z = tmp;
1479     }
1480
1481     /* Ensure z in range of n */
1482     {
1483         Bignum tmp = bigmod(z, publicKey->curve->w.n);
1484         freebn(z);
1485         z = tmp;
1486     }
1487
1488     /* Calculate signature */
1489     {
1490         Bignum w, x, u1, u2;
1491         struct ec_point *tmp;
1492
1493         w = modinv(s, publicKey->curve->w.n);
1494         if (!w) {
1495             freebn(z);
1496             return 0;
1497         }
1498         u1 = modmul(z, w, publicKey->curve->w.n);
1499         u2 = modmul(r, w, publicKey->curve->w.n);
1500         freebn(w);
1501
1502         tmp = ecp_summul(u1, u2, publicKey);
1503         freebn(u1);
1504         freebn(u2);
1505         if (!tmp) {
1506             freebn(z);
1507             return 0;
1508         }
1509
1510         x = bigmod(tmp->x, publicKey->curve->w.n);
1511         ec_point_free(tmp);
1512
1513         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
1514         freebn(x);
1515     }
1516
1517     freebn(z);
1518
1519     return valid;
1520 }
1521
1522 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
1523                         const unsigned char *data, const int dataLen,
1524                         Bignum *r, Bignum *s)
1525 {
1526     unsigned char digest[20];
1527     int z_bits, n_bits;
1528     Bignum z, k;
1529     struct ec_point *kG;
1530
1531     *r = NULL;
1532     *s = NULL;
1533
1534     if (curve->type != EC_WEIERSTRASS) {
1535         return;
1536     }
1537
1538     /* z = left most bitlen(curve->n) of data */
1539     z = bignum_from_bytes(data, dataLen);
1540     n_bits = bignum_bitcount(curve->w.n);
1541     z_bits = bignum_bitcount(z);
1542     if (z_bits > n_bits)
1543     {
1544         Bignum tmp;
1545         tmp = bignum_rshift(z, z_bits - n_bits);
1546         freebn(z);
1547         z = tmp;
1548     }
1549
1550     /* Generate k between 1 and curve->n, using the same deterministic
1551      * k generation system we use for conventional DSA. */
1552     SHA_Simple(data, dataLen, digest);
1553     k = dss_gen_k("ECDSA deterministic k generator", curve->w.n, privateKey,
1554                   digest, sizeof(digest));
1555
1556     kG = ecp_mul(&curve->w.G, k);
1557     if (!kG) {
1558         freebn(z);
1559         freebn(k);
1560         return;
1561     }
1562
1563     /* r = kG.x mod n */
1564     *r = bigmod(kG->x, curve->w.n);
1565     ec_point_free(kG);
1566
1567     /* s = (z + r * priv)/k mod n */
1568     {
1569         Bignum rPriv, zMod, first, firstMod, kInv;
1570         rPriv = modmul(*r, privateKey, curve->w.n);
1571         zMod = bigmod(z, curve->w.n);
1572         freebn(z);
1573         first = bigadd(rPriv, zMod);
1574         freebn(rPriv);
1575         freebn(zMod);
1576         firstMod = bigmod(first, curve->w.n);
1577         freebn(first);
1578         kInv = modinv(k, curve->w.n);
1579         freebn(k);
1580         if (!kInv) {
1581             freebn(firstMod);
1582             freebn(*r);
1583             return;
1584         }
1585         *s = modmul(firstMod, kInv, curve->w.n);
1586         freebn(firstMod);
1587         freebn(kInv);
1588     }
1589 }
1590
1591 /* ----------------------------------------------------------------------
1592  * Misc functions
1593  */
1594
1595 static void getstring(const char **data, int *datalen,
1596                       const char **p, int *length)
1597 {
1598     *p = NULL;
1599     if (*datalen < 4)
1600         return;
1601     *length = toint(GET_32BIT(*data));
1602     if (*length < 0)
1603         return;
1604     *datalen -= 4;
1605     *data += 4;
1606     if (*datalen < *length)
1607         return;
1608     *p = *data;
1609     *data += *length;
1610     *datalen -= *length;
1611 }
1612
1613 static Bignum getmp(const char **data, int *datalen)
1614 {
1615     const char *p;
1616     int length;
1617
1618     getstring(data, datalen, &p, &length);
1619     if (!p)
1620         return NULL;
1621     if (p[0] & 0x80)
1622         return NULL;                   /* negative mp */
1623     return bignum_from_bytes((unsigned char *)p, length);
1624 }
1625
1626 static Bignum getmp_le(const char **data, int *datalen)
1627 {
1628     const char *p;
1629     int length;
1630
1631     getstring(data, datalen, &p, &length);
1632     if (!p)
1633         return NULL;
1634     return bignum_from_bytes_le((const unsigned char *)p, length);
1635 }
1636
1637 static int decodepoint_ed(const char *p, int length, struct ec_point *point)
1638 {
1639     /* Got some conversion to do, first read in the y co-ord */
1640     int negative;
1641
1642     point->y = bignum_from_bytes_le((const unsigned char*)p, length);
1643     if ((unsigned)bignum_bitcount(point->y) > point->curve->fieldBits) {
1644         freebn(point->y);
1645         point->y = NULL;
1646         return 0;
1647     }
1648     /* Read x bit and then reset it */
1649     negative = bignum_bit(point->y, point->curve->fieldBits - 1);
1650     bignum_set_bit(point->y, point->curve->fieldBits - 1, 0);
1651     bn_restore_invariant(point->y);
1652
1653     /* Get the x from the y */
1654     point->x = ecp_edx(point->curve, point->y);
1655     if (!point->x) {
1656         freebn(point->y);
1657         point->y = NULL;
1658         return 0;
1659     }
1660     if (negative) {
1661         Bignum tmp = modsub(point->curve->p, point->x, point->curve->p);
1662         freebn(point->x);
1663         point->x = tmp;
1664     }
1665
1666     /* Verify the point is on the curve */
1667     if (!ec_point_verify(point)) {
1668         freebn(point->x);
1669         point->x = NULL;
1670         freebn(point->y);
1671         point->y = NULL;
1672         return 0;
1673     }
1674
1675     return 1;
1676 }
1677
1678 static int decodepoint(const char *p, int length, struct ec_point *point)
1679 {
1680     if (point->curve->type == EC_EDWARDS) {
1681         return decodepoint_ed(p, length, point);
1682     }
1683
1684     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
1685         return 0;
1686     /* Skip compression flag */
1687     ++p;
1688     --length;
1689     /* The two values must be equal length */
1690     if (length % 2 != 0) {
1691         point->x = NULL;
1692         point->y = NULL;
1693         point->z = NULL;
1694         return 0;
1695     }
1696     length = length / 2;
1697     point->x = bignum_from_bytes((const unsigned char *)p, length);
1698     p += length;
1699     point->y = bignum_from_bytes((const unsigned char *)p, length);
1700     point->z = NULL;
1701
1702     /* Verify the point is on the curve */
1703     if (!ec_point_verify(point)) {
1704         freebn(point->x);
1705         point->x = NULL;
1706         freebn(point->y);
1707         point->y = NULL;
1708         return 0;
1709     }
1710
1711     return 1;
1712 }
1713
1714 static int getmppoint(const char **data, int *datalen, struct ec_point *point)
1715 {
1716     const char *p;
1717     int length;
1718
1719     getstring(data, datalen, &p, &length);
1720     if (!p) return 0;
1721     return decodepoint(p, length, point);
1722 }
1723
1724 /* ----------------------------------------------------------------------
1725  * Exposed ECDSA interface
1726  */
1727
1728 struct ecsign_extra {
1729     struct ec_curve *(*curve)(void);
1730     const struct ssh_hash *hash;
1731
1732     /* These fields are used by the OpenSSH PEM format importer/exporter */
1733     const unsigned char *oid;
1734     int oidlen;
1735 };
1736
1737 static void ecdsa_freekey(void *key)
1738 {
1739     struct ec_key *ec = (struct ec_key *) key;
1740     if (!ec) return;
1741
1742     if (ec->publicKey.x)
1743         freebn(ec->publicKey.x);
1744     if (ec->publicKey.y)
1745         freebn(ec->publicKey.y);
1746     if (ec->publicKey.z)
1747         freebn(ec->publicKey.z);
1748     if (ec->privateKey)
1749         freebn(ec->privateKey);
1750     sfree(ec);
1751 }
1752
1753 static void *ecdsa_newkey(const struct ssh_signkey *self,
1754                           const char *data, int len)
1755 {
1756     const struct ecsign_extra *extra =
1757         (const struct ecsign_extra *)self->extra;
1758     const char *p;
1759     int slen;
1760     struct ec_key *ec;
1761     struct ec_curve *curve;
1762
1763     getstring(&data, &len, &p, &slen);
1764
1765     if (!p) {
1766         return NULL;
1767     }
1768     curve = extra->curve();
1769     assert(curve->type == EC_WEIERSTRASS || curve->type == EC_EDWARDS);
1770
1771     /* Curve name is duplicated for Weierstrass form */
1772     if (curve->type == EC_WEIERSTRASS) {
1773         getstring(&data, &len, &p, &slen);
1774         if (!p) return NULL;
1775         if (!match_ssh_id(slen, p, curve->name)) return NULL;
1776     }
1777
1778     ec = snew(struct ec_key);
1779
1780     ec->signalg = self;
1781     ec->publicKey.curve = curve;
1782     ec->publicKey.infinity = 0;
1783     ec->publicKey.x = NULL;
1784     ec->publicKey.y = NULL;
1785     ec->publicKey.z = NULL;
1786     ec->privateKey = NULL;
1787     if (!getmppoint(&data, &len, &ec->publicKey)) {
1788         ecdsa_freekey(ec);
1789         return NULL;
1790     }
1791
1792     if (!ec->publicKey.x || !ec->publicKey.y ||
1793         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
1794         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
1795     {
1796         ecdsa_freekey(ec);
1797         ec = NULL;
1798     }
1799
1800     return ec;
1801 }
1802
1803 static char *ecdsa_fmtkey(void *key)
1804 {
1805     struct ec_key *ec = (struct ec_key *) key;
1806     char *p;
1807     int len, i, pos, nibbles;
1808     static const char hex[] = "0123456789abcdef";
1809     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
1810         return NULL;
1811
1812     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
1813     if (ec->publicKey.curve->name)
1814         len += strlen(ec->publicKey.curve->name); /* Curve name */
1815     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
1816     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
1817     p = snewn(len, char);
1818
1819     pos = 0;
1820     if (ec->publicKey.curve->name)
1821         pos += sprintf(p + pos, "%s,", ec->publicKey.curve->name);
1822     pos += sprintf(p + pos, "0x");
1823     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
1824     if (nibbles < 1)
1825         nibbles = 1;
1826     for (i = nibbles; i--;) {
1827         p[pos++] =
1828             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
1829     }
1830     pos += sprintf(p + pos, ",0x");
1831     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
1832     if (nibbles < 1)
1833         nibbles = 1;
1834     for (i = nibbles; i--;) {
1835         p[pos++] =
1836             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
1837     }
1838     p[pos] = '\0';
1839     return p;
1840 }
1841
1842 static unsigned char *ecdsa_public_blob(void *key, int *len)
1843 {
1844     struct ec_key *ec = (struct ec_key *) key;
1845     int pointlen, bloblen, fullnamelen, namelen;
1846     int i;
1847     unsigned char *blob, *p;
1848
1849     fullnamelen = strlen(ec->signalg->name);
1850
1851     if (ec->publicKey.curve->type == EC_EDWARDS) {
1852         /* Edwards compressed form "ssh-ed25519" point y[:-1] + x[0:1] */
1853
1854         pointlen = ec->publicKey.curve->fieldBits / 8;
1855
1856         /* Can't handle this in our loop */
1857         if (pointlen < 2) return NULL;
1858
1859         bloblen = 4 + fullnamelen + 4 + pointlen;
1860         blob = snewn(bloblen, unsigned char);
1861
1862         p = blob;
1863         PUT_32BIT(p, fullnamelen);
1864         p += 4;
1865         memcpy(p, ec->signalg->name, fullnamelen);
1866         p += fullnamelen;
1867         PUT_32BIT(p, pointlen);
1868         p += 4;
1869
1870         /* Unset last bit of y and set first bit of x in its place */
1871         for (i = 0; i < pointlen - 1; ++i) {
1872             *p++ = bignum_byte(ec->publicKey.y, i);
1873         }
1874         /* Unset last bit of y and set first bit of x in its place */
1875         *p = bignum_byte(ec->publicKey.y, i) & 0x7f;
1876         *p++ |= bignum_bit(ec->publicKey.x, 0) << 7;
1877     } else if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
1878         assert(ec->publicKey.curve->name);
1879         namelen = strlen(ec->publicKey.curve->name);
1880
1881         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1882
1883         /*
1884          * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
1885          */
1886         bloblen = 4 + fullnamelen + 4 + namelen + 4 + 1 + (pointlen * 2);
1887         blob = snewn(bloblen, unsigned char);
1888
1889         p = blob;
1890         PUT_32BIT(p, fullnamelen);
1891         p += 4;
1892         memcpy(p, ec->signalg->name, fullnamelen);
1893         p += fullnamelen;
1894         PUT_32BIT(p, namelen);
1895         p += 4;
1896         memcpy(p, ec->publicKey.curve->name, namelen);
1897         p += namelen;
1898         PUT_32BIT(p, (2 * pointlen) + 1);
1899         p += 4;
1900         *p++ = 0x04;
1901         for (i = pointlen; i--;) {
1902             *p++ = bignum_byte(ec->publicKey.x, i);
1903         }
1904         for (i = pointlen; i--;) {
1905             *p++ = bignum_byte(ec->publicKey.y, i);
1906         }
1907     } else {
1908         return NULL;
1909     }
1910
1911     assert(p == blob + bloblen);
1912     *len = bloblen;
1913
1914     return blob;
1915 }
1916
1917 static unsigned char *ecdsa_private_blob(void *key, int *len)
1918 {
1919     struct ec_key *ec = (struct ec_key *) key;
1920     int keylen, bloblen;
1921     int i;
1922     unsigned char *blob, *p;
1923
1924     if (!ec->privateKey) return NULL;
1925
1926     if (ec->publicKey.curve->type == EC_EDWARDS) {
1927         /* Unsigned */
1928         keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
1929     } else {
1930         /* Signed */
1931         keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
1932     }
1933
1934     /*
1935      * mpint privateKey. Total 4 + keylen.
1936      */
1937     bloblen = 4 + keylen;
1938     blob = snewn(bloblen, unsigned char);
1939
1940     p = blob;
1941     PUT_32BIT(p, keylen);
1942     p += 4;
1943     if (ec->publicKey.curve->type == EC_EDWARDS) {
1944         /* Little endian */
1945         for (i = 0; i < keylen; ++i)
1946             *p++ = bignum_byte(ec->privateKey, i);
1947     } else {
1948         for (i = keylen; i--;)
1949             *p++ = bignum_byte(ec->privateKey, i);
1950     }
1951
1952     assert(p == blob + bloblen);
1953     *len = bloblen;
1954     return blob;
1955 }
1956
1957 static void *ecdsa_createkey(const struct ssh_signkey *self,
1958                              const unsigned char *pub_blob, int pub_len,
1959                              const unsigned char *priv_blob, int priv_len)
1960 {
1961     struct ec_key *ec;
1962     struct ec_point *publicKey;
1963     const char *pb = (const char *) priv_blob;
1964
1965     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) pub_blob, pub_len);
1966     if (!ec) {
1967         return NULL;
1968     }
1969
1970     if (ec->publicKey.curve->type != EC_WEIERSTRASS
1971         && ec->publicKey.curve->type != EC_EDWARDS) {
1972         ecdsa_freekey(ec);
1973         return NULL;
1974     }
1975
1976     if (ec->publicKey.curve->type == EC_EDWARDS) {
1977         ec->privateKey = getmp_le(&pb, &priv_len);
1978     } else {
1979         ec->privateKey = getmp(&pb, &priv_len);
1980     }
1981     if (!ec->privateKey) {
1982         ecdsa_freekey(ec);
1983         return NULL;
1984     }
1985
1986     /* Check that private key generates public key */
1987     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
1988
1989     if (!publicKey ||
1990         bignum_cmp(publicKey->x, ec->publicKey.x) ||
1991         bignum_cmp(publicKey->y, ec->publicKey.y))
1992     {
1993         ecdsa_freekey(ec);
1994         ec = NULL;
1995     }
1996     ec_point_free(publicKey);
1997
1998     return ec;
1999 }
2000
2001 static void *ed25519_openssh_createkey(const struct ssh_signkey *self,
2002                                        const unsigned char **blob, int *len)
2003 {
2004     struct ec_key *ec;
2005     struct ec_point *publicKey;
2006     const char *p, *q;
2007     int plen, qlen;
2008
2009     getstring((const char**)blob, len, &p, &plen);
2010     if (!p)
2011     {
2012         return NULL;
2013     }
2014
2015     ec = snew(struct ec_key);
2016
2017     ec->signalg = self;
2018     ec->publicKey.curve = ec_ed25519();
2019     ec->publicKey.infinity = 0;
2020     ec->privateKey = NULL;
2021     ec->publicKey.x = NULL;
2022     ec->publicKey.z = NULL;
2023     ec->publicKey.y = NULL;
2024
2025     if (!decodepoint_ed(p, plen, &ec->publicKey))
2026     {
2027         ecdsa_freekey(ec);
2028         return NULL;
2029     }
2030
2031     getstring((const char**)blob, len, &q, &qlen);
2032     if (!q)
2033         return NULL;
2034     if (qlen != 64)
2035         return NULL;
2036
2037     ec->privateKey = bignum_from_bytes_le((const unsigned char *)q, 32);
2038
2039     /* Check that private key generates public key */
2040     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2041
2042     if (!publicKey ||
2043         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2044         bignum_cmp(publicKey->y, ec->publicKey.y))
2045     {
2046         ecdsa_freekey(ec);
2047         ec = NULL;
2048     }
2049     ec_point_free(publicKey);
2050
2051     /* The OpenSSH format for ed25519 private keys also for some
2052      * reason encodes an extra copy of the public key in the second
2053      * half of the secret-key string. Check that that's present and
2054      * correct as well, otherwise the key we think we've imported
2055      * won't behave identically to the way OpenSSH would have treated
2056      * it. */
2057     if (plen != 32 || 0 != memcmp(q + 32, p, 32)) {
2058         ecdsa_freekey(ec);
2059         return NULL;
2060     }
2061
2062     return ec;
2063 }
2064
2065 static int ed25519_openssh_fmtkey(void *key, unsigned char *blob, int len)
2066 {
2067     struct ec_key *ec = (struct ec_key *) key;
2068
2069     int pointlen;
2070     int keylen;
2071     int bloblen;
2072     int i;
2073
2074     if (ec->publicKey.curve->type != EC_EDWARDS) {
2075         return 0;
2076     }
2077
2078     pointlen = (bignum_bitcount(ec->publicKey.y) + 7) / 8;
2079     keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2080     bloblen = 4 + pointlen + 4 + keylen + pointlen;
2081
2082     if (bloblen > len)
2083         return bloblen;
2084
2085     /* Encode the public point */
2086     PUT_32BIT(blob, pointlen);
2087     blob += 4;
2088
2089     for (i = 0; i < pointlen - 1; ++i) {
2090          *blob++ = bignum_byte(ec->publicKey.y, i);
2091     }
2092     /* Unset last bit of y and set first bit of x in its place */
2093     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2094     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2095
2096     PUT_32BIT(blob, keylen + pointlen);
2097     blob += 4;
2098     for (i = 0; i < keylen; ++i) {
2099          *blob++ = bignum_byte(ec->privateKey, i);
2100     }
2101     /* Now encode an extra copy of the public point as the second half
2102      * of the private key string, as the OpenSSH format for some
2103      * reason requires */
2104     for (i = 0; i < pointlen - 1; ++i) {
2105          *blob++ = bignum_byte(ec->publicKey.y, i);
2106     }
2107     /* Unset last bit of y and set first bit of x in its place */
2108     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2109     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2110
2111     return bloblen;
2112 }
2113
2114 static void *ecdsa_openssh_createkey(const struct ssh_signkey *self,
2115                                      const unsigned char **blob, int *len)
2116 {
2117     const struct ecsign_extra *extra =
2118         (const struct ecsign_extra *)self->extra;
2119     const char **b = (const char **) blob;
2120     const char *p;
2121     int slen;
2122     struct ec_key *ec;
2123     struct ec_curve *curve;
2124     struct ec_point *publicKey;
2125
2126     getstring(b, len, &p, &slen);
2127
2128     if (!p) {
2129         return NULL;
2130     }
2131     curve = extra->curve();
2132     assert(curve->type == EC_WEIERSTRASS);
2133
2134     ec = snew(struct ec_key);
2135
2136     ec->signalg = self;
2137     ec->publicKey.curve = curve;
2138     ec->publicKey.infinity = 0;
2139     ec->publicKey.x = NULL;
2140     ec->publicKey.y = NULL;
2141     ec->publicKey.z = NULL;
2142     if (!getmppoint(b, len, &ec->publicKey)) {
2143         ecdsa_freekey(ec);
2144         return NULL;
2145     }
2146     ec->privateKey = NULL;
2147
2148     if (!ec->publicKey.x || !ec->publicKey.y ||
2149         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2150         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2151     {
2152         ecdsa_freekey(ec);
2153         return NULL;
2154     }
2155
2156     ec->privateKey = getmp(b, len);
2157     if (ec->privateKey == NULL)
2158     {
2159         ecdsa_freekey(ec);
2160         return NULL;
2161     }
2162
2163     /* Now check that the private key makes the public key */
2164     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2165     if (!publicKey)
2166     {
2167         ecdsa_freekey(ec);
2168         return NULL;
2169     }
2170
2171     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
2172         bignum_cmp(ec->publicKey.y, publicKey->y))
2173     {
2174         /* Private key doesn't make the public key on the given curve */
2175         ecdsa_freekey(ec);
2176         ec_point_free(publicKey);
2177         return NULL;
2178     }
2179
2180     ec_point_free(publicKey);
2181
2182     return ec;
2183 }
2184
2185 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
2186 {
2187     struct ec_key *ec = (struct ec_key *) key;
2188
2189     int pointlen;
2190     int namelen;
2191     int bloblen;
2192     int i;
2193
2194     if (ec->publicKey.curve->type != EC_WEIERSTRASS) {
2195         return 0;
2196     }
2197
2198     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2199     namelen = strlen(ec->publicKey.curve->name);
2200     bloblen =
2201         4 + namelen /* <LEN> nistpXXX */
2202         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
2203         + ssh2_bignum_length(ec->privateKey);
2204
2205     if (bloblen > len)
2206         return bloblen;
2207
2208     bloblen = 0;
2209
2210     PUT_32BIT(blob+bloblen, namelen);
2211     bloblen += 4;
2212     memcpy(blob+bloblen, ec->publicKey.curve->name, namelen);
2213     bloblen += namelen;
2214
2215     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
2216     bloblen += 4;
2217     blob[bloblen++] = 0x04;
2218     for (i = pointlen; i--; )
2219         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
2220     for (i = pointlen; i--; )
2221         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
2222
2223     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2224     PUT_32BIT(blob+bloblen, pointlen);
2225     bloblen += 4;
2226     for (i = pointlen; i--; )
2227         blob[bloblen++] = bignum_byte(ec->privateKey, i);
2228
2229     return bloblen;
2230 }
2231
2232 static int ecdsa_pubkey_bits(const struct ssh_signkey *self,
2233                              const void *blob, int len)
2234 {
2235     struct ec_key *ec;
2236     int ret;
2237
2238     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) blob, len);
2239     if (!ec)
2240         return -1;
2241     ret = ec->publicKey.curve->fieldBits;
2242     ecdsa_freekey(ec);
2243
2244     return ret;
2245 }
2246
2247 static int ecdsa_verifysig(void *key, const char *sig, int siglen,
2248                            const char *data, int datalen)
2249 {
2250     struct ec_key *ec = (struct ec_key *) key;
2251     const struct ecsign_extra *extra =
2252         (const struct ecsign_extra *)ec->signalg->extra;
2253     const char *p;
2254     int slen;
2255     int digestLen;
2256     int ret;
2257
2258     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2259         return 0;
2260
2261     /* Check the signature starts with the algorithm name */
2262     getstring(&sig, &siglen, &p, &slen);
2263     if (!p) {
2264         return 0;
2265     }
2266     if (!match_ssh_id(slen, p, ec->signalg->name)) {
2267         return 0;
2268     }
2269
2270     getstring(&sig, &siglen, &p, &slen);
2271     if (ec->publicKey.curve->type == EC_EDWARDS) {
2272         struct ec_point *r;
2273         Bignum s, h;
2274
2275         /* Check that the signature is two times the length of a point */
2276         if (slen != (ec->publicKey.curve->fieldBits / 8) * 2) {
2277             return 0;
2278         }
2279
2280         /* Check it's the 256 bit field so that SHA512 is the correct hash */
2281         if (ec->publicKey.curve->fieldBits != 256) {
2282             return 0;
2283         }
2284
2285         /* Get the signature */
2286         r = ec_point_new(ec->publicKey.curve, NULL, NULL, NULL, 0);
2287         if (!r) {
2288             return 0;
2289         }
2290         if (!decodepoint(p, ec->publicKey.curve->fieldBits / 8, r)) {
2291             ec_point_free(r);
2292             return 0;
2293         }
2294         s = bignum_from_bytes_le((unsigned char*)p + (ec->publicKey.curve->fieldBits / 8),
2295                                  ec->publicKey.curve->fieldBits / 8);
2296
2297         /* Get the hash of the encoded value of R + encoded value of pk + message */
2298         {
2299             int i, pointlen;
2300             unsigned char b;
2301             unsigned char digest[512 / 8];
2302             SHA512_State hs;
2303             SHA512_Init(&hs);
2304
2305             /* Add encoded r (no need to encode it again, it was in the signature) */
2306             SHA512_Bytes(&hs, p, ec->publicKey.curve->fieldBits / 8);
2307
2308             /* Encode pk and add it */
2309             pointlen = ec->publicKey.curve->fieldBits / 8;
2310             for (i = 0; i < pointlen - 1; ++i) {
2311                 b = bignum_byte(ec->publicKey.y, i);
2312                 SHA512_Bytes(&hs, &b, 1);
2313             }
2314             /* Unset last bit of y and set first bit of x in its place */
2315             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
2316             b |= bignum_bit(ec->publicKey.x, 0) << 7;
2317             SHA512_Bytes(&hs, &b, 1);
2318
2319             /* Add the message itself */
2320             SHA512_Bytes(&hs, data, datalen);
2321
2322             /* Get the hash */
2323             SHA512_Final(&hs, digest);
2324
2325             /* Convert to Bignum */
2326             h = bignum_from_bytes_le(digest, sizeof(digest));
2327         }
2328
2329         /* Verify sB == r + h*publicKey */
2330         {
2331             struct ec_point *lhs, *rhs, *tmp;
2332
2333             /* lhs = sB */
2334             lhs = ecp_mul(&ec->publicKey.curve->e.B, s);
2335             freebn(s);
2336             if (!lhs) {
2337                 ec_point_free(r);
2338                 freebn(h);
2339                 return 0;
2340             }
2341
2342             /* rhs = r + h*publicKey */
2343             tmp = ecp_mul(&ec->publicKey, h);
2344             freebn(h);
2345             if (!tmp) {
2346                 ec_point_free(lhs);
2347                 ec_point_free(r);
2348                 return 0;
2349             }
2350             rhs = ecp_add(r, tmp, 0);
2351             ec_point_free(r);
2352             ec_point_free(tmp);
2353             if (!rhs) {
2354                 ec_point_free(lhs);
2355                 return 0;
2356             }
2357
2358             /* Check the point is the same */
2359             ret = !bignum_cmp(lhs->x, rhs->x);
2360             if (ret) {
2361                 ret = !bignum_cmp(lhs->y, rhs->y);
2362                 if (ret) {
2363                     ret = 1;
2364                 }
2365             }
2366             ec_point_free(lhs);
2367             ec_point_free(rhs);
2368         }
2369     } else {
2370         Bignum r, s;
2371         unsigned char digest[512 / 8];
2372         void *hashctx;
2373
2374         r = getmp(&p, &slen);
2375         if (!r) return 0;
2376         s = getmp(&p, &slen);
2377         if (!s) {
2378             freebn(r);
2379             return 0;
2380         }
2381
2382         digestLen = extra->hash->hlen;
2383         assert(digestLen <= sizeof(digest));
2384         hashctx = extra->hash->init();
2385         extra->hash->bytes(hashctx, data, datalen);
2386         extra->hash->final(hashctx, digest);
2387
2388         /* Verify the signature */
2389         ret = _ecdsa_verify(&ec->publicKey, digest, digestLen, r, s);
2390
2391         freebn(r);
2392         freebn(s);
2393     }
2394
2395     return ret;
2396 }
2397
2398 static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
2399                                  int *siglen)
2400 {
2401     struct ec_key *ec = (struct ec_key *) key;
2402     const struct ecsign_extra *extra =
2403         (const struct ecsign_extra *)ec->signalg->extra;
2404     unsigned char digest[512 / 8];
2405     int digestLen;
2406     Bignum r = NULL, s = NULL;
2407     unsigned char *buf, *p;
2408     int rlen, slen, namelen;
2409     int i;
2410
2411     if (!ec->privateKey || !ec->publicKey.curve) {
2412         return NULL;
2413     }
2414
2415     if (ec->publicKey.curve->type == EC_EDWARDS) {
2416         struct ec_point *rp;
2417         int pointlen = ec->publicKey.curve->fieldBits / 8;
2418
2419         /* hash = H(sk) (where hash creates 2 * fieldBits)
2420          * b = fieldBits
2421          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
2422          * r = H(h[b/8:b/4] + m)
2423          * R = rB
2424          * S = (r + H(encodepoint(R) + encodepoint(pk) + m) * a) % l */
2425         {
2426             unsigned char hash[512/8];
2427             unsigned char b;
2428             Bignum a;
2429             SHA512_State hs;
2430             SHA512_Init(&hs);
2431
2432             for (i = 0; i < pointlen; ++i) {
2433                 unsigned char b = (unsigned char)bignum_byte(ec->privateKey, i);
2434                 SHA512_Bytes(&hs, &b, 1);
2435             }
2436
2437             SHA512_Final(&hs, hash);
2438
2439             /* The second part is simply turning the hash into a
2440              * Bignum, however the 2^(b-2) bit *must* be set, and the
2441              * bottom 3 bits *must* not be */
2442             hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
2443             hash[31] &= 0x7f; /* Unset above (b-2) */
2444             hash[31] |= 0x40; /* Set 2^(b-2) */
2445             /* Chop off the top part and convert to int */
2446             a = bignum_from_bytes_le(hash, 32);
2447
2448             SHA512_Init(&hs);
2449             SHA512_Bytes(&hs,
2450                          hash+(ec->publicKey.curve->fieldBits / 8),
2451                          (ec->publicKey.curve->fieldBits / 4)
2452                          - (ec->publicKey.curve->fieldBits / 8));
2453             SHA512_Bytes(&hs, data, datalen);
2454             SHA512_Final(&hs, hash);
2455
2456             r = bignum_from_bytes_le(hash, 512/8);
2457             rp = ecp_mul(&ec->publicKey.curve->e.B, r);
2458             if (!rp) {
2459                 freebn(r);
2460                 freebn(a);
2461                 return NULL;
2462             }
2463
2464             /* Now calculate s */
2465             SHA512_Init(&hs);
2466             /* Encode the point R */
2467             for (i = 0; i < pointlen - 1; ++i) {
2468                 b = bignum_byte(rp->y, i);
2469                 SHA512_Bytes(&hs, &b, 1);
2470             }
2471             /* Unset last bit of y and set first bit of x in its place */
2472             b = bignum_byte(rp->y, i) & 0x7f;
2473             b |= bignum_bit(rp->x, 0) << 7;
2474             SHA512_Bytes(&hs, &b, 1);
2475
2476             /* Encode the point pk */
2477             for (i = 0; i < pointlen - 1; ++i) {
2478                 b = bignum_byte(ec->publicKey.y, i);
2479                 SHA512_Bytes(&hs, &b, 1);
2480             }
2481             /* Unset last bit of y and set first bit of x in its place */
2482             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
2483             b |= bignum_bit(ec->publicKey.x, 0) << 7;
2484             SHA512_Bytes(&hs, &b, 1);
2485
2486             /* Add the message */
2487             SHA512_Bytes(&hs, data, datalen);
2488             SHA512_Final(&hs, hash);
2489
2490             {
2491                 Bignum tmp, tmp2;
2492
2493                 tmp = bignum_from_bytes_le(hash, 512/8);
2494                 tmp2 = modmul(tmp, a, ec->publicKey.curve->e.l);
2495                 freebn(a);
2496                 freebn(tmp);
2497                 tmp = bigadd(r, tmp2);
2498                 freebn(r);
2499                 freebn(tmp2);
2500                 s = bigmod(tmp, ec->publicKey.curve->e.l);
2501                 freebn(tmp);
2502             }
2503         }
2504
2505         /* Format the output */
2506         namelen = strlen(ec->signalg->name);
2507         *siglen = 4+namelen+4+((ec->publicKey.curve->fieldBits / 8)*2);
2508         buf = snewn(*siglen, unsigned char);
2509         p = buf;
2510         PUT_32BIT(p, namelen);
2511         p += 4;
2512         memcpy(p, ec->signalg->name, namelen);
2513         p += namelen;
2514         PUT_32BIT(p, ((ec->publicKey.curve->fieldBits / 8)*2));
2515         p += 4;
2516
2517         /* Encode the point */
2518         pointlen = ec->publicKey.curve->fieldBits / 8;
2519         for (i = 0; i < pointlen - 1; ++i) {
2520             *p++ = bignum_byte(rp->y, i);
2521         }
2522         /* Unset last bit of y and set first bit of x in its place */
2523         *p = bignum_byte(rp->y, i) & 0x7f;
2524         *p++ |= bignum_bit(rp->x, 0) << 7;
2525         ec_point_free(rp);
2526
2527         /* Encode the int */
2528         for (i = 0; i < pointlen; ++i) {
2529             *p++ = bignum_byte(s, i);
2530         }
2531         freebn(s);
2532     } else {
2533         void *hashctx;
2534
2535         digestLen = extra->hash->hlen;
2536         assert(digestLen <= sizeof(digest));
2537         hashctx = extra->hash->init();
2538         extra->hash->bytes(hashctx, data, datalen);
2539         extra->hash->final(hashctx, digest);
2540
2541         /* Do the signature */
2542         _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
2543         if (!r || !s) {
2544             if (r) freebn(r);
2545             if (s) freebn(s);
2546             return NULL;
2547         }
2548
2549         rlen = (bignum_bitcount(r) + 8) / 8;
2550         slen = (bignum_bitcount(s) + 8) / 8;
2551
2552         namelen = strlen(ec->signalg->name);
2553
2554         /* Format the output */
2555         *siglen = 8+namelen+rlen+slen+8;
2556         buf = snewn(*siglen, unsigned char);
2557         p = buf;
2558         PUT_32BIT(p, namelen);
2559         p += 4;
2560         memcpy(p, ec->signalg->name, namelen);
2561         p += namelen;
2562         PUT_32BIT(p, rlen + slen + 8);
2563         p += 4;
2564         PUT_32BIT(p, rlen);
2565         p += 4;
2566         for (i = rlen; i--;)
2567             *p++ = bignum_byte(r, i);
2568         PUT_32BIT(p, slen);
2569         p += 4;
2570         for (i = slen; i--;)
2571             *p++ = bignum_byte(s, i);
2572
2573         freebn(r);
2574         freebn(s);
2575     }
2576
2577     return buf;
2578 }
2579
2580 const struct ecsign_extra sign_extra_ed25519 = {
2581     ec_ed25519, NULL,
2582     NULL, 0,
2583 };
2584 const struct ssh_signkey ssh_ecdsa_ed25519 = {
2585     ecdsa_newkey,
2586     ecdsa_freekey,
2587     ecdsa_fmtkey,
2588     ecdsa_public_blob,
2589     ecdsa_private_blob,
2590     ecdsa_createkey,
2591     ed25519_openssh_createkey,
2592     ed25519_openssh_fmtkey,
2593     2 /* point, private exponent */,
2594     ecdsa_pubkey_bits,
2595     ecdsa_verifysig,
2596     ecdsa_sign,
2597     "ssh-ed25519",
2598     "ssh-ed25519",
2599     &sign_extra_ed25519,
2600 };
2601
2602 /* OID: 1.2.840.10045.3.1.7 (ansiX9p256r1) */
2603 static const unsigned char nistp256_oid[] = {
2604     0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07
2605 };
2606 const struct ecsign_extra sign_extra_nistp256 = {
2607     ec_p256, &ssh_sha256,
2608     nistp256_oid, lenof(nistp256_oid),
2609 };
2610 const struct ssh_signkey ssh_ecdsa_nistp256 = {
2611     ecdsa_newkey,
2612     ecdsa_freekey,
2613     ecdsa_fmtkey,
2614     ecdsa_public_blob,
2615     ecdsa_private_blob,
2616     ecdsa_createkey,
2617     ecdsa_openssh_createkey,
2618     ecdsa_openssh_fmtkey,
2619     3 /* curve name, point, private exponent */,
2620     ecdsa_pubkey_bits,
2621     ecdsa_verifysig,
2622     ecdsa_sign,
2623     "ecdsa-sha2-nistp256",
2624     "ecdsa-sha2-nistp256",
2625     &sign_extra_nistp256,
2626 };
2627
2628 /* OID: 1.3.132.0.34 (secp384r1) */
2629 static const unsigned char nistp384_oid[] = {
2630     0x2b, 0x81, 0x04, 0x00, 0x22
2631 };
2632 const struct ecsign_extra sign_extra_nistp384 = {
2633     ec_p384, &ssh_sha384,
2634     nistp384_oid, lenof(nistp384_oid),
2635 };
2636 const struct ssh_signkey ssh_ecdsa_nistp384 = {
2637     ecdsa_newkey,
2638     ecdsa_freekey,
2639     ecdsa_fmtkey,
2640     ecdsa_public_blob,
2641     ecdsa_private_blob,
2642     ecdsa_createkey,
2643     ecdsa_openssh_createkey,
2644     ecdsa_openssh_fmtkey,
2645     3 /* curve name, point, private exponent */,
2646     ecdsa_pubkey_bits,
2647     ecdsa_verifysig,
2648     ecdsa_sign,
2649     "ecdsa-sha2-nistp384",
2650     "ecdsa-sha2-nistp384",
2651     &sign_extra_nistp384,
2652 };
2653
2654 /* OID: 1.3.132.0.35 (secp521r1) */
2655 static const unsigned char nistp521_oid[] = {
2656     0x2b, 0x81, 0x04, 0x00, 0x23
2657 };
2658 const struct ecsign_extra sign_extra_nistp521 = {
2659     ec_p521, &ssh_sha512,
2660     nistp521_oid, lenof(nistp521_oid),
2661 };
2662 const struct ssh_signkey ssh_ecdsa_nistp521 = {
2663     ecdsa_newkey,
2664     ecdsa_freekey,
2665     ecdsa_fmtkey,
2666     ecdsa_public_blob,
2667     ecdsa_private_blob,
2668     ecdsa_createkey,
2669     ecdsa_openssh_createkey,
2670     ecdsa_openssh_fmtkey,
2671     3 /* curve name, point, private exponent */,
2672     ecdsa_pubkey_bits,
2673     ecdsa_verifysig,
2674     ecdsa_sign,
2675     "ecdsa-sha2-nistp521",
2676     "ecdsa-sha2-nistp521",
2677     &sign_extra_nistp521,
2678 };
2679
2680 /* ----------------------------------------------------------------------
2681  * Exposed ECDH interface
2682  */
2683
2684 struct eckex_extra {
2685     struct ec_curve *(*curve)(void);
2686 };
2687
2688 static Bignum ecdh_calculate(const Bignum private,
2689                              const struct ec_point *public)
2690 {
2691     struct ec_point *p;
2692     Bignum ret;
2693     p = ecp_mul(public, private);
2694     if (!p) return NULL;
2695     ret = p->x;
2696     p->x = NULL;
2697
2698     if (p->curve->type == EC_MONTGOMERY) {
2699         /*
2700          * Endianness-swap. The Curve25519 algorithm definition
2701          * assumes you were doing your computation in arrays of 32
2702          * little-endian bytes, and now specifies that you take your
2703          * final one of those and convert it into a bignum in
2704          * _network_ byte order, i.e. big-endian.
2705          *
2706          * In particular, the spec says, you convert the _whole_ 32
2707          * bytes into a bignum. That is, on the rare occasions that
2708          * p->x has come out with the most significant 8 bits zero, we
2709          * have to imagine that being represented by a 32-byte string
2710          * with the last byte being zero, so that has to be converted
2711          * into an SSH-2 bignum with the _low_ byte zero, i.e. a
2712          * multiple of 256.
2713          */
2714         int i;
2715         int bytes = (p->curve->fieldBits+7) / 8;
2716         unsigned char *byteorder = snewn(bytes, unsigned char);
2717         for (i = 0; i < bytes; ++i) {
2718             byteorder[i] = bignum_byte(ret, i);
2719         }
2720         freebn(ret);
2721         ret = bignum_from_bytes(byteorder, bytes);
2722         smemclr(byteorder, bytes);
2723         sfree(byteorder);
2724     }
2725
2726     ec_point_free(p);
2727     return ret;
2728 }
2729
2730 const char *ssh_ecdhkex_curve_textname(const struct ssh_kex *kex)
2731 {
2732     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
2733     struct ec_curve *curve = extra->curve();
2734     return curve->textname;
2735 }
2736
2737 void *ssh_ecdhkex_newkey(const struct ssh_kex *kex)
2738 {
2739     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
2740     struct ec_curve *curve;
2741     struct ec_key *key;
2742     struct ec_point *publicKey;
2743
2744     curve = extra->curve();
2745
2746     key = snew(struct ec_key);
2747
2748     key->signalg = NULL;
2749     key->publicKey.curve = curve;
2750
2751     if (curve->type == EC_MONTGOMERY) {
2752         unsigned char bytes[32] = {0};
2753         int i;
2754
2755         for (i = 0; i < sizeof(bytes); ++i)
2756         {
2757             bytes[i] = (unsigned char)random_byte();
2758         }
2759         bytes[0] &= 248;
2760         bytes[31] &= 127;
2761         bytes[31] |= 64;
2762         key->privateKey = bignum_from_bytes(bytes, sizeof(bytes));
2763         for (i = 0; i < sizeof(bytes); ++i)
2764         {
2765             ((volatile char*)bytes)[i] = 0;
2766         }
2767         if (!key->privateKey) {
2768             sfree(key);
2769             return NULL;
2770         }
2771         publicKey = ecp_mul(&key->publicKey.curve->m.G, key->privateKey);
2772         if (!publicKey) {
2773             freebn(key->privateKey);
2774             sfree(key);
2775             return NULL;
2776         }
2777         key->publicKey.x = publicKey->x;
2778         key->publicKey.y = publicKey->y;
2779         key->publicKey.z = NULL;
2780         sfree(publicKey);
2781     } else {
2782         key->privateKey = bignum_random_in_range(One, key->publicKey.curve->w.n);
2783         if (!key->privateKey) {
2784             sfree(key);
2785             return NULL;
2786         }
2787         publicKey = ecp_mul(&key->publicKey.curve->w.G, key->privateKey);
2788         if (!publicKey) {
2789             freebn(key->privateKey);
2790             sfree(key);
2791             return NULL;
2792         }
2793         key->publicKey.x = publicKey->x;
2794         key->publicKey.y = publicKey->y;
2795         key->publicKey.z = NULL;
2796         sfree(publicKey);
2797     }
2798     return key;
2799 }
2800
2801 char *ssh_ecdhkex_getpublic(void *key, int *len)
2802 {
2803     struct ec_key *ec = (struct ec_key*)key;
2804     char *point, *p;
2805     int i;
2806     int pointlen;
2807
2808     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2809
2810     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2811         *len = 1 + pointlen * 2;
2812     } else {
2813         *len = pointlen;
2814     }
2815     point = (char*)snewn(*len, char);
2816
2817     p = point;
2818     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2819         *p++ = 0x04;
2820         for (i = pointlen; i--;) {
2821             *p++ = bignum_byte(ec->publicKey.x, i);
2822         }
2823         for (i = pointlen; i--;) {
2824             *p++ = bignum_byte(ec->publicKey.y, i);
2825         }
2826     } else {
2827         for (i = 0; i < pointlen; ++i) {
2828             *p++ = bignum_byte(ec->publicKey.x, i);
2829         }
2830     }
2831
2832     return point;
2833 }
2834
2835 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
2836 {
2837     struct ec_key *ec = (struct ec_key*) key;
2838     struct ec_point remote;
2839     Bignum ret;
2840
2841     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2842         remote.curve = ec->publicKey.curve;
2843         remote.infinity = 0;
2844         if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
2845             return NULL;
2846         }
2847     } else {
2848         /* Point length has to be the same length */
2849         if (remoteKeyLen != (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8) {
2850             return NULL;
2851         }
2852
2853         remote.curve = ec->publicKey.curve;
2854         remote.infinity = 0;
2855         remote.x = bignum_from_bytes_le((unsigned char*)remoteKey, remoteKeyLen);
2856         remote.y = NULL;
2857         remote.z = NULL;
2858     }
2859
2860     ret = ecdh_calculate(ec->privateKey, &remote);
2861     if (remote.x) freebn(remote.x);
2862     if (remote.y) freebn(remote.y);
2863     return ret;
2864 }
2865
2866 void ssh_ecdhkex_freekey(void *key)
2867 {
2868     ecdsa_freekey(key);
2869 }
2870
2871 static const struct eckex_extra kex_extra_curve25519 = { ec_curve25519 };
2872 static const struct ssh_kex ssh_ec_kex_curve25519 = {
2873     "curve25519-sha256@libssh.org", NULL, KEXTYPE_ECDH,
2874     &ssh_sha256, &kex_extra_curve25519,
2875 };
2876
2877 const struct eckex_extra kex_extra_nistp256 = { ec_p256 };
2878 static const struct ssh_kex ssh_ec_kex_nistp256 = {
2879     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH,
2880     &ssh_sha256, &kex_extra_nistp256,
2881 };
2882
2883 const struct eckex_extra kex_extra_nistp384 = { ec_p384 };
2884 static const struct ssh_kex ssh_ec_kex_nistp384 = {
2885     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH,
2886     &ssh_sha384, &kex_extra_nistp384,
2887 };
2888
2889 const struct eckex_extra kex_extra_nistp521 = { ec_p521 };
2890 static const struct ssh_kex ssh_ec_kex_nistp521 = {
2891     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH,
2892     &ssh_sha512, &kex_extra_nistp521,
2893 };
2894
2895 static const struct ssh_kex *const ec_kex_list[] = {
2896     &ssh_ec_kex_curve25519,
2897     &ssh_ec_kex_nistp256,
2898     &ssh_ec_kex_nistp384,
2899     &ssh_ec_kex_nistp521,
2900 };
2901
2902 const struct ssh_kexes ssh_ecdh_kex = {
2903     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
2904     ec_kex_list
2905 };
2906
2907 /* ----------------------------------------------------------------------
2908  * Helper functions for finding key algorithms and returning auxiliary
2909  * data.
2910  */
2911
2912 const struct ssh_signkey *ec_alg_by_oid(int len, const void *oid,
2913                                         const struct ec_curve **curve)
2914 {
2915     static const struct ssh_signkey *algs_with_oid[] = {
2916         &ssh_ecdsa_nistp256,
2917         &ssh_ecdsa_nistp384,
2918         &ssh_ecdsa_nistp521,
2919     };
2920     int i;
2921
2922     for (i = 0; i < lenof(algs_with_oid); i++) {
2923         const struct ssh_signkey *alg = algs_with_oid[i];
2924         const struct ecsign_extra *extra =
2925             (const struct ecsign_extra *)alg->extra;
2926         if (len == extra->oidlen && !memcmp(oid, extra->oid, len)) {
2927             *curve = extra->curve();
2928             return alg;
2929         }
2930     }
2931     return NULL;
2932 }
2933
2934 const unsigned char *ec_alg_oid(const struct ssh_signkey *alg,
2935                                 int *oidlen)
2936 {
2937     const struct ecsign_extra *extra = (const struct ecsign_extra *)alg->extra;
2938     *oidlen = extra->oidlen;
2939     return extra->oid;
2940 }
2941
2942 const int ec_nist_alg_and_curve_by_bits(int bits,
2943                                         const struct ec_curve **curve,
2944                                         const struct ssh_signkey **alg)
2945 {
2946     switch (bits) {
2947       case 256: *alg = &ssh_ecdsa_nistp256; break;
2948       case 384: *alg = &ssh_ecdsa_nistp384; break;
2949       case 521: *alg = &ssh_ecdsa_nistp521; break;
2950       default: return FALSE;
2951     }
2952     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
2953     return TRUE;
2954 }
2955
2956 const int ec_ed_alg_and_curve_by_bits(int bits,
2957                                       const struct ec_curve **curve,
2958                                       const struct ssh_signkey **alg)
2959 {
2960     switch (bits) {
2961       case 256: *alg = &ssh_ecdsa_ed25519; break;
2962       default: return FALSE;
2963     }
2964     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
2965     return TRUE;
2966 }