]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
Implement "curve448-sha512" kex, from draft-ietf-curdle-ssh-curves-00.
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  *
5  * NOTE: Only curves on prime field are handled by the maths functions
6  *       in Weierstrass form using Jacobian co-ordinates.
7  *
8  *       Montgomery form curves are supported for DH. (Curve25519)
9  *
10  *       Edwards form curves are supported for DSA. (Ed25519)
11  */
12
13 /*
14  * References:
15  *
16  * Elliptic curves in SSH are specified in RFC 5656:
17  *   http://tools.ietf.org/html/rfc5656
18  *
19  * That specification delegates details of public key formatting and a
20  * lot of underlying mechanism to SEC 1:
21  *   http://www.secg.org/sec1-v2.pdf
22  *
23  * Montgomery maths from:
24  * Handbook of elliptic and hyperelliptic curve cryptography, Chapter 13
25  *   http://cs.ucsb.edu/~koc/ccs130h/2013/EllipticHyperelliptic-CohenFrey.pdf
26  *
27  * Curve25519 spec from libssh (with reference to other things in the
28  * libssh code):
29  *   https://git.libssh.org/users/aris/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt
30  *
31  * Edwards DSA:
32  *   http://ed25519.cr.yp.to/ed25519-20110926.pdf
33  */
34
35 #include <stdlib.h>
36 #include <assert.h>
37
38 #include "ssh.h"
39
40 /* ----------------------------------------------------------------------
41  * Elliptic curve definitions
42  */
43
44 static void initialise_wcurve(struct ec_curve *curve, int bits,
45                               const unsigned char *p,
46                               const unsigned char *a, const unsigned char *b,
47                               const unsigned char *n, const unsigned char *Gx,
48                               const unsigned char *Gy)
49 {
50     int length = bits / 8;
51     if (bits % 8) ++length;
52
53     curve->type = EC_WEIERSTRASS;
54
55     curve->fieldBits = bits;
56     curve->p = bignum_from_bytes(p, length);
57
58     /* Curve co-efficients */
59     curve->w.a = bignum_from_bytes(a, length);
60     curve->w.b = bignum_from_bytes(b, length);
61
62     /* Group order and generator */
63     curve->w.n = bignum_from_bytes(n, length);
64     curve->w.G.x = bignum_from_bytes(Gx, length);
65     curve->w.G.y = bignum_from_bytes(Gy, length);
66     curve->w.G.curve = curve;
67     curve->w.G.infinity = 0;
68 }
69
70 static void initialise_mcurve(struct ec_curve *curve, int bits,
71                               const unsigned char *p,
72                               const unsigned char *a, const unsigned char *b,
73                               const unsigned char *Gx)
74 {
75     int length = bits / 8;
76     if (bits % 8) ++length;
77
78     curve->type = EC_MONTGOMERY;
79
80     curve->fieldBits = bits;
81     curve->p = bignum_from_bytes(p, length);
82
83     /* Curve co-efficients */
84     curve->m.a = bignum_from_bytes(a, length);
85     curve->m.b = bignum_from_bytes(b, length);
86
87     /* Generator */
88     curve->m.G.x = bignum_from_bytes(Gx, length);
89     curve->m.G.y = NULL;
90     curve->m.G.z = NULL;
91     curve->m.G.curve = curve;
92     curve->m.G.infinity = 0;
93 }
94
95 static void initialise_ecurve(struct ec_curve *curve, int bits,
96                               const unsigned char *p,
97                               const unsigned char *l, const unsigned char *d,
98                               const unsigned char *Bx, const unsigned char *By)
99 {
100     int length = bits / 8;
101     if (bits % 8) ++length;
102
103     curve->type = EC_EDWARDS;
104
105     curve->fieldBits = bits;
106     curve->p = bignum_from_bytes(p, length);
107
108     /* Curve co-efficients */
109     curve->e.l = bignum_from_bytes(l, length);
110     curve->e.d = bignum_from_bytes(d, length);
111
112     /* Group order and generator */
113     curve->e.B.x = bignum_from_bytes(Bx, length);
114     curve->e.B.y = bignum_from_bytes(By, length);
115     curve->e.B.curve = curve;
116     curve->e.B.infinity = 0;
117 }
118
119 static struct ec_curve *ec_p256(void)
120 {
121     static struct ec_curve curve = { 0 };
122     static unsigned char initialised = 0;
123
124     if (!initialised)
125     {
126         static const unsigned char p[] = {
127             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
128             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
129             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
130             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
131         };
132         static const unsigned char a[] = {
133             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
134             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
135             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
136             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
137         };
138         static const unsigned char b[] = {
139             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
140             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
141             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
142             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
143         };
144         static const unsigned char n[] = {
145             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
146             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
147             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
148             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
149         };
150         static const unsigned char Gx[] = {
151             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
152             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
153             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
154             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
155         };
156         static const unsigned char Gy[] = {
157             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
158             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
159             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
160             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
161         };
162
163         initialise_wcurve(&curve, 256, p, a, b, n, Gx, Gy);
164         curve.textname = curve.name = "nistp256";
165
166         /* Now initialised, no need to do it again */
167         initialised = 1;
168     }
169
170     return &curve;
171 }
172
173 static struct ec_curve *ec_p384(void)
174 {
175     static struct ec_curve curve = { 0 };
176     static unsigned char initialised = 0;
177
178     if (!initialised)
179     {
180         static const unsigned char p[] = {
181             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
182             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
183             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
184             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
185             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
186             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
187         };
188         static const unsigned char a[] = {
189             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
190             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
191             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
192             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
193             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
194             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
195         };
196         static const unsigned char b[] = {
197             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
198             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
199             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
200             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
201             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
202             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
203         };
204         static const unsigned char n[] = {
205             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
206             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
207             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
208             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
209             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
210             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
211         };
212         static const unsigned char Gx[] = {
213             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
214             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
215             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
216             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
217             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
218             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
219         };
220         static const unsigned char Gy[] = {
221             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
222             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
223             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
224             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
225             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
226             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
227         };
228
229         initialise_wcurve(&curve, 384, p, a, b, n, Gx, Gy);
230         curve.textname = curve.name = "nistp384";
231
232         /* Now initialised, no need to do it again */
233         initialised = 1;
234     }
235
236     return &curve;
237 }
238
239 static struct ec_curve *ec_p521(void)
240 {
241     static struct ec_curve curve = { 0 };
242     static unsigned char initialised = 0;
243
244     if (!initialised)
245     {
246         static const unsigned char p[] = {
247             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
248             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
249             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
250             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
251             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
252             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
253             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
254             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
255             0xff, 0xff
256         };
257         static const unsigned char a[] = {
258             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
259             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
260             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
261             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
262             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
263             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
264             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
265             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
266             0xff, 0xfc
267         };
268         static const unsigned char b[] = {
269             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
270             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
271             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
272             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
273             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
274             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
275             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
276             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
277             0x3f, 0x00
278         };
279         static const unsigned char n[] = {
280             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
281             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
282             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
283             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
284             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
285             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
286             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
287             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
288             0x64, 0x09
289         };
290         static const unsigned char Gx[] = {
291             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
292             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
293             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
294             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
295             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
296             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
297             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
298             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
299             0xbd, 0x66
300         };
301         static const unsigned char Gy[] = {
302             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
303             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
304             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
305             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
306             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
307             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
308             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
309             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
310             0x66, 0x50
311         };
312
313         initialise_wcurve(&curve, 521, p, a, b, n, Gx, Gy);
314         curve.textname = curve.name = "nistp521";
315
316         /* Now initialised, no need to do it again */
317         initialised = 1;
318     }
319
320     return &curve;
321 }
322
323 static struct ec_curve *ec_curve25519(void)
324 {
325     static struct ec_curve curve = { 0 };
326     static unsigned char initialised = 0;
327
328     if (!initialised)
329     {
330         static const unsigned char p[] = {
331             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
332             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
333             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
334             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
335         };
336         static const unsigned char a[] = {
337             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
338             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
339             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
340             0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x6d, 0x06
341         };
342         static const unsigned char b[] = {
343             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
344             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
345             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
346             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01
347         };
348         static const unsigned char gx[32] = {
349             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
350             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
351             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
352             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09
353         };
354
355         initialise_mcurve(&curve, 256, p, a, b, gx);
356         /* This curve doesn't need a name, because it's never used in
357          * any format that embeds the curve name */
358         curve.name = NULL;
359         curve.textname = "Curve25519";
360
361         /* Now initialised, no need to do it again */
362         initialised = 1;
363     }
364
365     return &curve;
366 }
367
368 static struct ec_curve *ec_curve448(void)
369 {
370     static struct ec_curve curve = { 0 };
371     static unsigned char initialised = 0;
372
373     if (!initialised)
374     {
375         static const unsigned char p[56] = {
376             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
377             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
378             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
379             0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff,
380             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
381             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
382             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
383         };
384         static const unsigned char a[56] = {
385             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
386             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
387             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
388             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
389             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
390             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
391             0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x62, 0xa6,
392         };
393         static const unsigned char b[56] = {
394             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
395             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
396             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
397             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
398             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
399             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
400             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
401         };
402         static const unsigned char gx[56] = {
403             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
404             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
405             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
406             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
407             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
408             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
409             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05,
410         };
411
412         initialise_mcurve(&curve, 448, p, a, b, gx);
413         /* This curve doesn't need a name, because it's never used in
414          * any format that embeds the curve name */
415         curve.name = NULL;
416         curve.textname = "Curve448";
417
418         /* Now initialised, no need to do it again */
419         initialised = 1;
420     }
421
422     return &curve;
423 }
424
425 static struct ec_curve *ec_ed25519(void)
426 {
427     static struct ec_curve curve = { 0 };
428     static unsigned char initialised = 0;
429
430     if (!initialised)
431     {
432         static const unsigned char q[] = {
433             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
434             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
435             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
436             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
437         };
438         static const unsigned char l[32] = {
439             0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
440             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
441             0x14, 0xde, 0xf9, 0xde, 0xa2, 0xf7, 0x9c, 0xd6,
442             0x58, 0x12, 0x63, 0x1a, 0x5c, 0xf5, 0xd3, 0xed
443         };
444         static const unsigned char d[32] = {
445             0x52, 0x03, 0x6c, 0xee, 0x2b, 0x6f, 0xfe, 0x73,
446             0x8c, 0xc7, 0x40, 0x79, 0x77, 0x79, 0xe8, 0x98,
447             0x00, 0x70, 0x0a, 0x4d, 0x41, 0x41, 0xd8, 0xab,
448             0x75, 0xeb, 0x4d, 0xca, 0x13, 0x59, 0x78, 0xa3
449         };
450         static const unsigned char Bx[32] = {
451             0x21, 0x69, 0x36, 0xd3, 0xcd, 0x6e, 0x53, 0xfe,
452             0xc0, 0xa4, 0xe2, 0x31, 0xfd, 0xd6, 0xdc, 0x5c,
453             0x69, 0x2c, 0xc7, 0x60, 0x95, 0x25, 0xa7, 0xb2,
454             0xc9, 0x56, 0x2d, 0x60, 0x8f, 0x25, 0xd5, 0x1a
455         };
456         static const unsigned char By[32] = {
457             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
458             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
459             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
460             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x58
461         };
462
463         /* This curve doesn't need a name, because it's never used in
464          * any format that embeds the curve name */
465         curve.name = NULL;
466
467         initialise_ecurve(&curve, 256, q, l, d, Bx, By);
468         curve.textname = "Ed25519";
469
470         /* Now initialised, no need to do it again */
471         initialised = 1;
472     }
473
474     return &curve;
475 }
476
477 /* Return 1 if a is -3 % p, otherwise return 0
478  * This is used because there are some maths optimisations */
479 static int ec_aminus3(const struct ec_curve *curve)
480 {
481     int ret;
482     Bignum _p;
483
484     if (curve->type != EC_WEIERSTRASS) {
485         return 0;
486     }
487
488     _p = bignum_add_long(curve->w.a, 3);
489
490     ret = !bignum_cmp(curve->p, _p);
491     freebn(_p);
492     return ret;
493 }
494
495 /* ----------------------------------------------------------------------
496  * Elliptic curve field maths
497  */
498
499 static Bignum ecf_add(const Bignum a, const Bignum b,
500                       const struct ec_curve *curve)
501 {
502     Bignum a1, b1, ab, ret;
503
504     a1 = bigmod(a, curve->p);
505     b1 = bigmod(b, curve->p);
506
507     ab = bigadd(a1, b1);
508     freebn(a1);
509     freebn(b1);
510
511     ret = bigmod(ab, curve->p);
512     freebn(ab);
513
514     return ret;
515 }
516
517 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
518 {
519     return modmul(a, a, curve->p);
520 }
521
522 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
523 {
524     Bignum ret, tmp;
525
526     /* Double */
527     tmp = bignum_lshift(a, 1);
528
529     /* Add itself (i.e. treble) */
530     ret = bigadd(tmp, a);
531     freebn(tmp);
532
533     /* Normalise */
534     while (bignum_cmp(ret, curve->p) >= 0)
535     {
536         tmp = bigsub(ret, curve->p);
537         assert(tmp);
538         freebn(ret);
539         ret = tmp;
540     }
541
542     return ret;
543 }
544
545 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
546 {
547     Bignum ret = bignum_lshift(a, 1);
548     if (bignum_cmp(ret, curve->p) >= 0)
549     {
550         Bignum tmp = bigsub(ret, curve->p);
551         assert(tmp);
552         freebn(ret);
553         return tmp;
554     }
555     else
556     {
557         return ret;
558     }
559 }
560
561 /* ----------------------------------------------------------------------
562  * Memory functions
563  */
564
565 void ec_point_free(struct ec_point *point)
566 {
567     if (point == NULL) return;
568     point->curve = 0;
569     if (point->x) freebn(point->x);
570     if (point->y) freebn(point->y);
571     if (point->z) freebn(point->z);
572     point->infinity = 0;
573     sfree(point);
574 }
575
576 static struct ec_point *ec_point_new(const struct ec_curve *curve,
577                                      const Bignum x, const Bignum y, const Bignum z,
578                                      unsigned char infinity)
579 {
580     struct ec_point *point = snewn(1, struct ec_point);
581     point->curve = curve;
582     point->x = x;
583     point->y = y;
584     point->z = z;
585     point->infinity = infinity ? 1 : 0;
586     return point;
587 }
588
589 static struct ec_point *ec_point_copy(const struct ec_point *a)
590 {
591     if (a == NULL) return NULL;
592     return ec_point_new(a->curve,
593                         a->x ? copybn(a->x) : NULL,
594                         a->y ? copybn(a->y) : NULL,
595                         a->z ? copybn(a->z) : NULL,
596                         a->infinity);
597 }
598
599 static int ec_point_verify(const struct ec_point *a)
600 {
601     if (a->infinity) {
602         return 1;
603     } else if (a->curve->type == EC_EDWARDS) {
604         /* Check y^2 - x^2 - 1 - d * x^2 * y^2 == 0 */
605         Bignum y2, x2, tmp, tmp2, tmp3;
606         int ret;
607
608         y2 = ecf_square(a->y, a->curve);
609         x2 = ecf_square(a->x, a->curve);
610         tmp = modmul(a->curve->e.d, x2, a->curve->p);
611         tmp2 = modmul(tmp, y2, a->curve->p);
612         freebn(tmp);
613         tmp = modsub(y2, x2, a->curve->p);
614         freebn(y2);
615         freebn(x2);
616         tmp3 = modsub(tmp, tmp2, a->curve->p);
617         freebn(tmp);
618         freebn(tmp2);
619         ret = !bignum_cmp(tmp3, One);
620         freebn(tmp3);
621         return ret;
622     } else if (a->curve->type == EC_WEIERSTRASS) {
623         /* Verify y^2 = x^3 + ax + b */
624         int ret = 0;
625
626         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
627
628         Bignum Three = bignum_from_long(3);
629
630         lhs = modmul(a->y, a->y, a->curve->p);
631
632         /* This uses montgomery multiplication to optimise */
633         x3 = modpow(a->x, Three, a->curve->p);
634         freebn(Three);
635         ax = modmul(a->curve->w.a, a->x, a->curve->p);
636         x3ax = bigadd(x3, ax);
637         freebn(x3); x3 = NULL;
638         freebn(ax); ax = NULL;
639         x3axm = bigmod(x3ax, a->curve->p);
640         freebn(x3ax); x3ax = NULL;
641         x3axb = bigadd(x3axm, a->curve->w.b);
642         freebn(x3axm); x3axm = NULL;
643         rhs = bigmod(x3axb, a->curve->p);
644         freebn(x3axb);
645
646         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
647         freebn(lhs);
648         freebn(rhs);
649
650         return ret;
651     } else {
652         return 0;
653     }
654 }
655
656 /* ----------------------------------------------------------------------
657  * Elliptic curve point maths
658  */
659
660 /* Returns 1 on success and 0 on memory error */
661 static int ecp_normalise(struct ec_point *a)
662 {
663     if (!a) {
664         /* No point */
665         return 0;
666     }
667
668     if (a->infinity) {
669         /* Point is at infinity - i.e. normalised */
670         return 1;
671     }
672
673     if (a->curve->type == EC_WEIERSTRASS) {
674         /* In Jacobian Coordinates the triple (X, Y, Z) represents
675            the affine point (X / Z^2, Y / Z^3) */
676
677         Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
678
679         if (!a->x || !a->y) {
680             /* No point defined */
681             return 0;
682         } else if (!a->z) {
683             /* Already normalised */
684             return 1;
685         }
686
687         Z2 = ecf_square(a->z, a->curve);
688         Z2inv = modinv(Z2, a->curve->p);
689         if (!Z2inv) {
690             freebn(Z2);
691             return 0;
692         }
693         tx = modmul(a->x, Z2inv, a->curve->p);
694         freebn(Z2inv);
695
696         Z3 = modmul(Z2, a->z, a->curve->p);
697         freebn(Z2);
698         Z3inv = modinv(Z3, a->curve->p);
699         freebn(Z3);
700         if (!Z3inv) {
701             freebn(tx);
702             return 0;
703         }
704         ty = modmul(a->y, Z3inv, a->curve->p);
705         freebn(Z3inv);
706
707         freebn(a->x);
708         a->x = tx;
709         freebn(a->y);
710         a->y = ty;
711         freebn(a->z);
712         a->z = NULL;
713         return 1;
714     } else if (a->curve->type == EC_MONTGOMERY) {
715         /* In Montgomery (X : Z) represents the x co-ord (X / Z, ?) */
716
717         Bignum tmp, tmp2;
718
719         if (!a->x) {
720             /* No point defined */
721             return 0;
722         } else if (!a->z) {
723             /* Already normalised */
724             return 1;
725         }
726
727         tmp = modinv(a->z, a->curve->p);
728         if (!tmp) {
729             return 0;
730         }
731         tmp2 = modmul(a->x, tmp, a->curve->p);
732         freebn(tmp);
733
734         freebn(a->z);
735         a->z = NULL;
736         freebn(a->x);
737         a->x = tmp2;
738         return 1;
739     } else if (a->curve->type == EC_EDWARDS) {
740         /* Always normalised */
741         return 1;
742     } else {
743         return 0;
744     }
745 }
746
747 static struct ec_point *ecp_doublew(const struct ec_point *a, const int aminus3)
748 {
749     Bignum S, M, outx, outy, outz;
750
751     if (bignum_cmp(a->y, Zero) == 0)
752     {
753         /* Identity */
754         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
755     }
756
757     /* S = 4*X*Y^2 */
758     {
759         Bignum Y2, XY2, _2XY2;
760
761         Y2 = ecf_square(a->y, a->curve);
762         XY2 = modmul(a->x, Y2, a->curve->p);
763         freebn(Y2);
764
765         _2XY2 = ecf_double(XY2, a->curve);
766         freebn(XY2);
767         S = ecf_double(_2XY2, a->curve);
768         freebn(_2XY2);
769     }
770
771     /* Faster calculation if a = -3 */
772     if (aminus3) {
773         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
774         Bignum Z2, XpZ2, XmZ2, second;
775
776         if (a->z == NULL) {
777             Z2 = copybn(One);
778         } else {
779             Z2 = ecf_square(a->z, a->curve);
780         }
781
782         XpZ2 = ecf_add(a->x, Z2, a->curve);
783         XmZ2 = modsub(a->x, Z2, a->curve->p);
784         freebn(Z2);
785
786         second = modmul(XpZ2, XmZ2, a->curve->p);
787         freebn(XpZ2);
788         freebn(XmZ2);
789
790         M = ecf_treble(second, a->curve);
791         freebn(second);
792     } else {
793         /* M = 3*X^2 + a*Z^4 */
794         Bignum _3X2, X2, aZ4;
795
796         if (a->z == NULL) {
797             aZ4 = copybn(a->curve->w.a);
798         } else {
799             Bignum Z2, Z4;
800
801             Z2 = ecf_square(a->z, a->curve);
802             Z4 = ecf_square(Z2, a->curve);
803             freebn(Z2);
804             aZ4 = modmul(a->curve->w.a, Z4, a->curve->p);
805             freebn(Z4);
806         }
807
808         X2 = modmul(a->x, a->x, a->curve->p);
809         _3X2 = ecf_treble(X2, a->curve);
810         freebn(X2);
811         M = ecf_add(_3X2, aZ4, a->curve);
812         freebn(_3X2);
813         freebn(aZ4);
814     }
815
816     /* X' = M^2 - 2*S */
817     {
818         Bignum M2, _2S;
819
820         M2 = ecf_square(M, a->curve);
821         _2S = ecf_double(S, a->curve);
822         outx = modsub(M2, _2S, a->curve->p);
823         freebn(M2);
824         freebn(_2S);
825     }
826
827     /* Y' = M*(S - X') - 8*Y^4 */
828     {
829         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
830
831         SX = modsub(S, outx, a->curve->p);
832         freebn(S);
833         MSX = modmul(M, SX, a->curve->p);
834         freebn(SX);
835         freebn(M);
836         Y2 = ecf_square(a->y, a->curve);
837         Y4 = ecf_square(Y2, a->curve);
838         freebn(Y2);
839         Eight = bignum_from_long(8);
840         _8Y4 = modmul(Eight, Y4, a->curve->p);
841         freebn(Eight);
842         freebn(Y4);
843         outy = modsub(MSX, _8Y4, a->curve->p);
844         freebn(MSX);
845         freebn(_8Y4);
846     }
847
848     /* Z' = 2*Y*Z */
849     {
850         Bignum YZ;
851
852         if (a->z == NULL) {
853             YZ = copybn(a->y);
854         } else {
855             YZ = modmul(a->y, a->z, a->curve->p);
856         }
857
858         outz = ecf_double(YZ, a->curve);
859         freebn(YZ);
860     }
861
862     return ec_point_new(a->curve, outx, outy, outz, 0);
863 }
864
865 static struct ec_point *ecp_doublem(const struct ec_point *a)
866 {
867     Bignum z, outx, outz, xpz, xmz;
868
869     z = a->z;
870     if (!z) {
871         z = One;
872     }
873
874     /* 4xz = (x + z)^2 - (x - z)^2 */
875     {
876         Bignum tmp;
877
878         tmp = ecf_add(a->x, z, a->curve);
879         xpz = ecf_square(tmp, a->curve);
880         freebn(tmp);
881
882         tmp = modsub(a->x, z, a->curve->p);
883         xmz = ecf_square(tmp, a->curve);
884         freebn(tmp);
885     }
886
887     /* outx = (x + z)^2 * (x - z)^2 */
888     outx = modmul(xpz, xmz, a->curve->p);
889
890     /* outz = 4xz * ((x - z)^2 + ((A + 2) / 4)*4xz) */
891     {
892         Bignum _4xz, tmp, tmp2, tmp3;
893
894         tmp = bignum_from_long(2);
895         tmp2 = ecf_add(a->curve->m.a, tmp, a->curve);
896         freebn(tmp);
897
898         _4xz = modsub(xpz, xmz, a->curve->p);
899         freebn(xpz);
900         tmp = modmul(tmp2, _4xz, a->curve->p);
901         freebn(tmp2);
902
903         tmp2 = bignum_from_long(4);
904         tmp3 = modinv(tmp2, a->curve->p);
905         freebn(tmp2);
906         if (!tmp3) {
907             freebn(tmp);
908             freebn(_4xz);
909             freebn(outx);
910             freebn(xmz);
911             return NULL;
912         }
913         tmp2 = modmul(tmp, tmp3, a->curve->p);
914         freebn(tmp);
915         freebn(tmp3);
916
917         tmp = ecf_add(xmz, tmp2, a->curve);
918         freebn(xmz);
919         freebn(tmp2);
920         outz = modmul(_4xz, tmp, a->curve->p);
921         freebn(_4xz);
922         freebn(tmp);
923     }
924
925     return ec_point_new(a->curve, outx, NULL, outz, 0);
926 }
927
928 /* Forward declaration for Edwards curve doubling */
929 static struct ec_point *ecp_add(const struct ec_point *a,
930                                 const struct ec_point *b,
931                                 const int aminus3);
932
933 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
934 {
935     if (a->infinity)
936     {
937         /* Identity */
938         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
939     }
940
941     if (a->curve->type == EC_EDWARDS)
942     {
943         return ecp_add(a, a, aminus3);
944     }
945     else if (a->curve->type == EC_WEIERSTRASS)
946     {
947         return ecp_doublew(a, aminus3);
948     }
949     else
950     {
951         return ecp_doublem(a);
952     }
953 }
954
955 static struct ec_point *ecp_addw(const struct ec_point *a,
956                                  const struct ec_point *b,
957                                  const int aminus3)
958 {
959     Bignum U1, U2, S1, S2, outx, outy, outz;
960
961     /* U1 = X1*Z2^2 */
962     /* S1 = Y1*Z2^3 */
963     if (b->z) {
964         Bignum Z2, Z3;
965
966         Z2 = ecf_square(b->z, a->curve);
967         U1 = modmul(a->x, Z2, a->curve->p);
968         Z3 = modmul(Z2, b->z, a->curve->p);
969         freebn(Z2);
970         S1 = modmul(a->y, Z3, a->curve->p);
971         freebn(Z3);
972     } else {
973         U1 = copybn(a->x);
974         S1 = copybn(a->y);
975     }
976
977     /* U2 = X2*Z1^2 */
978     /* S2 = Y2*Z1^3 */
979     if (a->z) {
980         Bignum Z2, Z3;
981
982         Z2 = ecf_square(a->z, b->curve);
983         U2 = modmul(b->x, Z2, b->curve->p);
984         Z3 = modmul(Z2, a->z, b->curve->p);
985         freebn(Z2);
986         S2 = modmul(b->y, Z3, b->curve->p);
987         freebn(Z3);
988     } else {
989         U2 = copybn(b->x);
990         S2 = copybn(b->y);
991     }
992
993     /* Check if multiplying by self */
994     if (bignum_cmp(U1, U2) == 0)
995     {
996         freebn(U1);
997         freebn(U2);
998         if (bignum_cmp(S1, S2) == 0)
999         {
1000             freebn(S1);
1001             freebn(S2);
1002             return ecp_double(a, aminus3);
1003         }
1004         else
1005         {
1006             freebn(S1);
1007             freebn(S2);
1008             /* Infinity */
1009             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
1010         }
1011     }
1012
1013     {
1014         Bignum H, R, UH2, H3;
1015
1016         /* H = U2 - U1 */
1017         H = modsub(U2, U1, a->curve->p);
1018         freebn(U2);
1019
1020         /* R = S2 - S1 */
1021         R = modsub(S2, S1, a->curve->p);
1022         freebn(S2);
1023
1024         /* X3 = R^2 - H^3 - 2*U1*H^2 */
1025         {
1026             Bignum R2, H2, _2UH2, first;
1027
1028             H2 = ecf_square(H, a->curve);
1029             UH2 = modmul(U1, H2, a->curve->p);
1030             freebn(U1);
1031             H3 = modmul(H2, H, a->curve->p);
1032             freebn(H2);
1033             R2 = ecf_square(R, a->curve);
1034             _2UH2 = ecf_double(UH2, a->curve);
1035             first = modsub(R2, H3, a->curve->p);
1036             freebn(R2);
1037             outx = modsub(first, _2UH2, a->curve->p);
1038             freebn(first);
1039             freebn(_2UH2);
1040         }
1041
1042         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
1043         {
1044             Bignum RUH2mX, UH2mX, SH3;
1045
1046             UH2mX = modsub(UH2, outx, a->curve->p);
1047             freebn(UH2);
1048             RUH2mX = modmul(R, UH2mX, a->curve->p);
1049             freebn(UH2mX);
1050             freebn(R);
1051             SH3 = modmul(S1, H3, a->curve->p);
1052             freebn(S1);
1053             freebn(H3);
1054
1055             outy = modsub(RUH2mX, SH3, a->curve->p);
1056             freebn(RUH2mX);
1057             freebn(SH3);
1058         }
1059
1060         /* Z3 = H*Z1*Z2 */
1061         if (a->z && b->z) {
1062             Bignum ZZ;
1063
1064             ZZ = modmul(a->z, b->z, a->curve->p);
1065             outz = modmul(H, ZZ, a->curve->p);
1066             freebn(H);
1067             freebn(ZZ);
1068         } else if (a->z) {
1069             outz = modmul(H, a->z, a->curve->p);
1070             freebn(H);
1071         } else if (b->z) {
1072             outz = modmul(H, b->z, a->curve->p);
1073             freebn(H);
1074         } else {
1075             outz = H;
1076         }
1077     }
1078
1079     return ec_point_new(a->curve, outx, outy, outz, 0);
1080 }
1081
1082 static struct ec_point *ecp_addm(const struct ec_point *a,
1083                                  const struct ec_point *b,
1084                                  const struct ec_point *base)
1085 {
1086     Bignum outx, outz, az, bz;
1087
1088     az = a->z;
1089     if (!az) {
1090         az = One;
1091     }
1092     bz = b->z;
1093     if (!bz) {
1094         bz = One;
1095     }
1096
1097     /* a-b is maintained at 1 due to Montgomery ladder implementation */
1098     /* Xa+b = Za-b * ((Xa - Za)*(Xb + Zb) + (Xa + Za)*(Xb - Zb))^2 */
1099     /* Za+b = Xa-b * ((Xa - Za)*(Xb + Zb) - (Xa + Za)*(Xb - Zb))^2 */
1100     {
1101         Bignum tmp, tmp2, tmp3, tmp4;
1102
1103         /* (Xa + Za) * (Xb - Zb) */
1104         tmp = ecf_add(a->x, az, a->curve);
1105         tmp2 = modsub(b->x, bz, a->curve->p);
1106         tmp3 = modmul(tmp, tmp2, a->curve->p);
1107         freebn(tmp);
1108         freebn(tmp2);
1109
1110         /* (Xa - Za) * (Xb + Zb) */
1111         tmp = modsub(a->x, az, a->curve->p);
1112         tmp2 = ecf_add(b->x, bz, a->curve);
1113         tmp4 = modmul(tmp, tmp2, a->curve->p);
1114         freebn(tmp);
1115         freebn(tmp2);
1116
1117         tmp = ecf_add(tmp3, tmp4, a->curve);
1118         outx = ecf_square(tmp, a->curve);
1119         freebn(tmp);
1120
1121         tmp = modsub(tmp3, tmp4, a->curve->p);
1122         freebn(tmp3);
1123         freebn(tmp4);
1124         tmp2 = ecf_square(tmp, a->curve);
1125         freebn(tmp);
1126         outz = modmul(base->x, tmp2, a->curve->p);
1127         freebn(tmp2);
1128     }
1129
1130     return ec_point_new(a->curve, outx, NULL, outz, 0);
1131 }
1132
1133 static struct ec_point *ecp_adde(const struct ec_point *a,
1134                                  const struct ec_point *b)
1135 {
1136     Bignum outx, outy, dmul;
1137
1138     /* outx = (a->x * b->y + b->x * a->y) /
1139      *        (1 + a->curve->e.d * a->x * b->x * a->y * b->y) */
1140     {
1141         Bignum tmp, tmp2, tmp3, tmp4;
1142
1143         tmp = modmul(a->x, b->y, a->curve->p);
1144         tmp2 = modmul(b->x, a->y, a->curve->p);
1145         tmp3 = ecf_add(tmp, tmp2, a->curve);
1146
1147         tmp4 = modmul(tmp, tmp2, a->curve->p);
1148         freebn(tmp);
1149         freebn(tmp2);
1150         dmul = modmul(a->curve->e.d, tmp4, a->curve->p);
1151         freebn(tmp4);
1152
1153         tmp = ecf_add(One, dmul, a->curve);
1154         tmp2 = modinv(tmp, a->curve->p);
1155         freebn(tmp);
1156         if (!tmp2)
1157         {
1158             freebn(tmp3);
1159             freebn(dmul);
1160             return NULL;
1161         }
1162
1163         outx = modmul(tmp3, tmp2, a->curve->p);
1164         freebn(tmp3);
1165         freebn(tmp2);
1166     }
1167
1168     /* outy = (a->y * b->y + a->x * b->x) /
1169      *        (1 - a->curve->e.d * a->x * b->x * a->y * b->y) */
1170     {
1171         Bignum tmp, tmp2, tmp3, tmp4;
1172
1173         tmp = modsub(One, dmul, a->curve->p);
1174         freebn(dmul);
1175
1176         tmp2 = modinv(tmp, a->curve->p);
1177         freebn(tmp);
1178         if (!tmp2)
1179         {
1180             freebn(outx);
1181             return NULL;
1182         }
1183
1184         tmp = modmul(a->y, b->y, a->curve->p);
1185         tmp3 = modmul(a->x, b->x, a->curve->p);
1186         tmp4 = ecf_add(tmp, tmp3, a->curve);
1187         freebn(tmp);
1188         freebn(tmp3);
1189
1190         outy = modmul(tmp4, tmp2, a->curve->p);
1191         freebn(tmp4);
1192         freebn(tmp2);
1193     }
1194
1195     return ec_point_new(a->curve, outx, outy, NULL, 0);
1196 }
1197
1198 static struct ec_point *ecp_add(const struct ec_point *a,
1199                                 const struct ec_point *b,
1200                                 const int aminus3)
1201 {
1202     if (a->curve != b->curve) {
1203         return NULL;
1204     }
1205
1206     /* Check if multiplying by infinity */
1207     if (a->infinity) return ec_point_copy(b);
1208     if (b->infinity) return ec_point_copy(a);
1209
1210     if (a->curve->type == EC_EDWARDS)
1211     {
1212         return ecp_adde(a, b);
1213     }
1214
1215     if (a->curve->type == EC_WEIERSTRASS)
1216     {
1217         return ecp_addw(a, b, aminus3);
1218     }
1219
1220     return NULL;
1221 }
1222
1223 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1224 {
1225     struct ec_point *A, *ret;
1226     int bits, i;
1227
1228     A = ec_point_copy(a);
1229     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1230
1231     bits = bignum_bitcount(b);
1232     for (i = 0; i < bits; ++i)
1233     {
1234         if (bignum_bit(b, i))
1235         {
1236             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1237             ec_point_free(ret);
1238             ret = tmp;
1239         }
1240         if (i+1 != bits)
1241         {
1242             struct ec_point *tmp = ecp_double(A, aminus3);
1243             ec_point_free(A);
1244             A = tmp;
1245         }
1246     }
1247
1248     ec_point_free(A);
1249     return ret;
1250 }
1251
1252 static struct ec_point *ecp_mulw(const struct ec_point *a, const Bignum b)
1253 {
1254     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1255
1256     if (!ecp_normalise(ret)) {
1257         ec_point_free(ret);
1258         return NULL;
1259     }
1260
1261     return ret;
1262 }
1263
1264 static struct ec_point *ecp_mule(const struct ec_point *a, const Bignum b)
1265 {
1266     int i;
1267     struct ec_point *ret;
1268
1269     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1270
1271     for (i = bignum_bitcount(b); i >= 0 && ret; --i)
1272     {
1273         {
1274             struct ec_point *tmp = ecp_double(ret, 0);
1275             ec_point_free(ret);
1276             ret = tmp;
1277         }
1278         if (ret && bignum_bit(b, i))
1279         {
1280             struct ec_point *tmp = ecp_add(ret, a, 0);
1281             ec_point_free(ret);
1282             ret = tmp;
1283         }
1284     }
1285
1286     return ret;
1287 }
1288
1289 static struct ec_point *ecp_mulm(const struct ec_point *p, const Bignum n)
1290 {
1291     struct ec_point *P1, *P2;
1292     int bits, i;
1293
1294     /* P1 <- P and P2 <- [2]P */
1295     P2 = ecp_double(p, 0);
1296     P1 = ec_point_copy(p);
1297
1298     /* for i = bits âˆ’ 2 down to 0 */
1299     bits = bignum_bitcount(n);
1300     for (i = bits - 2; i >= 0; --i)
1301     {
1302         if (!bignum_bit(n, i))
1303         {
1304             /* P2 <- P1 + P2 */
1305             struct ec_point *tmp = ecp_addm(P1, P2, p);
1306             ec_point_free(P2);
1307             P2 = tmp;
1308
1309             /* P1 <- [2]P1 */
1310             tmp = ecp_double(P1, 0);
1311             ec_point_free(P1);
1312             P1 = tmp;
1313         }
1314         else
1315         {
1316             /* P1 <- P1 + P2 */
1317             struct ec_point *tmp = ecp_addm(P1, P2, p);
1318             ec_point_free(P1);
1319             P1 = tmp;
1320
1321             /* P2 <- [2]P2 */
1322             tmp = ecp_double(P2, 0);
1323             ec_point_free(P2);
1324             P2 = tmp;
1325         }
1326     }
1327
1328     ec_point_free(P2);
1329
1330     if (!ecp_normalise(P1)) {
1331         ec_point_free(P1);
1332         return NULL;
1333     }
1334
1335     return P1;
1336 }
1337
1338 /* Not static because it is used by sshecdsag.c to generate a new key */
1339 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1340 {
1341     if (a->curve->type == EC_WEIERSTRASS) {
1342         return ecp_mulw(a, b);
1343     } else if (a->curve->type == EC_EDWARDS) {
1344         return ecp_mule(a, b);
1345     } else {
1346         return ecp_mulm(a, b);
1347     }
1348 }
1349
1350 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
1351                                    const struct ec_point *point)
1352 {
1353     struct ec_point *aG, *bP, *ret;
1354     int aminus3;
1355
1356     if (point->curve->type != EC_WEIERSTRASS) {
1357         return NULL;
1358     }
1359
1360     aminus3 = ec_aminus3(point->curve);
1361
1362     aG = ecp_mul_(&point->curve->w.G, a, aminus3);
1363     if (!aG) return NULL;
1364     bP = ecp_mul_(point, b, aminus3);
1365     if (!bP) {
1366         ec_point_free(aG);
1367         return NULL;
1368     }
1369
1370     ret = ecp_add(aG, bP, aminus3);
1371
1372     ec_point_free(aG);
1373     ec_point_free(bP);
1374
1375     if (!ecp_normalise(ret)) {
1376         ec_point_free(ret);
1377         return NULL;
1378     }
1379
1380     return ret;
1381 }
1382 static Bignum *ecp_edx(const struct ec_curve *curve, const Bignum y)
1383 {
1384     /* Get the x value on the given Edwards curve for a given y */
1385     Bignum x, xx;
1386
1387     /* xx = (y^2 - 1) / (d * y^2 + 1) */
1388     {
1389         Bignum tmp, tmp2, tmp3;
1390
1391         tmp = ecf_square(y, curve);
1392         tmp2 = modmul(curve->e.d, tmp, curve->p);
1393         tmp3 = ecf_add(tmp2, One, curve);
1394         freebn(tmp2);
1395         tmp2 = modinv(tmp3, curve->p);
1396         freebn(tmp3);
1397         if (!tmp2) {
1398             freebn(tmp);
1399             return NULL;
1400         }
1401
1402         tmp3 = modsub(tmp, One, curve->p);
1403         freebn(tmp);
1404         xx = modmul(tmp3, tmp2, curve->p);
1405         freebn(tmp3);
1406         freebn(tmp2);
1407     }
1408
1409     /* x = xx^((p + 3) / 8) */
1410     {
1411         Bignum tmp, tmp2;
1412
1413         tmp = bignum_add_long(curve->p, 3);
1414         tmp2 = bignum_rshift(tmp, 3);
1415         freebn(tmp);
1416         x = modpow(xx, tmp2, curve->p);
1417         freebn(tmp2);
1418     }
1419
1420     /* if x^2 - xx != 0 then x = x*(2^((p - 1) / 4)) */
1421     {
1422         Bignum tmp, tmp2;
1423
1424         tmp = ecf_square(x, curve);
1425         tmp2 = modsub(tmp, xx, curve->p);
1426         freebn(tmp);
1427         freebn(xx);
1428         if (bignum_cmp(tmp2, Zero)) {
1429             Bignum tmp3;
1430
1431             freebn(tmp2);
1432
1433             tmp = modsub(curve->p, One, curve->p);
1434             tmp2 = bignum_rshift(tmp, 2);
1435             freebn(tmp);
1436             tmp = bignum_from_long(2);
1437             tmp3 = modpow(tmp, tmp2, curve->p);
1438             freebn(tmp);
1439             freebn(tmp2);
1440
1441             tmp = modmul(x, tmp3, curve->p);
1442             freebn(x);
1443             freebn(tmp3);
1444             x = tmp;
1445         } else {
1446             freebn(tmp2);
1447         }
1448     }
1449
1450     /* if x % 2 != 0 then x = p - x */
1451     if (bignum_bit(x, 0)) {
1452         Bignum tmp = modsub(curve->p, x, curve->p);
1453         freebn(x);
1454         x = tmp;
1455     }
1456
1457     return x;
1458 }
1459
1460 /* ----------------------------------------------------------------------
1461  * Public point from private
1462  */
1463
1464 struct ec_point *ec_public(const Bignum privateKey, const struct ec_curve *curve)
1465 {
1466     if (curve->type == EC_WEIERSTRASS) {
1467         return ecp_mul(&curve->w.G, privateKey);
1468     } else if (curve->type == EC_EDWARDS) {
1469         /* hash = H(sk) (where hash creates 2 * fieldBits)
1470          * b = fieldBits
1471          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
1472          * publicKey = aB */
1473         struct ec_point *ret;
1474         unsigned char hash[512/8];
1475         Bignum a;
1476         int i, keylen;
1477         SHA512_State s;
1478         SHA512_Init(&s);
1479
1480         keylen = curve->fieldBits / 8;
1481         for (i = 0; i < keylen; ++i) {
1482             unsigned char b = bignum_byte(privateKey, i);
1483             SHA512_Bytes(&s, &b, 1);
1484         }
1485         SHA512_Final(&s, hash);
1486
1487         /* The second part is simply turning the hash into a Bignum,
1488          * however the 2^(b-2) bit *must* be set, and the bottom 3
1489          * bits *must* not be */
1490         hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
1491         hash[31] &= 0x7f; /* Unset above (b-2) */
1492         hash[31] |= 0x40; /* Set 2^(b-2) */
1493         /* Chop off the top part and convert to int */
1494         a = bignum_from_bytes_le(hash, 32);
1495
1496         ret = ecp_mul(&curve->e.B, a);
1497         freebn(a);
1498         return ret;
1499     } else {
1500         return NULL;
1501     }
1502 }
1503
1504 /* ----------------------------------------------------------------------
1505  * Basic sign and verify routines
1506  */
1507
1508 static int _ecdsa_verify(const struct ec_point *publicKey,
1509                          const unsigned char *data, const int dataLen,
1510                          const Bignum r, const Bignum s)
1511 {
1512     int z_bits, n_bits;
1513     Bignum z;
1514     int valid = 0;
1515
1516     if (publicKey->curve->type != EC_WEIERSTRASS) {
1517         return 0;
1518     }
1519
1520     /* Sanity checks */
1521     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->w.n) >= 0
1522         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->w.n) >= 0)
1523     {
1524         return 0;
1525     }
1526
1527     /* z = left most bitlen(curve->n) of data */
1528     z = bignum_from_bytes(data, dataLen);
1529     n_bits = bignum_bitcount(publicKey->curve->w.n);
1530     z_bits = bignum_bitcount(z);
1531     if (z_bits > n_bits)
1532     {
1533         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
1534         freebn(z);
1535         z = tmp;
1536     }
1537
1538     /* Ensure z in range of n */
1539     {
1540         Bignum tmp = bigmod(z, publicKey->curve->w.n);
1541         freebn(z);
1542         z = tmp;
1543     }
1544
1545     /* Calculate signature */
1546     {
1547         Bignum w, x, u1, u2;
1548         struct ec_point *tmp;
1549
1550         w = modinv(s, publicKey->curve->w.n);
1551         if (!w) {
1552             freebn(z);
1553             return 0;
1554         }
1555         u1 = modmul(z, w, publicKey->curve->w.n);
1556         u2 = modmul(r, w, publicKey->curve->w.n);
1557         freebn(w);
1558
1559         tmp = ecp_summul(u1, u2, publicKey);
1560         freebn(u1);
1561         freebn(u2);
1562         if (!tmp) {
1563             freebn(z);
1564             return 0;
1565         }
1566
1567         x = bigmod(tmp->x, publicKey->curve->w.n);
1568         ec_point_free(tmp);
1569
1570         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
1571         freebn(x);
1572     }
1573
1574     freebn(z);
1575
1576     return valid;
1577 }
1578
1579 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
1580                         const unsigned char *data, const int dataLen,
1581                         Bignum *r, Bignum *s)
1582 {
1583     unsigned char digest[20];
1584     int z_bits, n_bits;
1585     Bignum z, k;
1586     struct ec_point *kG;
1587
1588     *r = NULL;
1589     *s = NULL;
1590
1591     if (curve->type != EC_WEIERSTRASS) {
1592         return;
1593     }
1594
1595     /* z = left most bitlen(curve->n) of data */
1596     z = bignum_from_bytes(data, dataLen);
1597     n_bits = bignum_bitcount(curve->w.n);
1598     z_bits = bignum_bitcount(z);
1599     if (z_bits > n_bits)
1600     {
1601         Bignum tmp;
1602         tmp = bignum_rshift(z, z_bits - n_bits);
1603         freebn(z);
1604         z = tmp;
1605     }
1606
1607     /* Generate k between 1 and curve->n, using the same deterministic
1608      * k generation system we use for conventional DSA. */
1609     SHA_Simple(data, dataLen, digest);
1610     k = dss_gen_k("ECDSA deterministic k generator", curve->w.n, privateKey,
1611                   digest, sizeof(digest));
1612
1613     kG = ecp_mul(&curve->w.G, k);
1614     if (!kG) {
1615         freebn(z);
1616         freebn(k);
1617         return;
1618     }
1619
1620     /* r = kG.x mod n */
1621     *r = bigmod(kG->x, curve->w.n);
1622     ec_point_free(kG);
1623
1624     /* s = (z + r * priv)/k mod n */
1625     {
1626         Bignum rPriv, zMod, first, firstMod, kInv;
1627         rPriv = modmul(*r, privateKey, curve->w.n);
1628         zMod = bigmod(z, curve->w.n);
1629         freebn(z);
1630         first = bigadd(rPriv, zMod);
1631         freebn(rPriv);
1632         freebn(zMod);
1633         firstMod = bigmod(first, curve->w.n);
1634         freebn(first);
1635         kInv = modinv(k, curve->w.n);
1636         freebn(k);
1637         if (!kInv) {
1638             freebn(firstMod);
1639             freebn(*r);
1640             return;
1641         }
1642         *s = modmul(firstMod, kInv, curve->w.n);
1643         freebn(firstMod);
1644         freebn(kInv);
1645     }
1646 }
1647
1648 /* ----------------------------------------------------------------------
1649  * Misc functions
1650  */
1651
1652 static void getstring(const char **data, int *datalen,
1653                       const char **p, int *length)
1654 {
1655     *p = NULL;
1656     if (*datalen < 4)
1657         return;
1658     *length = toint(GET_32BIT(*data));
1659     if (*length < 0)
1660         return;
1661     *datalen -= 4;
1662     *data += 4;
1663     if (*datalen < *length)
1664         return;
1665     *p = *data;
1666     *data += *length;
1667     *datalen -= *length;
1668 }
1669
1670 static Bignum getmp(const char **data, int *datalen)
1671 {
1672     const char *p;
1673     int length;
1674
1675     getstring(data, datalen, &p, &length);
1676     if (!p)
1677         return NULL;
1678     if (p[0] & 0x80)
1679         return NULL;                   /* negative mp */
1680     return bignum_from_bytes((unsigned char *)p, length);
1681 }
1682
1683 static Bignum getmp_le(const char **data, int *datalen)
1684 {
1685     const char *p;
1686     int length;
1687
1688     getstring(data, datalen, &p, &length);
1689     if (!p)
1690         return NULL;
1691     return bignum_from_bytes_le((const unsigned char *)p, length);
1692 }
1693
1694 static int decodepoint_ed(const char *p, int length, struct ec_point *point)
1695 {
1696     /* Got some conversion to do, first read in the y co-ord */
1697     int negative;
1698
1699     point->y = bignum_from_bytes_le((const unsigned char*)p, length);
1700     if ((unsigned)bignum_bitcount(point->y) > point->curve->fieldBits) {
1701         freebn(point->y);
1702         point->y = NULL;
1703         return 0;
1704     }
1705     /* Read x bit and then reset it */
1706     negative = bignum_bit(point->y, point->curve->fieldBits - 1);
1707     bignum_set_bit(point->y, point->curve->fieldBits - 1, 0);
1708     bn_restore_invariant(point->y);
1709
1710     /* Get the x from the y */
1711     point->x = ecp_edx(point->curve, point->y);
1712     if (!point->x) {
1713         freebn(point->y);
1714         point->y = NULL;
1715         return 0;
1716     }
1717     if (negative) {
1718         Bignum tmp = modsub(point->curve->p, point->x, point->curve->p);
1719         freebn(point->x);
1720         point->x = tmp;
1721     }
1722
1723     /* Verify the point is on the curve */
1724     if (!ec_point_verify(point)) {
1725         freebn(point->x);
1726         point->x = NULL;
1727         freebn(point->y);
1728         point->y = NULL;
1729         return 0;
1730     }
1731
1732     return 1;
1733 }
1734
1735 static int decodepoint(const char *p, int length, struct ec_point *point)
1736 {
1737     if (point->curve->type == EC_EDWARDS) {
1738         return decodepoint_ed(p, length, point);
1739     }
1740
1741     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
1742         return 0;
1743     /* Skip compression flag */
1744     ++p;
1745     --length;
1746     /* The two values must be equal length */
1747     if (length % 2 != 0) {
1748         point->x = NULL;
1749         point->y = NULL;
1750         point->z = NULL;
1751         return 0;
1752     }
1753     length = length / 2;
1754     point->x = bignum_from_bytes((const unsigned char *)p, length);
1755     p += length;
1756     point->y = bignum_from_bytes((const unsigned char *)p, length);
1757     point->z = NULL;
1758
1759     /* Verify the point is on the curve */
1760     if (!ec_point_verify(point)) {
1761         freebn(point->x);
1762         point->x = NULL;
1763         freebn(point->y);
1764         point->y = NULL;
1765         return 0;
1766     }
1767
1768     return 1;
1769 }
1770
1771 static int getmppoint(const char **data, int *datalen, struct ec_point *point)
1772 {
1773     const char *p;
1774     int length;
1775
1776     getstring(data, datalen, &p, &length);
1777     if (!p) return 0;
1778     return decodepoint(p, length, point);
1779 }
1780
1781 /* ----------------------------------------------------------------------
1782  * Exposed ECDSA interface
1783  */
1784
1785 struct ecsign_extra {
1786     struct ec_curve *(*curve)(void);
1787     const struct ssh_hash *hash;
1788
1789     /* These fields are used by the OpenSSH PEM format importer/exporter */
1790     const unsigned char *oid;
1791     int oidlen;
1792 };
1793
1794 static void ecdsa_freekey(void *key)
1795 {
1796     struct ec_key *ec = (struct ec_key *) key;
1797     if (!ec) return;
1798
1799     if (ec->publicKey.x)
1800         freebn(ec->publicKey.x);
1801     if (ec->publicKey.y)
1802         freebn(ec->publicKey.y);
1803     if (ec->publicKey.z)
1804         freebn(ec->publicKey.z);
1805     if (ec->privateKey)
1806         freebn(ec->privateKey);
1807     sfree(ec);
1808 }
1809
1810 static void *ecdsa_newkey(const struct ssh_signkey *self,
1811                           const char *data, int len)
1812 {
1813     const struct ecsign_extra *extra =
1814         (const struct ecsign_extra *)self->extra;
1815     const char *p;
1816     int slen;
1817     struct ec_key *ec;
1818     struct ec_curve *curve;
1819
1820     getstring(&data, &len, &p, &slen);
1821
1822     if (!p) {
1823         return NULL;
1824     }
1825     curve = extra->curve();
1826     assert(curve->type == EC_WEIERSTRASS || curve->type == EC_EDWARDS);
1827
1828     /* Curve name is duplicated for Weierstrass form */
1829     if (curve->type == EC_WEIERSTRASS) {
1830         getstring(&data, &len, &p, &slen);
1831         if (!p) return NULL;
1832         if (!match_ssh_id(slen, p, curve->name)) return NULL;
1833     }
1834
1835     ec = snew(struct ec_key);
1836
1837     ec->signalg = self;
1838     ec->publicKey.curve = curve;
1839     ec->publicKey.infinity = 0;
1840     ec->publicKey.x = NULL;
1841     ec->publicKey.y = NULL;
1842     ec->publicKey.z = NULL;
1843     ec->privateKey = NULL;
1844     if (!getmppoint(&data, &len, &ec->publicKey)) {
1845         ecdsa_freekey(ec);
1846         return NULL;
1847     }
1848
1849     if (!ec->publicKey.x || !ec->publicKey.y ||
1850         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
1851         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
1852     {
1853         ecdsa_freekey(ec);
1854         ec = NULL;
1855     }
1856
1857     return ec;
1858 }
1859
1860 static char *ecdsa_fmtkey(void *key)
1861 {
1862     struct ec_key *ec = (struct ec_key *) key;
1863     char *p;
1864     int len, i, pos, nibbles;
1865     static const char hex[] = "0123456789abcdef";
1866     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
1867         return NULL;
1868
1869     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
1870     if (ec->publicKey.curve->name)
1871         len += strlen(ec->publicKey.curve->name); /* Curve name */
1872     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
1873     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
1874     p = snewn(len, char);
1875
1876     pos = 0;
1877     if (ec->publicKey.curve->name)
1878         pos += sprintf(p + pos, "%s,", ec->publicKey.curve->name);
1879     pos += sprintf(p + pos, "0x");
1880     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
1881     if (nibbles < 1)
1882         nibbles = 1;
1883     for (i = nibbles; i--;) {
1884         p[pos++] =
1885             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
1886     }
1887     pos += sprintf(p + pos, ",0x");
1888     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
1889     if (nibbles < 1)
1890         nibbles = 1;
1891     for (i = nibbles; i--;) {
1892         p[pos++] =
1893             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
1894     }
1895     p[pos] = '\0';
1896     return p;
1897 }
1898
1899 static unsigned char *ecdsa_public_blob(void *key, int *len)
1900 {
1901     struct ec_key *ec = (struct ec_key *) key;
1902     int pointlen, bloblen, fullnamelen, namelen;
1903     int i;
1904     unsigned char *blob, *p;
1905
1906     fullnamelen = strlen(ec->signalg->name);
1907
1908     if (ec->publicKey.curve->type == EC_EDWARDS) {
1909         /* Edwards compressed form "ssh-ed25519" point y[:-1] + x[0:1] */
1910
1911         pointlen = ec->publicKey.curve->fieldBits / 8;
1912
1913         /* Can't handle this in our loop */
1914         if (pointlen < 2) return NULL;
1915
1916         bloblen = 4 + fullnamelen + 4 + pointlen;
1917         blob = snewn(bloblen, unsigned char);
1918
1919         p = blob;
1920         PUT_32BIT(p, fullnamelen);
1921         p += 4;
1922         memcpy(p, ec->signalg->name, fullnamelen);
1923         p += fullnamelen;
1924         PUT_32BIT(p, pointlen);
1925         p += 4;
1926
1927         /* Unset last bit of y and set first bit of x in its place */
1928         for (i = 0; i < pointlen - 1; ++i) {
1929             *p++ = bignum_byte(ec->publicKey.y, i);
1930         }
1931         /* Unset last bit of y and set first bit of x in its place */
1932         *p = bignum_byte(ec->publicKey.y, i) & 0x7f;
1933         *p++ |= bignum_bit(ec->publicKey.x, 0) << 7;
1934     } else if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
1935         assert(ec->publicKey.curve->name);
1936         namelen = strlen(ec->publicKey.curve->name);
1937
1938         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1939
1940         /*
1941          * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
1942          */
1943         bloblen = 4 + fullnamelen + 4 + namelen + 4 + 1 + (pointlen * 2);
1944         blob = snewn(bloblen, unsigned char);
1945
1946         p = blob;
1947         PUT_32BIT(p, fullnamelen);
1948         p += 4;
1949         memcpy(p, ec->signalg->name, fullnamelen);
1950         p += fullnamelen;
1951         PUT_32BIT(p, namelen);
1952         p += 4;
1953         memcpy(p, ec->publicKey.curve->name, namelen);
1954         p += namelen;
1955         PUT_32BIT(p, (2 * pointlen) + 1);
1956         p += 4;
1957         *p++ = 0x04;
1958         for (i = pointlen; i--;) {
1959             *p++ = bignum_byte(ec->publicKey.x, i);
1960         }
1961         for (i = pointlen; i--;) {
1962             *p++ = bignum_byte(ec->publicKey.y, i);
1963         }
1964     } else {
1965         return NULL;
1966     }
1967
1968     assert(p == blob + bloblen);
1969     *len = bloblen;
1970
1971     return blob;
1972 }
1973
1974 static unsigned char *ecdsa_private_blob(void *key, int *len)
1975 {
1976     struct ec_key *ec = (struct ec_key *) key;
1977     int keylen, bloblen;
1978     int i;
1979     unsigned char *blob, *p;
1980
1981     if (!ec->privateKey) return NULL;
1982
1983     if (ec->publicKey.curve->type == EC_EDWARDS) {
1984         /* Unsigned */
1985         keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
1986     } else {
1987         /* Signed */
1988         keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
1989     }
1990
1991     /*
1992      * mpint privateKey. Total 4 + keylen.
1993      */
1994     bloblen = 4 + keylen;
1995     blob = snewn(bloblen, unsigned char);
1996
1997     p = blob;
1998     PUT_32BIT(p, keylen);
1999     p += 4;
2000     if (ec->publicKey.curve->type == EC_EDWARDS) {
2001         /* Little endian */
2002         for (i = 0; i < keylen; ++i)
2003             *p++ = bignum_byte(ec->privateKey, i);
2004     } else {
2005         for (i = keylen; i--;)
2006             *p++ = bignum_byte(ec->privateKey, i);
2007     }
2008
2009     assert(p == blob + bloblen);
2010     *len = bloblen;
2011     return blob;
2012 }
2013
2014 static void *ecdsa_createkey(const struct ssh_signkey *self,
2015                              const unsigned char *pub_blob, int pub_len,
2016                              const unsigned char *priv_blob, int priv_len)
2017 {
2018     struct ec_key *ec;
2019     struct ec_point *publicKey;
2020     const char *pb = (const char *) priv_blob;
2021
2022     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) pub_blob, pub_len);
2023     if (!ec) {
2024         return NULL;
2025     }
2026
2027     if (ec->publicKey.curve->type != EC_WEIERSTRASS
2028         && ec->publicKey.curve->type != EC_EDWARDS) {
2029         ecdsa_freekey(ec);
2030         return NULL;
2031     }
2032
2033     if (ec->publicKey.curve->type == EC_EDWARDS) {
2034         ec->privateKey = getmp_le(&pb, &priv_len);
2035     } else {
2036         ec->privateKey = getmp(&pb, &priv_len);
2037     }
2038     if (!ec->privateKey) {
2039         ecdsa_freekey(ec);
2040         return NULL;
2041     }
2042
2043     /* Check that private key generates public key */
2044     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2045
2046     if (!publicKey ||
2047         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2048         bignum_cmp(publicKey->y, ec->publicKey.y))
2049     {
2050         ecdsa_freekey(ec);
2051         ec = NULL;
2052     }
2053     ec_point_free(publicKey);
2054
2055     return ec;
2056 }
2057
2058 static void *ed25519_openssh_createkey(const struct ssh_signkey *self,
2059                                        const unsigned char **blob, int *len)
2060 {
2061     struct ec_key *ec;
2062     struct ec_point *publicKey;
2063     const char *p, *q;
2064     int plen, qlen;
2065
2066     getstring((const char**)blob, len, &p, &plen);
2067     if (!p)
2068     {
2069         return NULL;
2070     }
2071
2072     ec = snew(struct ec_key);
2073
2074     ec->signalg = self;
2075     ec->publicKey.curve = ec_ed25519();
2076     ec->publicKey.infinity = 0;
2077     ec->privateKey = NULL;
2078     ec->publicKey.x = NULL;
2079     ec->publicKey.z = NULL;
2080     ec->publicKey.y = NULL;
2081
2082     if (!decodepoint_ed(p, plen, &ec->publicKey))
2083     {
2084         ecdsa_freekey(ec);
2085         return NULL;
2086     }
2087
2088     getstring((const char**)blob, len, &q, &qlen);
2089     if (!q)
2090         return NULL;
2091     if (qlen != 64)
2092         return NULL;
2093
2094     ec->privateKey = bignum_from_bytes_le((const unsigned char *)q, 32);
2095
2096     /* Check that private key generates public key */
2097     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2098
2099     if (!publicKey ||
2100         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2101         bignum_cmp(publicKey->y, ec->publicKey.y))
2102     {
2103         ecdsa_freekey(ec);
2104         ec = NULL;
2105     }
2106     ec_point_free(publicKey);
2107
2108     /* The OpenSSH format for ed25519 private keys also for some
2109      * reason encodes an extra copy of the public key in the second
2110      * half of the secret-key string. Check that that's present and
2111      * correct as well, otherwise the key we think we've imported
2112      * won't behave identically to the way OpenSSH would have treated
2113      * it. */
2114     if (plen != 32 || 0 != memcmp(q + 32, p, 32)) {
2115         ecdsa_freekey(ec);
2116         return NULL;
2117     }
2118
2119     return ec;
2120 }
2121
2122 static int ed25519_openssh_fmtkey(void *key, unsigned char *blob, int len)
2123 {
2124     struct ec_key *ec = (struct ec_key *) key;
2125
2126     int pointlen;
2127     int keylen;
2128     int bloblen;
2129     int i;
2130
2131     if (ec->publicKey.curve->type != EC_EDWARDS) {
2132         return 0;
2133     }
2134
2135     pointlen = (bignum_bitcount(ec->publicKey.y) + 7) / 8;
2136     keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2137     bloblen = 4 + pointlen + 4 + keylen + pointlen;
2138
2139     if (bloblen > len)
2140         return bloblen;
2141
2142     /* Encode the public point */
2143     PUT_32BIT(blob, pointlen);
2144     blob += 4;
2145
2146     for (i = 0; i < pointlen - 1; ++i) {
2147          *blob++ = bignum_byte(ec->publicKey.y, i);
2148     }
2149     /* Unset last bit of y and set first bit of x in its place */
2150     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2151     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2152
2153     PUT_32BIT(blob, keylen + pointlen);
2154     blob += 4;
2155     for (i = 0; i < keylen; ++i) {
2156          *blob++ = bignum_byte(ec->privateKey, i);
2157     }
2158     /* Now encode an extra copy of the public point as the second half
2159      * of the private key string, as the OpenSSH format for some
2160      * reason requires */
2161     for (i = 0; i < pointlen - 1; ++i) {
2162          *blob++ = bignum_byte(ec->publicKey.y, i);
2163     }
2164     /* Unset last bit of y and set first bit of x in its place */
2165     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2166     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2167
2168     return bloblen;
2169 }
2170
2171 static void *ecdsa_openssh_createkey(const struct ssh_signkey *self,
2172                                      const unsigned char **blob, int *len)
2173 {
2174     const struct ecsign_extra *extra =
2175         (const struct ecsign_extra *)self->extra;
2176     const char **b = (const char **) blob;
2177     const char *p;
2178     int slen;
2179     struct ec_key *ec;
2180     struct ec_curve *curve;
2181     struct ec_point *publicKey;
2182
2183     getstring(b, len, &p, &slen);
2184
2185     if (!p) {
2186         return NULL;
2187     }
2188     curve = extra->curve();
2189     assert(curve->type == EC_WEIERSTRASS);
2190
2191     ec = snew(struct ec_key);
2192
2193     ec->signalg = self;
2194     ec->publicKey.curve = curve;
2195     ec->publicKey.infinity = 0;
2196     ec->publicKey.x = NULL;
2197     ec->publicKey.y = NULL;
2198     ec->publicKey.z = NULL;
2199     if (!getmppoint(b, len, &ec->publicKey)) {
2200         ecdsa_freekey(ec);
2201         return NULL;
2202     }
2203     ec->privateKey = NULL;
2204
2205     if (!ec->publicKey.x || !ec->publicKey.y ||
2206         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2207         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2208     {
2209         ecdsa_freekey(ec);
2210         return NULL;
2211     }
2212
2213     ec->privateKey = getmp(b, len);
2214     if (ec->privateKey == NULL)
2215     {
2216         ecdsa_freekey(ec);
2217         return NULL;
2218     }
2219
2220     /* Now check that the private key makes the public key */
2221     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2222     if (!publicKey)
2223     {
2224         ecdsa_freekey(ec);
2225         return NULL;
2226     }
2227
2228     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
2229         bignum_cmp(ec->publicKey.y, publicKey->y))
2230     {
2231         /* Private key doesn't make the public key on the given curve */
2232         ecdsa_freekey(ec);
2233         ec_point_free(publicKey);
2234         return NULL;
2235     }
2236
2237     ec_point_free(publicKey);
2238
2239     return ec;
2240 }
2241
2242 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
2243 {
2244     struct ec_key *ec = (struct ec_key *) key;
2245
2246     int pointlen;
2247     int namelen;
2248     int bloblen;
2249     int i;
2250
2251     if (ec->publicKey.curve->type != EC_WEIERSTRASS) {
2252         return 0;
2253     }
2254
2255     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2256     namelen = strlen(ec->publicKey.curve->name);
2257     bloblen =
2258         4 + namelen /* <LEN> nistpXXX */
2259         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
2260         + ssh2_bignum_length(ec->privateKey);
2261
2262     if (bloblen > len)
2263         return bloblen;
2264
2265     bloblen = 0;
2266
2267     PUT_32BIT(blob+bloblen, namelen);
2268     bloblen += 4;
2269     memcpy(blob+bloblen, ec->publicKey.curve->name, namelen);
2270     bloblen += namelen;
2271
2272     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
2273     bloblen += 4;
2274     blob[bloblen++] = 0x04;
2275     for (i = pointlen; i--; )
2276         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
2277     for (i = pointlen; i--; )
2278         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
2279
2280     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2281     PUT_32BIT(blob+bloblen, pointlen);
2282     bloblen += 4;
2283     for (i = pointlen; i--; )
2284         blob[bloblen++] = bignum_byte(ec->privateKey, i);
2285
2286     return bloblen;
2287 }
2288
2289 static int ecdsa_pubkey_bits(const struct ssh_signkey *self,
2290                              const void *blob, int len)
2291 {
2292     struct ec_key *ec;
2293     int ret;
2294
2295     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) blob, len);
2296     if (!ec)
2297         return -1;
2298     ret = ec->publicKey.curve->fieldBits;
2299     ecdsa_freekey(ec);
2300
2301     return ret;
2302 }
2303
2304 static int ecdsa_verifysig(void *key, const char *sig, int siglen,
2305                            const char *data, int datalen)
2306 {
2307     struct ec_key *ec = (struct ec_key *) key;
2308     const struct ecsign_extra *extra =
2309         (const struct ecsign_extra *)ec->signalg->extra;
2310     const char *p;
2311     int slen;
2312     int digestLen;
2313     int ret;
2314
2315     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2316         return 0;
2317
2318     /* Check the signature starts with the algorithm name */
2319     getstring(&sig, &siglen, &p, &slen);
2320     if (!p) {
2321         return 0;
2322     }
2323     if (!match_ssh_id(slen, p, ec->signalg->name)) {
2324         return 0;
2325     }
2326
2327     getstring(&sig, &siglen, &p, &slen);
2328     if (!p) return 0;
2329     if (ec->publicKey.curve->type == EC_EDWARDS) {
2330         struct ec_point *r;
2331         Bignum s, h;
2332
2333         /* Check that the signature is two times the length of a point */
2334         if (slen != (ec->publicKey.curve->fieldBits / 8) * 2) {
2335             return 0;
2336         }
2337
2338         /* Check it's the 256 bit field so that SHA512 is the correct hash */
2339         if (ec->publicKey.curve->fieldBits != 256) {
2340             return 0;
2341         }
2342
2343         /* Get the signature */
2344         r = ec_point_new(ec->publicKey.curve, NULL, NULL, NULL, 0);
2345         if (!r) {
2346             return 0;
2347         }
2348         if (!decodepoint(p, ec->publicKey.curve->fieldBits / 8, r)) {
2349             ec_point_free(r);
2350             return 0;
2351         }
2352         s = bignum_from_bytes_le((unsigned char*)p + (ec->publicKey.curve->fieldBits / 8),
2353                                  ec->publicKey.curve->fieldBits / 8);
2354
2355         /* Get the hash of the encoded value of R + encoded value of pk + message */
2356         {
2357             int i, pointlen;
2358             unsigned char b;
2359             unsigned char digest[512 / 8];
2360             SHA512_State hs;
2361             SHA512_Init(&hs);
2362
2363             /* Add encoded r (no need to encode it again, it was in the signature) */
2364             SHA512_Bytes(&hs, p, ec->publicKey.curve->fieldBits / 8);
2365
2366             /* Encode pk and add it */
2367             pointlen = ec->publicKey.curve->fieldBits / 8;
2368             for (i = 0; i < pointlen - 1; ++i) {
2369                 b = bignum_byte(ec->publicKey.y, i);
2370                 SHA512_Bytes(&hs, &b, 1);
2371             }
2372             /* Unset last bit of y and set first bit of x in its place */
2373             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
2374             b |= bignum_bit(ec->publicKey.x, 0) << 7;
2375             SHA512_Bytes(&hs, &b, 1);
2376
2377             /* Add the message itself */
2378             SHA512_Bytes(&hs, data, datalen);
2379
2380             /* Get the hash */
2381             SHA512_Final(&hs, digest);
2382
2383             /* Convert to Bignum */
2384             h = bignum_from_bytes_le(digest, sizeof(digest));
2385         }
2386
2387         /* Verify sB == r + h*publicKey */
2388         {
2389             struct ec_point *lhs, *rhs, *tmp;
2390
2391             /* lhs = sB */
2392             lhs = ecp_mul(&ec->publicKey.curve->e.B, s);
2393             freebn(s);
2394             if (!lhs) {
2395                 ec_point_free(r);
2396                 freebn(h);
2397                 return 0;
2398             }
2399
2400             /* rhs = r + h*publicKey */
2401             tmp = ecp_mul(&ec->publicKey, h);
2402             freebn(h);
2403             if (!tmp) {
2404                 ec_point_free(lhs);
2405                 ec_point_free(r);
2406                 return 0;
2407             }
2408             rhs = ecp_add(r, tmp, 0);
2409             ec_point_free(r);
2410             ec_point_free(tmp);
2411             if (!rhs) {
2412                 ec_point_free(lhs);
2413                 return 0;
2414             }
2415
2416             /* Check the point is the same */
2417             ret = !bignum_cmp(lhs->x, rhs->x);
2418             if (ret) {
2419                 ret = !bignum_cmp(lhs->y, rhs->y);
2420                 if (ret) {
2421                     ret = 1;
2422                 }
2423             }
2424             ec_point_free(lhs);
2425             ec_point_free(rhs);
2426         }
2427     } else {
2428         Bignum r, s;
2429         unsigned char digest[512 / 8];
2430         void *hashctx;
2431
2432         r = getmp(&p, &slen);
2433         if (!r) return 0;
2434         s = getmp(&p, &slen);
2435         if (!s) {
2436             freebn(r);
2437             return 0;
2438         }
2439
2440         digestLen = extra->hash->hlen;
2441         assert(digestLen <= sizeof(digest));
2442         hashctx = extra->hash->init();
2443         extra->hash->bytes(hashctx, data, datalen);
2444         extra->hash->final(hashctx, digest);
2445
2446         /* Verify the signature */
2447         ret = _ecdsa_verify(&ec->publicKey, digest, digestLen, r, s);
2448
2449         freebn(r);
2450         freebn(s);
2451     }
2452
2453     return ret;
2454 }
2455
2456 static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
2457                                  int *siglen)
2458 {
2459     struct ec_key *ec = (struct ec_key *) key;
2460     const struct ecsign_extra *extra =
2461         (const struct ecsign_extra *)ec->signalg->extra;
2462     unsigned char digest[512 / 8];
2463     int digestLen;
2464     Bignum r = NULL, s = NULL;
2465     unsigned char *buf, *p;
2466     int rlen, slen, namelen;
2467     int i;
2468
2469     if (!ec->privateKey || !ec->publicKey.curve) {
2470         return NULL;
2471     }
2472
2473     if (ec->publicKey.curve->type == EC_EDWARDS) {
2474         struct ec_point *rp;
2475         int pointlen = ec->publicKey.curve->fieldBits / 8;
2476
2477         /* hash = H(sk) (where hash creates 2 * fieldBits)
2478          * b = fieldBits
2479          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
2480          * r = H(h[b/8:b/4] + m)
2481          * R = rB
2482          * S = (r + H(encodepoint(R) + encodepoint(pk) + m) * a) % l */
2483         {
2484             unsigned char hash[512/8];
2485             unsigned char b;
2486             Bignum a;
2487             SHA512_State hs;
2488             SHA512_Init(&hs);
2489
2490             for (i = 0; i < pointlen; ++i) {
2491                 unsigned char b = (unsigned char)bignum_byte(ec->privateKey, i);
2492                 SHA512_Bytes(&hs, &b, 1);
2493             }
2494
2495             SHA512_Final(&hs, hash);
2496
2497             /* The second part is simply turning the hash into a
2498              * Bignum, however the 2^(b-2) bit *must* be set, and the
2499              * bottom 3 bits *must* not be */
2500             hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
2501             hash[31] &= 0x7f; /* Unset above (b-2) */
2502             hash[31] |= 0x40; /* Set 2^(b-2) */
2503             /* Chop off the top part and convert to int */
2504             a = bignum_from_bytes_le(hash, 32);
2505
2506             SHA512_Init(&hs);
2507             SHA512_Bytes(&hs,
2508                          hash+(ec->publicKey.curve->fieldBits / 8),
2509                          (ec->publicKey.curve->fieldBits / 4)
2510                          - (ec->publicKey.curve->fieldBits / 8));
2511             SHA512_Bytes(&hs, data, datalen);
2512             SHA512_Final(&hs, hash);
2513
2514             r = bignum_from_bytes_le(hash, 512/8);
2515             rp = ecp_mul(&ec->publicKey.curve->e.B, r);
2516             if (!rp) {
2517                 freebn(r);
2518                 freebn(a);
2519                 return NULL;
2520             }
2521
2522             /* Now calculate s */
2523             SHA512_Init(&hs);
2524             /* Encode the point R */
2525             for (i = 0; i < pointlen - 1; ++i) {
2526                 b = bignum_byte(rp->y, i);
2527                 SHA512_Bytes(&hs, &b, 1);
2528             }
2529             /* Unset last bit of y and set first bit of x in its place */
2530             b = bignum_byte(rp->y, i) & 0x7f;
2531             b |= bignum_bit(rp->x, 0) << 7;
2532             SHA512_Bytes(&hs, &b, 1);
2533
2534             /* Encode the point pk */
2535             for (i = 0; i < pointlen - 1; ++i) {
2536                 b = bignum_byte(ec->publicKey.y, i);
2537                 SHA512_Bytes(&hs, &b, 1);
2538             }
2539             /* Unset last bit of y and set first bit of x in its place */
2540             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
2541             b |= bignum_bit(ec->publicKey.x, 0) << 7;
2542             SHA512_Bytes(&hs, &b, 1);
2543
2544             /* Add the message */
2545             SHA512_Bytes(&hs, data, datalen);
2546             SHA512_Final(&hs, hash);
2547
2548             {
2549                 Bignum tmp, tmp2;
2550
2551                 tmp = bignum_from_bytes_le(hash, 512/8);
2552                 tmp2 = modmul(tmp, a, ec->publicKey.curve->e.l);
2553                 freebn(a);
2554                 freebn(tmp);
2555                 tmp = bigadd(r, tmp2);
2556                 freebn(r);
2557                 freebn(tmp2);
2558                 s = bigmod(tmp, ec->publicKey.curve->e.l);
2559                 freebn(tmp);
2560             }
2561         }
2562
2563         /* Format the output */
2564         namelen = strlen(ec->signalg->name);
2565         *siglen = 4+namelen+4+((ec->publicKey.curve->fieldBits / 8)*2);
2566         buf = snewn(*siglen, unsigned char);
2567         p = buf;
2568         PUT_32BIT(p, namelen);
2569         p += 4;
2570         memcpy(p, ec->signalg->name, namelen);
2571         p += namelen;
2572         PUT_32BIT(p, ((ec->publicKey.curve->fieldBits / 8)*2));
2573         p += 4;
2574
2575         /* Encode the point */
2576         pointlen = ec->publicKey.curve->fieldBits / 8;
2577         for (i = 0; i < pointlen - 1; ++i) {
2578             *p++ = bignum_byte(rp->y, i);
2579         }
2580         /* Unset last bit of y and set first bit of x in its place */
2581         *p = bignum_byte(rp->y, i) & 0x7f;
2582         *p++ |= bignum_bit(rp->x, 0) << 7;
2583         ec_point_free(rp);
2584
2585         /* Encode the int */
2586         for (i = 0; i < pointlen; ++i) {
2587             *p++ = bignum_byte(s, i);
2588         }
2589         freebn(s);
2590     } else {
2591         void *hashctx;
2592
2593         digestLen = extra->hash->hlen;
2594         assert(digestLen <= sizeof(digest));
2595         hashctx = extra->hash->init();
2596         extra->hash->bytes(hashctx, data, datalen);
2597         extra->hash->final(hashctx, digest);
2598
2599         /* Do the signature */
2600         _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
2601         if (!r || !s) {
2602             if (r) freebn(r);
2603             if (s) freebn(s);
2604             return NULL;
2605         }
2606
2607         rlen = (bignum_bitcount(r) + 8) / 8;
2608         slen = (bignum_bitcount(s) + 8) / 8;
2609
2610         namelen = strlen(ec->signalg->name);
2611
2612         /* Format the output */
2613         *siglen = 8+namelen+rlen+slen+8;
2614         buf = snewn(*siglen, unsigned char);
2615         p = buf;
2616         PUT_32BIT(p, namelen);
2617         p += 4;
2618         memcpy(p, ec->signalg->name, namelen);
2619         p += namelen;
2620         PUT_32BIT(p, rlen + slen + 8);
2621         p += 4;
2622         PUT_32BIT(p, rlen);
2623         p += 4;
2624         for (i = rlen; i--;)
2625             *p++ = bignum_byte(r, i);
2626         PUT_32BIT(p, slen);
2627         p += 4;
2628         for (i = slen; i--;)
2629             *p++ = bignum_byte(s, i);
2630
2631         freebn(r);
2632         freebn(s);
2633     }
2634
2635     return buf;
2636 }
2637
2638 const struct ecsign_extra sign_extra_ed25519 = {
2639     ec_ed25519, NULL,
2640     NULL, 0,
2641 };
2642 const struct ssh_signkey ssh_ecdsa_ed25519 = {
2643     ecdsa_newkey,
2644     ecdsa_freekey,
2645     ecdsa_fmtkey,
2646     ecdsa_public_blob,
2647     ecdsa_private_blob,
2648     ecdsa_createkey,
2649     ed25519_openssh_createkey,
2650     ed25519_openssh_fmtkey,
2651     2 /* point, private exponent */,
2652     ecdsa_pubkey_bits,
2653     ecdsa_verifysig,
2654     ecdsa_sign,
2655     "ssh-ed25519",
2656     "ssh-ed25519",
2657     &sign_extra_ed25519,
2658 };
2659
2660 /* OID: 1.2.840.10045.3.1.7 (ansiX9p256r1) */
2661 static const unsigned char nistp256_oid[] = {
2662     0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07
2663 };
2664 const struct ecsign_extra sign_extra_nistp256 = {
2665     ec_p256, &ssh_sha256,
2666     nistp256_oid, lenof(nistp256_oid),
2667 };
2668 const struct ssh_signkey ssh_ecdsa_nistp256 = {
2669     ecdsa_newkey,
2670     ecdsa_freekey,
2671     ecdsa_fmtkey,
2672     ecdsa_public_blob,
2673     ecdsa_private_blob,
2674     ecdsa_createkey,
2675     ecdsa_openssh_createkey,
2676     ecdsa_openssh_fmtkey,
2677     3 /* curve name, point, private exponent */,
2678     ecdsa_pubkey_bits,
2679     ecdsa_verifysig,
2680     ecdsa_sign,
2681     "ecdsa-sha2-nistp256",
2682     "ecdsa-sha2-nistp256",
2683     &sign_extra_nistp256,
2684 };
2685
2686 /* OID: 1.3.132.0.34 (secp384r1) */
2687 static const unsigned char nistp384_oid[] = {
2688     0x2b, 0x81, 0x04, 0x00, 0x22
2689 };
2690 const struct ecsign_extra sign_extra_nistp384 = {
2691     ec_p384, &ssh_sha384,
2692     nistp384_oid, lenof(nistp384_oid),
2693 };
2694 const struct ssh_signkey ssh_ecdsa_nistp384 = {
2695     ecdsa_newkey,
2696     ecdsa_freekey,
2697     ecdsa_fmtkey,
2698     ecdsa_public_blob,
2699     ecdsa_private_blob,
2700     ecdsa_createkey,
2701     ecdsa_openssh_createkey,
2702     ecdsa_openssh_fmtkey,
2703     3 /* curve name, point, private exponent */,
2704     ecdsa_pubkey_bits,
2705     ecdsa_verifysig,
2706     ecdsa_sign,
2707     "ecdsa-sha2-nistp384",
2708     "ecdsa-sha2-nistp384",
2709     &sign_extra_nistp384,
2710 };
2711
2712 /* OID: 1.3.132.0.35 (secp521r1) */
2713 static const unsigned char nistp521_oid[] = {
2714     0x2b, 0x81, 0x04, 0x00, 0x23
2715 };
2716 const struct ecsign_extra sign_extra_nistp521 = {
2717     ec_p521, &ssh_sha512,
2718     nistp521_oid, lenof(nistp521_oid),
2719 };
2720 const struct ssh_signkey ssh_ecdsa_nistp521 = {
2721     ecdsa_newkey,
2722     ecdsa_freekey,
2723     ecdsa_fmtkey,
2724     ecdsa_public_blob,
2725     ecdsa_private_blob,
2726     ecdsa_createkey,
2727     ecdsa_openssh_createkey,
2728     ecdsa_openssh_fmtkey,
2729     3 /* curve name, point, private exponent */,
2730     ecdsa_pubkey_bits,
2731     ecdsa_verifysig,
2732     ecdsa_sign,
2733     "ecdsa-sha2-nistp521",
2734     "ecdsa-sha2-nistp521",
2735     &sign_extra_nistp521,
2736 };
2737
2738 /* ----------------------------------------------------------------------
2739  * Exposed ECDH interface
2740  */
2741
2742 struct eckex_extra {
2743     struct ec_curve *(*curve)(void);
2744     int low_byte_mask, high_byte_top_bit;
2745 };
2746
2747 static Bignum ecdh_calculate(const Bignum private,
2748                              const struct ec_point *public)
2749 {
2750     struct ec_point *p;
2751     Bignum ret;
2752     p = ecp_mul(public, private);
2753     if (!p) return NULL;
2754     ret = p->x;
2755     p->x = NULL;
2756
2757     if (p->curve->type == EC_MONTGOMERY) {
2758         /*
2759          * Endianness-swap. The Curve25519 algorithm definition
2760          * assumes you were doing your computation in arrays of 32
2761          * little-endian bytes, and now specifies that you take your
2762          * final one of those and convert it into a bignum in
2763          * _network_ byte order, i.e. big-endian.
2764          *
2765          * In particular, the spec says, you convert the _whole_ 32
2766          * bytes into a bignum. That is, on the rare occasions that
2767          * p->x has come out with the most significant 8 bits zero, we
2768          * have to imagine that being represented by a 32-byte string
2769          * with the last byte being zero, so that has to be converted
2770          * into an SSH-2 bignum with the _low_ byte zero, i.e. a
2771          * multiple of 256.
2772          */
2773         int i;
2774         int bytes = (p->curve->fieldBits+7) / 8;
2775         unsigned char *byteorder = snewn(bytes, unsigned char);
2776         for (i = 0; i < bytes; ++i) {
2777             byteorder[i] = bignum_byte(ret, i);
2778         }
2779         freebn(ret);
2780         ret = bignum_from_bytes(byteorder, bytes);
2781         smemclr(byteorder, bytes);
2782         sfree(byteorder);
2783     }
2784
2785     ec_point_free(p);
2786     return ret;
2787 }
2788
2789 const char *ssh_ecdhkex_curve_textname(const struct ssh_kex *kex)
2790 {
2791     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
2792     struct ec_curve *curve = extra->curve();
2793     return curve->textname;
2794 }
2795
2796 void *ssh_ecdhkex_newkey(const struct ssh_kex *kex)
2797 {
2798     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
2799     struct ec_curve *curve;
2800     struct ec_key *key;
2801     struct ec_point *publicKey;
2802
2803     curve = extra->curve();
2804
2805     key = snew(struct ec_key);
2806
2807     key->signalg = NULL;
2808     key->publicKey.curve = curve;
2809
2810     if (curve->type == EC_MONTGOMERY) {
2811         int nbytes = (curve->fieldBits+7) / 8;
2812         unsigned char bytes[56] = {0};
2813         int i;
2814
2815         assert(nbytes <= lenof(bytes));
2816
2817         for (i = 0; i < nbytes; ++i)
2818         {
2819             bytes[i] = (unsigned char)random_byte();
2820         }
2821         bytes[0] &= extra->low_byte_mask;
2822         bytes[nbytes-1] &= extra->high_byte_top_bit - 1;
2823         bytes[nbytes-1] |= extra->high_byte_top_bit;
2824         key->privateKey = bignum_from_bytes_le(bytes, nbytes);
2825         smemclr(bytes, nbytes);
2826         if (!key->privateKey) {
2827             sfree(key);
2828             return NULL;
2829         }
2830         publicKey = ecp_mul(&key->publicKey.curve->m.G, key->privateKey);
2831         if (!publicKey) {
2832             freebn(key->privateKey);
2833             sfree(key);
2834             return NULL;
2835         }
2836         key->publicKey.x = publicKey->x;
2837         key->publicKey.y = publicKey->y;
2838         key->publicKey.z = NULL;
2839         sfree(publicKey);
2840     } else {
2841         key->privateKey = bignum_random_in_range(One, key->publicKey.curve->w.n);
2842         if (!key->privateKey) {
2843             sfree(key);
2844             return NULL;
2845         }
2846         publicKey = ecp_mul(&key->publicKey.curve->w.G, key->privateKey);
2847         if (!publicKey) {
2848             freebn(key->privateKey);
2849             sfree(key);
2850             return NULL;
2851         }
2852         key->publicKey.x = publicKey->x;
2853         key->publicKey.y = publicKey->y;
2854         key->publicKey.z = NULL;
2855         sfree(publicKey);
2856     }
2857     return key;
2858 }
2859
2860 char *ssh_ecdhkex_getpublic(void *key, int *len)
2861 {
2862     struct ec_key *ec = (struct ec_key*)key;
2863     char *point, *p;
2864     int i;
2865     int pointlen;
2866
2867     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2868
2869     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2870         *len = 1 + pointlen * 2;
2871     } else {
2872         *len = pointlen;
2873     }
2874     point = (char*)snewn(*len, char);
2875
2876     p = point;
2877     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2878         *p++ = 0x04;
2879         for (i = pointlen; i--;) {
2880             *p++ = bignum_byte(ec->publicKey.x, i);
2881         }
2882         for (i = pointlen; i--;) {
2883             *p++ = bignum_byte(ec->publicKey.y, i);
2884         }
2885     } else {
2886         for (i = 0; i < pointlen; ++i) {
2887             *p++ = bignum_byte(ec->publicKey.x, i);
2888         }
2889     }
2890
2891     return point;
2892 }
2893
2894 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
2895 {
2896     struct ec_key *ec = (struct ec_key*) key;
2897     struct ec_point remote;
2898     Bignum ret;
2899
2900     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2901         remote.curve = ec->publicKey.curve;
2902         remote.infinity = 0;
2903         if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
2904             return NULL;
2905         }
2906     } else {
2907         /* Point length has to be the same length */
2908         if (remoteKeyLen != (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8) {
2909             return NULL;
2910         }
2911
2912         remote.curve = ec->publicKey.curve;
2913         remote.infinity = 0;
2914         remote.x = bignum_from_bytes_le((unsigned char*)remoteKey, remoteKeyLen);
2915         remote.y = NULL;
2916         remote.z = NULL;
2917     }
2918
2919     ret = ecdh_calculate(ec->privateKey, &remote);
2920     if (remote.x) freebn(remote.x);
2921     if (remote.y) freebn(remote.y);
2922     return ret;
2923 }
2924
2925 void ssh_ecdhkex_freekey(void *key)
2926 {
2927     ecdsa_freekey(key);
2928 }
2929
2930 static const struct eckex_extra kex_extra_curve25519 = {
2931     ec_curve25519, 0xF8, 0x40,
2932 };
2933 static const struct ssh_kex ssh_ec_kex_curve25519 = {
2934     "curve25519-sha256@libssh.org", NULL, KEXTYPE_ECDH,
2935     &ssh_sha256, &kex_extra_curve25519,
2936 };
2937
2938 static const struct eckex_extra kex_extra_curve448 = {
2939     ec_curve448, 0xFC, 0x80,
2940 };
2941 static const struct ssh_kex ssh_ec_kex_curve448 = {
2942     "curve448-sha512", NULL, KEXTYPE_ECDH,
2943     &ssh_sha512, &kex_extra_curve448,
2944 };
2945
2946 const struct eckex_extra kex_extra_nistp256 = { ec_p256 };
2947 static const struct ssh_kex ssh_ec_kex_nistp256 = {
2948     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH,
2949     &ssh_sha256, &kex_extra_nistp256,
2950 };
2951
2952 const struct eckex_extra kex_extra_nistp384 = { ec_p384 };
2953 static const struct ssh_kex ssh_ec_kex_nistp384 = {
2954     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH,
2955     &ssh_sha384, &kex_extra_nistp384,
2956 };
2957
2958 const struct eckex_extra kex_extra_nistp521 = { ec_p521 };
2959 static const struct ssh_kex ssh_ec_kex_nistp521 = {
2960     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH,
2961     &ssh_sha512, &kex_extra_nistp521,
2962 };
2963
2964 static const struct ssh_kex *const ec_kex_list[] = {
2965     &ssh_ec_kex_curve448,
2966     &ssh_ec_kex_curve25519,
2967     &ssh_ec_kex_nistp256,
2968     &ssh_ec_kex_nistp384,
2969     &ssh_ec_kex_nistp521,
2970 };
2971
2972 const struct ssh_kexes ssh_ecdh_kex = {
2973     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
2974     ec_kex_list
2975 };
2976
2977 /* ----------------------------------------------------------------------
2978  * Helper functions for finding key algorithms and returning auxiliary
2979  * data.
2980  */
2981
2982 const struct ssh_signkey *ec_alg_by_oid(int len, const void *oid,
2983                                         const struct ec_curve **curve)
2984 {
2985     static const struct ssh_signkey *algs_with_oid[] = {
2986         &ssh_ecdsa_nistp256,
2987         &ssh_ecdsa_nistp384,
2988         &ssh_ecdsa_nistp521,
2989     };
2990     int i;
2991
2992     for (i = 0; i < lenof(algs_with_oid); i++) {
2993         const struct ssh_signkey *alg = algs_with_oid[i];
2994         const struct ecsign_extra *extra =
2995             (const struct ecsign_extra *)alg->extra;
2996         if (len == extra->oidlen && !memcmp(oid, extra->oid, len)) {
2997             *curve = extra->curve();
2998             return alg;
2999         }
3000     }
3001     return NULL;
3002 }
3003
3004 const unsigned char *ec_alg_oid(const struct ssh_signkey *alg,
3005                                 int *oidlen)
3006 {
3007     const struct ecsign_extra *extra = (const struct ecsign_extra *)alg->extra;
3008     *oidlen = extra->oidlen;
3009     return extra->oid;
3010 }
3011
3012 const int ec_nist_curve_lengths[] = { 256, 384, 521 };
3013 const int n_ec_nist_curve_lengths = lenof(ec_nist_curve_lengths);
3014
3015 const int ec_nist_alg_and_curve_by_bits(int bits,
3016                                         const struct ec_curve **curve,
3017                                         const struct ssh_signkey **alg)
3018 {
3019     switch (bits) {
3020       case 256: *alg = &ssh_ecdsa_nistp256; break;
3021       case 384: *alg = &ssh_ecdsa_nistp384; break;
3022       case 521: *alg = &ssh_ecdsa_nistp521; break;
3023       default: return FALSE;
3024     }
3025     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
3026     return TRUE;
3027 }
3028
3029 const int ec_ed_alg_and_curve_by_bits(int bits,
3030                                       const struct ec_curve **curve,
3031                                       const struct ssh_signkey **alg)
3032 {
3033     switch (bits) {
3034       case 256: *alg = &ssh_ecdsa_ed25519; break;
3035       default: return FALSE;
3036     }
3037     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
3038     return TRUE;
3039 }