]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
Another ecdsa_newkey crash: initialise ec->privateKey earlier.
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  *
5  * NOTE: Only curves on prime field are handled by the maths functions
6  *       in Weierstrass form using Jacobian co-ordinates.
7  *
8  *       Montgomery form curves are supported for DH. (Curve25519)
9  *
10  *       Edwards form curves are supported for DSA. (Ed25519)
11  */
12
13 /*
14  * References:
15  *
16  * Elliptic curves in SSH are specified in RFC 5656:
17  *   http://tools.ietf.org/html/rfc5656
18  *
19  * That specification delegates details of public key formatting and a
20  * lot of underlying mechanism to SEC 1:
21  *   http://www.secg.org/sec1-v2.pdf
22  *
23  * Montgomery maths from:
24  * Handbook of elliptic and hyperelliptic curve cryptography, Chapter 13
25  *   http://cs.ucsb.edu/~koc/ccs130h/2013/EllipticHyperelliptic-CohenFrey.pdf
26  *
27  * Curve25519 spec from libssh (with reference to other things in the
28  * libssh code):
29  *   https://git.libssh.org/users/aris/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt
30  *
31  * Edwards DSA:
32  *   http://ed25519.cr.yp.to/ed25519-20110926.pdf
33  */
34
35 #include <stdlib.h>
36 #include <assert.h>
37
38 #include "ssh.h"
39
40 /* ----------------------------------------------------------------------
41  * Elliptic curve definitions
42  */
43
44 static void initialise_wcurve(struct ec_curve *curve, int bits,
45                               const unsigned char *p,
46                               const unsigned char *a, const unsigned char *b,
47                               const unsigned char *n, const unsigned char *Gx,
48                               const unsigned char *Gy)
49 {
50     int length = bits / 8;
51     if (bits % 8) ++length;
52
53     curve->type = EC_WEIERSTRASS;
54
55     curve->fieldBits = bits;
56     curve->p = bignum_from_bytes(p, length);
57
58     /* Curve co-efficients */
59     curve->w.a = bignum_from_bytes(a, length);
60     curve->w.b = bignum_from_bytes(b, length);
61
62     /* Group order and generator */
63     curve->w.n = bignum_from_bytes(n, length);
64     curve->w.G.x = bignum_from_bytes(Gx, length);
65     curve->w.G.y = bignum_from_bytes(Gy, length);
66     curve->w.G.curve = curve;
67     curve->w.G.infinity = 0;
68 }
69
70 static void initialise_mcurve(struct ec_curve *curve, int bits,
71                               const unsigned char *p,
72                               const unsigned char *a, const unsigned char *b,
73                               const unsigned char *Gx)
74 {
75     int length = bits / 8;
76     if (bits % 8) ++length;
77
78     curve->type = EC_MONTGOMERY;
79
80     curve->fieldBits = bits;
81     curve->p = bignum_from_bytes(p, length);
82
83     /* Curve co-efficients */
84     curve->m.a = bignum_from_bytes(a, length);
85     curve->m.b = bignum_from_bytes(b, length);
86
87     /* Generator */
88     curve->m.G.x = bignum_from_bytes(Gx, length);
89     curve->m.G.y = NULL;
90     curve->m.G.z = NULL;
91     curve->m.G.curve = curve;
92     curve->m.G.infinity = 0;
93 }
94
95 static void initialise_ecurve(struct ec_curve *curve, int bits,
96                               const unsigned char *p,
97                               const unsigned char *l, const unsigned char *d,
98                               const unsigned char *Bx, const unsigned char *By)
99 {
100     int length = bits / 8;
101     if (bits % 8) ++length;
102
103     curve->type = EC_EDWARDS;
104
105     curve->fieldBits = bits;
106     curve->p = bignum_from_bytes(p, length);
107
108     /* Curve co-efficients */
109     curve->e.l = bignum_from_bytes(l, length);
110     curve->e.d = bignum_from_bytes(d, length);
111
112     /* Group order and generator */
113     curve->e.B.x = bignum_from_bytes(Bx, length);
114     curve->e.B.y = bignum_from_bytes(By, length);
115     curve->e.B.curve = curve;
116     curve->e.B.infinity = 0;
117 }
118
119 static struct ec_curve *ec_p256(void)
120 {
121     static struct ec_curve curve = { 0 };
122     static unsigned char initialised = 0;
123
124     if (!initialised)
125     {
126         static const unsigned char p[] = {
127             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
128             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
129             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
130             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
131         };
132         static const unsigned char a[] = {
133             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
134             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
135             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
136             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
137         };
138         static const unsigned char b[] = {
139             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
140             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
141             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
142             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
143         };
144         static const unsigned char n[] = {
145             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
146             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
147             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
148             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
149         };
150         static const unsigned char Gx[] = {
151             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
152             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
153             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
154             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
155         };
156         static const unsigned char Gy[] = {
157             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
158             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
159             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
160             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
161         };
162
163         initialise_wcurve(&curve, 256, p, a, b, n, Gx, Gy);
164         curve.textname = curve.name = "nistp256";
165
166         /* Now initialised, no need to do it again */
167         initialised = 1;
168     }
169
170     return &curve;
171 }
172
173 static struct ec_curve *ec_p384(void)
174 {
175     static struct ec_curve curve = { 0 };
176     static unsigned char initialised = 0;
177
178     if (!initialised)
179     {
180         static const unsigned char p[] = {
181             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
182             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
183             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
184             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
185             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
186             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
187         };
188         static const unsigned char a[] = {
189             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
190             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
191             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
192             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
193             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
194             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
195         };
196         static const unsigned char b[] = {
197             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
198             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
199             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
200             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
201             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
202             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
203         };
204         static const unsigned char n[] = {
205             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
206             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
207             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
208             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
209             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
210             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
211         };
212         static const unsigned char Gx[] = {
213             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
214             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
215             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
216             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
217             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
218             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
219         };
220         static const unsigned char Gy[] = {
221             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
222             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
223             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
224             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
225             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
226             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
227         };
228
229         initialise_wcurve(&curve, 384, p, a, b, n, Gx, Gy);
230         curve.textname = curve.name = "nistp384";
231
232         /* Now initialised, no need to do it again */
233         initialised = 1;
234     }
235
236     return &curve;
237 }
238
239 static struct ec_curve *ec_p521(void)
240 {
241     static struct ec_curve curve = { 0 };
242     static unsigned char initialised = 0;
243
244     if (!initialised)
245     {
246         static const unsigned char p[] = {
247             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
248             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
249             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
250             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
251             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
252             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
253             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
254             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
255             0xff, 0xff
256         };
257         static const unsigned char a[] = {
258             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
259             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
260             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
261             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
262             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
263             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
264             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
265             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
266             0xff, 0xfc
267         };
268         static const unsigned char b[] = {
269             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
270             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
271             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
272             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
273             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
274             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
275             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
276             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
277             0x3f, 0x00
278         };
279         static const unsigned char n[] = {
280             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
281             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
282             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
283             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
284             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
285             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
286             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
287             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
288             0x64, 0x09
289         };
290         static const unsigned char Gx[] = {
291             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
292             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
293             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
294             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
295             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
296             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
297             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
298             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
299             0xbd, 0x66
300         };
301         static const unsigned char Gy[] = {
302             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
303             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
304             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
305             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
306             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
307             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
308             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
309             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
310             0x66, 0x50
311         };
312
313         initialise_wcurve(&curve, 521, p, a, b, n, Gx, Gy);
314         curve.textname = curve.name = "nistp521";
315
316         /* Now initialised, no need to do it again */
317         initialised = 1;
318     }
319
320     return &curve;
321 }
322
323 static struct ec_curve *ec_curve25519(void)
324 {
325     static struct ec_curve curve = { 0 };
326     static unsigned char initialised = 0;
327
328     if (!initialised)
329     {
330         static const unsigned char p[] = {
331             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
332             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
333             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
334             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
335         };
336         static const unsigned char a[] = {
337             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
338             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
339             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
340             0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x6d, 0x06
341         };
342         static const unsigned char b[] = {
343             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
344             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
345             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
346             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01
347         };
348         static const unsigned char gx[32] = {
349             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
350             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
351             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
352             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09
353         };
354
355         initialise_mcurve(&curve, 256, p, a, b, gx);
356         /* This curve doesn't need a name, because it's never used in
357          * any format that embeds the curve name */
358         curve.name = NULL;
359         curve.textname = "Curve25519";
360
361         /* Now initialised, no need to do it again */
362         initialised = 1;
363     }
364
365     return &curve;
366 }
367
368 static struct ec_curve *ec_ed25519(void)
369 {
370     static struct ec_curve curve = { 0 };
371     static unsigned char initialised = 0;
372
373     if (!initialised)
374     {
375         static const unsigned char q[] = {
376             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
377             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
378             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
379             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
380         };
381         static const unsigned char l[32] = {
382             0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
383             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
384             0x14, 0xde, 0xf9, 0xde, 0xa2, 0xf7, 0x9c, 0xd6,
385             0x58, 0x12, 0x63, 0x1a, 0x5c, 0xf5, 0xd3, 0xed
386         };
387         static const unsigned char d[32] = {
388             0x52, 0x03, 0x6c, 0xee, 0x2b, 0x6f, 0xfe, 0x73,
389             0x8c, 0xc7, 0x40, 0x79, 0x77, 0x79, 0xe8, 0x98,
390             0x00, 0x70, 0x0a, 0x4d, 0x41, 0x41, 0xd8, 0xab,
391             0x75, 0xeb, 0x4d, 0xca, 0x13, 0x59, 0x78, 0xa3
392         };
393         static const unsigned char Bx[32] = {
394             0x21, 0x69, 0x36, 0xd3, 0xcd, 0x6e, 0x53, 0xfe,
395             0xc0, 0xa4, 0xe2, 0x31, 0xfd, 0xd6, 0xdc, 0x5c,
396             0x69, 0x2c, 0xc7, 0x60, 0x95, 0x25, 0xa7, 0xb2,
397             0xc9, 0x56, 0x2d, 0x60, 0x8f, 0x25, 0xd5, 0x1a
398         };
399         static const unsigned char By[32] = {
400             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
401             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
402             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
403             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x58
404         };
405
406         /* This curve doesn't need a name, because it's never used in
407          * any format that embeds the curve name */
408         curve.name = NULL;
409
410         initialise_ecurve(&curve, 256, q, l, d, Bx, By);
411         curve.textname = "Ed25519";
412
413         /* Now initialised, no need to do it again */
414         initialised = 1;
415     }
416
417     return &curve;
418 }
419
420 /* Return 1 if a is -3 % p, otherwise return 0
421  * This is used because there are some maths optimisations */
422 static int ec_aminus3(const struct ec_curve *curve)
423 {
424     int ret;
425     Bignum _p;
426
427     if (curve->type != EC_WEIERSTRASS) {
428         return 0;
429     }
430
431     _p = bignum_add_long(curve->w.a, 3);
432
433     ret = !bignum_cmp(curve->p, _p);
434     freebn(_p);
435     return ret;
436 }
437
438 /* ----------------------------------------------------------------------
439  * Elliptic curve field maths
440  */
441
442 static Bignum ecf_add(const Bignum a, const Bignum b,
443                       const struct ec_curve *curve)
444 {
445     Bignum a1, b1, ab, ret;
446
447     a1 = bigmod(a, curve->p);
448     b1 = bigmod(b, curve->p);
449
450     ab = bigadd(a1, b1);
451     freebn(a1);
452     freebn(b1);
453
454     ret = bigmod(ab, curve->p);
455     freebn(ab);
456
457     return ret;
458 }
459
460 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
461 {
462     return modmul(a, a, curve->p);
463 }
464
465 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
466 {
467     Bignum ret, tmp;
468
469     /* Double */
470     tmp = bignum_lshift(a, 1);
471
472     /* Add itself (i.e. treble) */
473     ret = bigadd(tmp, a);
474     freebn(tmp);
475
476     /* Normalise */
477     while (bignum_cmp(ret, curve->p) >= 0)
478     {
479         tmp = bigsub(ret, curve->p);
480         assert(tmp);
481         freebn(ret);
482         ret = tmp;
483     }
484
485     return ret;
486 }
487
488 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
489 {
490     Bignum ret = bignum_lshift(a, 1);
491     if (bignum_cmp(ret, curve->p) >= 0)
492     {
493         Bignum tmp = bigsub(ret, curve->p);
494         assert(tmp);
495         freebn(ret);
496         return tmp;
497     }
498     else
499     {
500         return ret;
501     }
502 }
503
504 /* ----------------------------------------------------------------------
505  * Memory functions
506  */
507
508 void ec_point_free(struct ec_point *point)
509 {
510     if (point == NULL) return;
511     point->curve = 0;
512     if (point->x) freebn(point->x);
513     if (point->y) freebn(point->y);
514     if (point->z) freebn(point->z);
515     point->infinity = 0;
516     sfree(point);
517 }
518
519 static struct ec_point *ec_point_new(const struct ec_curve *curve,
520                                      const Bignum x, const Bignum y, const Bignum z,
521                                      unsigned char infinity)
522 {
523     struct ec_point *point = snewn(1, struct ec_point);
524     point->curve = curve;
525     point->x = x;
526     point->y = y;
527     point->z = z;
528     point->infinity = infinity ? 1 : 0;
529     return point;
530 }
531
532 static struct ec_point *ec_point_copy(const struct ec_point *a)
533 {
534     if (a == NULL) return NULL;
535     return ec_point_new(a->curve,
536                         a->x ? copybn(a->x) : NULL,
537                         a->y ? copybn(a->y) : NULL,
538                         a->z ? copybn(a->z) : NULL,
539                         a->infinity);
540 }
541
542 static int ec_point_verify(const struct ec_point *a)
543 {
544     if (a->infinity) {
545         return 1;
546     } else if (a->curve->type == EC_EDWARDS) {
547         /* Check y^2 - x^2 - 1 - d * x^2 * y^2 == 0 */
548         Bignum y2, x2, tmp, tmp2, tmp3;
549         int ret;
550
551         y2 = ecf_square(a->y, a->curve);
552         x2 = ecf_square(a->x, a->curve);
553         tmp = modmul(a->curve->e.d, x2, a->curve->p);
554         tmp2 = modmul(tmp, y2, a->curve->p);
555         freebn(tmp);
556         tmp = modsub(y2, x2, a->curve->p);
557         freebn(y2);
558         freebn(x2);
559         tmp3 = modsub(tmp, tmp2, a->curve->p);
560         freebn(tmp);
561         freebn(tmp2);
562         ret = !bignum_cmp(tmp3, One);
563         freebn(tmp3);
564         return ret;
565     } else if (a->curve->type == EC_WEIERSTRASS) {
566         /* Verify y^2 = x^3 + ax + b */
567         int ret = 0;
568
569         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
570
571         Bignum Three = bignum_from_long(3);
572
573         lhs = modmul(a->y, a->y, a->curve->p);
574
575         /* This uses montgomery multiplication to optimise */
576         x3 = modpow(a->x, Three, a->curve->p);
577         freebn(Three);
578         ax = modmul(a->curve->w.a, a->x, a->curve->p);
579         x3ax = bigadd(x3, ax);
580         freebn(x3); x3 = NULL;
581         freebn(ax); ax = NULL;
582         x3axm = bigmod(x3ax, a->curve->p);
583         freebn(x3ax); x3ax = NULL;
584         x3axb = bigadd(x3axm, a->curve->w.b);
585         freebn(x3axm); x3axm = NULL;
586         rhs = bigmod(x3axb, a->curve->p);
587         freebn(x3axb);
588
589         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
590         freebn(lhs);
591         freebn(rhs);
592
593         return ret;
594     } else {
595         return 0;
596     }
597 }
598
599 /* ----------------------------------------------------------------------
600  * Elliptic curve point maths
601  */
602
603 /* Returns 1 on success and 0 on memory error */
604 static int ecp_normalise(struct ec_point *a)
605 {
606     if (!a) {
607         /* No point */
608         return 0;
609     }
610
611     if (a->infinity) {
612         /* Point is at infinity - i.e. normalised */
613         return 1;
614     }
615
616     if (a->curve->type == EC_WEIERSTRASS) {
617         /* In Jacobian Coordinates the triple (X, Y, Z) represents
618            the affine point (X / Z^2, Y / Z^3) */
619
620         Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
621
622         if (!a->x || !a->y) {
623             /* No point defined */
624             return 0;
625         } else if (!a->z) {
626             /* Already normalised */
627             return 1;
628         }
629
630         Z2 = ecf_square(a->z, a->curve);
631         Z2inv = modinv(Z2, a->curve->p);
632         if (!Z2inv) {
633             freebn(Z2);
634             return 0;
635         }
636         tx = modmul(a->x, Z2inv, a->curve->p);
637         freebn(Z2inv);
638
639         Z3 = modmul(Z2, a->z, a->curve->p);
640         freebn(Z2);
641         Z3inv = modinv(Z3, a->curve->p);
642         freebn(Z3);
643         if (!Z3inv) {
644             freebn(tx);
645             return 0;
646         }
647         ty = modmul(a->y, Z3inv, a->curve->p);
648         freebn(Z3inv);
649
650         freebn(a->x);
651         a->x = tx;
652         freebn(a->y);
653         a->y = ty;
654         freebn(a->z);
655         a->z = NULL;
656         return 1;
657     } else if (a->curve->type == EC_MONTGOMERY) {
658         /* In Montgomery (X : Z) represents the x co-ord (X / Z, ?) */
659
660         Bignum tmp, tmp2;
661
662         if (!a->x) {
663             /* No point defined */
664             return 0;
665         } else if (!a->z) {
666             /* Already normalised */
667             return 1;
668         }
669
670         tmp = modinv(a->z, a->curve->p);
671         if (!tmp) {
672             return 0;
673         }
674         tmp2 = modmul(a->x, tmp, a->curve->p);
675         freebn(tmp);
676
677         freebn(a->z);
678         a->z = NULL;
679         freebn(a->x);
680         a->x = tmp2;
681         return 1;
682     } else if (a->curve->type == EC_EDWARDS) {
683         /* Always normalised */
684         return 1;
685     } else {
686         return 0;
687     }
688 }
689
690 static struct ec_point *ecp_doublew(const struct ec_point *a, const int aminus3)
691 {
692     Bignum S, M, outx, outy, outz;
693
694     if (bignum_cmp(a->y, Zero) == 0)
695     {
696         /* Identity */
697         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
698     }
699
700     /* S = 4*X*Y^2 */
701     {
702         Bignum Y2, XY2, _2XY2;
703
704         Y2 = ecf_square(a->y, a->curve);
705         XY2 = modmul(a->x, Y2, a->curve->p);
706         freebn(Y2);
707
708         _2XY2 = ecf_double(XY2, a->curve);
709         freebn(XY2);
710         S = ecf_double(_2XY2, a->curve);
711         freebn(_2XY2);
712     }
713
714     /* Faster calculation if a = -3 */
715     if (aminus3) {
716         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
717         Bignum Z2, XpZ2, XmZ2, second;
718
719         if (a->z == NULL) {
720             Z2 = copybn(One);
721         } else {
722             Z2 = ecf_square(a->z, a->curve);
723         }
724
725         XpZ2 = ecf_add(a->x, Z2, a->curve);
726         XmZ2 = modsub(a->x, Z2, a->curve->p);
727         freebn(Z2);
728
729         second = modmul(XpZ2, XmZ2, a->curve->p);
730         freebn(XpZ2);
731         freebn(XmZ2);
732
733         M = ecf_treble(second, a->curve);
734         freebn(second);
735     } else {
736         /* M = 3*X^2 + a*Z^4 */
737         Bignum _3X2, X2, aZ4;
738
739         if (a->z == NULL) {
740             aZ4 = copybn(a->curve->w.a);
741         } else {
742             Bignum Z2, Z4;
743
744             Z2 = ecf_square(a->z, a->curve);
745             Z4 = ecf_square(Z2, a->curve);
746             freebn(Z2);
747             aZ4 = modmul(a->curve->w.a, Z4, a->curve->p);
748             freebn(Z4);
749         }
750
751         X2 = modmul(a->x, a->x, a->curve->p);
752         _3X2 = ecf_treble(X2, a->curve);
753         freebn(X2);
754         M = ecf_add(_3X2, aZ4, a->curve);
755         freebn(_3X2);
756         freebn(aZ4);
757     }
758
759     /* X' = M^2 - 2*S */
760     {
761         Bignum M2, _2S;
762
763         M2 = ecf_square(M, a->curve);
764         _2S = ecf_double(S, a->curve);
765         outx = modsub(M2, _2S, a->curve->p);
766         freebn(M2);
767         freebn(_2S);
768     }
769
770     /* Y' = M*(S - X') - 8*Y^4 */
771     {
772         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
773
774         SX = modsub(S, outx, a->curve->p);
775         freebn(S);
776         MSX = modmul(M, SX, a->curve->p);
777         freebn(SX);
778         freebn(M);
779         Y2 = ecf_square(a->y, a->curve);
780         Y4 = ecf_square(Y2, a->curve);
781         freebn(Y2);
782         Eight = bignum_from_long(8);
783         _8Y4 = modmul(Eight, Y4, a->curve->p);
784         freebn(Eight);
785         freebn(Y4);
786         outy = modsub(MSX, _8Y4, a->curve->p);
787         freebn(MSX);
788         freebn(_8Y4);
789     }
790
791     /* Z' = 2*Y*Z */
792     {
793         Bignum YZ;
794
795         if (a->z == NULL) {
796             YZ = copybn(a->y);
797         } else {
798             YZ = modmul(a->y, a->z, a->curve->p);
799         }
800
801         outz = ecf_double(YZ, a->curve);
802         freebn(YZ);
803     }
804
805     return ec_point_new(a->curve, outx, outy, outz, 0);
806 }
807
808 static struct ec_point *ecp_doublem(const struct ec_point *a)
809 {
810     Bignum z, outx, outz, xpz, xmz;
811
812     z = a->z;
813     if (!z) {
814         z = One;
815     }
816
817     /* 4xz = (x + z)^2 - (x - z)^2 */
818     {
819         Bignum tmp;
820
821         tmp = ecf_add(a->x, z, a->curve);
822         xpz = ecf_square(tmp, a->curve);
823         freebn(tmp);
824
825         tmp = modsub(a->x, z, a->curve->p);
826         xmz = ecf_square(tmp, a->curve);
827         freebn(tmp);
828     }
829
830     /* outx = (x + z)^2 * (x - z)^2 */
831     outx = modmul(xpz, xmz, a->curve->p);
832
833     /* outz = 4xz * ((x - z)^2 + ((A + 2) / 4)*4xz) */
834     {
835         Bignum _4xz, tmp, tmp2, tmp3;
836
837         tmp = bignum_from_long(2);
838         tmp2 = ecf_add(a->curve->m.a, tmp, a->curve);
839         freebn(tmp);
840
841         _4xz = modsub(xpz, xmz, a->curve->p);
842         freebn(xpz);
843         tmp = modmul(tmp2, _4xz, a->curve->p);
844         freebn(tmp2);
845
846         tmp2 = bignum_from_long(4);
847         tmp3 = modinv(tmp2, a->curve->p);
848         freebn(tmp2);
849         if (!tmp3) {
850             freebn(tmp);
851             freebn(_4xz);
852             freebn(outx);
853             freebn(xmz);
854             return NULL;
855         }
856         tmp2 = modmul(tmp, tmp3, a->curve->p);
857         freebn(tmp);
858         freebn(tmp3);
859
860         tmp = ecf_add(xmz, tmp2, a->curve);
861         freebn(xmz);
862         freebn(tmp2);
863         outz = modmul(_4xz, tmp, a->curve->p);
864         freebn(_4xz);
865         freebn(tmp);
866     }
867
868     return ec_point_new(a->curve, outx, NULL, outz, 0);
869 }
870
871 /* Forward declaration for Edwards curve doubling */
872 static struct ec_point *ecp_add(const struct ec_point *a,
873                                 const struct ec_point *b,
874                                 const int aminus3);
875
876 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
877 {
878     if (a->infinity)
879     {
880         /* Identity */
881         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
882     }
883
884     if (a->curve->type == EC_EDWARDS)
885     {
886         return ecp_add(a, a, aminus3);
887     }
888     else if (a->curve->type == EC_WEIERSTRASS)
889     {
890         return ecp_doublew(a, aminus3);
891     }
892     else
893     {
894         return ecp_doublem(a);
895     }
896 }
897
898 static struct ec_point *ecp_addw(const struct ec_point *a,
899                                  const struct ec_point *b,
900                                  const int aminus3)
901 {
902     Bignum U1, U2, S1, S2, outx, outy, outz;
903
904     /* U1 = X1*Z2^2 */
905     /* S1 = Y1*Z2^3 */
906     if (b->z) {
907         Bignum Z2, Z3;
908
909         Z2 = ecf_square(b->z, a->curve);
910         U1 = modmul(a->x, Z2, a->curve->p);
911         Z3 = modmul(Z2, b->z, a->curve->p);
912         freebn(Z2);
913         S1 = modmul(a->y, Z3, a->curve->p);
914         freebn(Z3);
915     } else {
916         U1 = copybn(a->x);
917         S1 = copybn(a->y);
918     }
919
920     /* U2 = X2*Z1^2 */
921     /* S2 = Y2*Z1^3 */
922     if (a->z) {
923         Bignum Z2, Z3;
924
925         Z2 = ecf_square(a->z, b->curve);
926         U2 = modmul(b->x, Z2, b->curve->p);
927         Z3 = modmul(Z2, a->z, b->curve->p);
928         freebn(Z2);
929         S2 = modmul(b->y, Z3, b->curve->p);
930         freebn(Z3);
931     } else {
932         U2 = copybn(b->x);
933         S2 = copybn(b->y);
934     }
935
936     /* Check if multiplying by self */
937     if (bignum_cmp(U1, U2) == 0)
938     {
939         freebn(U1);
940         freebn(U2);
941         if (bignum_cmp(S1, S2) == 0)
942         {
943             freebn(S1);
944             freebn(S2);
945             return ecp_double(a, aminus3);
946         }
947         else
948         {
949             freebn(S1);
950             freebn(S2);
951             /* Infinity */
952             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
953         }
954     }
955
956     {
957         Bignum H, R, UH2, H3;
958
959         /* H = U2 - U1 */
960         H = modsub(U2, U1, a->curve->p);
961         freebn(U2);
962
963         /* R = S2 - S1 */
964         R = modsub(S2, S1, a->curve->p);
965         freebn(S2);
966
967         /* X3 = R^2 - H^3 - 2*U1*H^2 */
968         {
969             Bignum R2, H2, _2UH2, first;
970
971             H2 = ecf_square(H, a->curve);
972             UH2 = modmul(U1, H2, a->curve->p);
973             freebn(U1);
974             H3 = modmul(H2, H, a->curve->p);
975             freebn(H2);
976             R2 = ecf_square(R, a->curve);
977             _2UH2 = ecf_double(UH2, a->curve);
978             first = modsub(R2, H3, a->curve->p);
979             freebn(R2);
980             outx = modsub(first, _2UH2, a->curve->p);
981             freebn(first);
982             freebn(_2UH2);
983         }
984
985         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
986         {
987             Bignum RUH2mX, UH2mX, SH3;
988
989             UH2mX = modsub(UH2, outx, a->curve->p);
990             freebn(UH2);
991             RUH2mX = modmul(R, UH2mX, a->curve->p);
992             freebn(UH2mX);
993             freebn(R);
994             SH3 = modmul(S1, H3, a->curve->p);
995             freebn(S1);
996             freebn(H3);
997
998             outy = modsub(RUH2mX, SH3, a->curve->p);
999             freebn(RUH2mX);
1000             freebn(SH3);
1001         }
1002
1003         /* Z3 = H*Z1*Z2 */
1004         if (a->z && b->z) {
1005             Bignum ZZ;
1006
1007             ZZ = modmul(a->z, b->z, a->curve->p);
1008             outz = modmul(H, ZZ, a->curve->p);
1009             freebn(H);
1010             freebn(ZZ);
1011         } else if (a->z) {
1012             outz = modmul(H, a->z, a->curve->p);
1013             freebn(H);
1014         } else if (b->z) {
1015             outz = modmul(H, b->z, a->curve->p);
1016             freebn(H);
1017         } else {
1018             outz = H;
1019         }
1020     }
1021
1022     return ec_point_new(a->curve, outx, outy, outz, 0);
1023 }
1024
1025 static struct ec_point *ecp_addm(const struct ec_point *a,
1026                                  const struct ec_point *b,
1027                                  const struct ec_point *base)
1028 {
1029     Bignum outx, outz, az, bz;
1030
1031     az = a->z;
1032     if (!az) {
1033         az = One;
1034     }
1035     bz = b->z;
1036     if (!bz) {
1037         bz = One;
1038     }
1039
1040     /* a-b is maintained at 1 due to Montgomery ladder implementation */
1041     /* Xa+b = Za-b * ((Xa - Za)*(Xb + Zb) + (Xa + Za)*(Xb - Zb))^2 */
1042     /* Za+b = Xa-b * ((Xa - Za)*(Xb + Zb) - (Xa + Za)*(Xb - Zb))^2 */
1043     {
1044         Bignum tmp, tmp2, tmp3, tmp4;
1045
1046         /* (Xa + Za) * (Xb - Zb) */
1047         tmp = ecf_add(a->x, az, a->curve);
1048         tmp2 = modsub(b->x, bz, a->curve->p);
1049         tmp3 = modmul(tmp, tmp2, a->curve->p);
1050         freebn(tmp);
1051         freebn(tmp2);
1052
1053         /* (Xa - Za) * (Xb + Zb) */
1054         tmp = modsub(a->x, az, a->curve->p);
1055         tmp2 = ecf_add(b->x, bz, a->curve);
1056         tmp4 = modmul(tmp, tmp2, a->curve->p);
1057         freebn(tmp);
1058         freebn(tmp2);
1059
1060         tmp = ecf_add(tmp3, tmp4, a->curve);
1061         outx = ecf_square(tmp, a->curve);
1062         freebn(tmp);
1063
1064         tmp = modsub(tmp3, tmp4, a->curve->p);
1065         freebn(tmp3);
1066         freebn(tmp4);
1067         tmp2 = ecf_square(tmp, a->curve);
1068         freebn(tmp);
1069         outz = modmul(base->x, tmp2, a->curve->p);
1070         freebn(tmp2);
1071     }
1072
1073     return ec_point_new(a->curve, outx, NULL, outz, 0);
1074 }
1075
1076 static struct ec_point *ecp_adde(const struct ec_point *a,
1077                                  const struct ec_point *b)
1078 {
1079     Bignum outx, outy, dmul;
1080
1081     /* outx = (a->x * b->y + b->x * a->y) /
1082      *        (1 + a->curve->e.d * a->x * b->x * a->y * b->y) */
1083     {
1084         Bignum tmp, tmp2, tmp3, tmp4;
1085
1086         tmp = modmul(a->x, b->y, a->curve->p);
1087         tmp2 = modmul(b->x, a->y, a->curve->p);
1088         tmp3 = ecf_add(tmp, tmp2, a->curve);
1089
1090         tmp4 = modmul(tmp, tmp2, a->curve->p);
1091         freebn(tmp);
1092         freebn(tmp2);
1093         dmul = modmul(a->curve->e.d, tmp4, a->curve->p);
1094         freebn(tmp4);
1095
1096         tmp = ecf_add(One, dmul, a->curve);
1097         tmp2 = modinv(tmp, a->curve->p);
1098         freebn(tmp);
1099         if (!tmp2)
1100         {
1101             freebn(tmp3);
1102             freebn(dmul);
1103             return NULL;
1104         }
1105
1106         outx = modmul(tmp3, tmp2, a->curve->p);
1107         freebn(tmp3);
1108         freebn(tmp2);
1109     }
1110
1111     /* outy = (a->y * b->y + a->x * b->x) /
1112      *        (1 - a->curve->e.d * a->x * b->x * a->y * b->y) */
1113     {
1114         Bignum tmp, tmp2, tmp3, tmp4;
1115
1116         tmp = modsub(One, dmul, a->curve->p);
1117         freebn(dmul);
1118
1119         tmp2 = modinv(tmp, a->curve->p);
1120         freebn(tmp);
1121         if (!tmp2)
1122         {
1123             freebn(outx);
1124             return NULL;
1125         }
1126
1127         tmp = modmul(a->y, b->y, a->curve->p);
1128         tmp3 = modmul(a->x, b->x, a->curve->p);
1129         tmp4 = ecf_add(tmp, tmp3, a->curve);
1130         freebn(tmp);
1131         freebn(tmp3);
1132
1133         outy = modmul(tmp4, tmp2, a->curve->p);
1134         freebn(tmp4);
1135         freebn(tmp2);
1136     }
1137
1138     return ec_point_new(a->curve, outx, outy, NULL, 0);
1139 }
1140
1141 static struct ec_point *ecp_add(const struct ec_point *a,
1142                                 const struct ec_point *b,
1143                                 const int aminus3)
1144 {
1145     if (a->curve != b->curve) {
1146         return NULL;
1147     }
1148
1149     /* Check if multiplying by infinity */
1150     if (a->infinity) return ec_point_copy(b);
1151     if (b->infinity) return ec_point_copy(a);
1152
1153     if (a->curve->type == EC_EDWARDS)
1154     {
1155         return ecp_adde(a, b);
1156     }
1157
1158     if (a->curve->type == EC_WEIERSTRASS)
1159     {
1160         return ecp_addw(a, b, aminus3);
1161     }
1162
1163     return NULL;
1164 }
1165
1166 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1167 {
1168     struct ec_point *A, *ret;
1169     int bits, i;
1170
1171     A = ec_point_copy(a);
1172     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1173
1174     bits = bignum_bitcount(b);
1175     for (i = 0; i < bits; ++i)
1176     {
1177         if (bignum_bit(b, i))
1178         {
1179             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1180             ec_point_free(ret);
1181             ret = tmp;
1182         }
1183         if (i+1 != bits)
1184         {
1185             struct ec_point *tmp = ecp_double(A, aminus3);
1186             ec_point_free(A);
1187             A = tmp;
1188         }
1189     }
1190
1191     ec_point_free(A);
1192     return ret;
1193 }
1194
1195 static struct ec_point *ecp_mulw(const struct ec_point *a, const Bignum b)
1196 {
1197     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1198
1199     if (!ecp_normalise(ret)) {
1200         ec_point_free(ret);
1201         return NULL;
1202     }
1203
1204     return ret;
1205 }
1206
1207 static struct ec_point *ecp_mule(const struct ec_point *a, const Bignum b)
1208 {
1209     int i;
1210     struct ec_point *ret;
1211
1212     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1213
1214     for (i = bignum_bitcount(b); i >= 0 && ret; --i)
1215     {
1216         {
1217             struct ec_point *tmp = ecp_double(ret, 0);
1218             ec_point_free(ret);
1219             ret = tmp;
1220         }
1221         if (ret && bignum_bit(b, i))
1222         {
1223             struct ec_point *tmp = ecp_add(ret, a, 0);
1224             ec_point_free(ret);
1225             ret = tmp;
1226         }
1227     }
1228
1229     return ret;
1230 }
1231
1232 static struct ec_point *ecp_mulm(const struct ec_point *p, const Bignum n)
1233 {
1234     struct ec_point *P1, *P2;
1235     int bits, i;
1236
1237     /* P1 <- P and P2 <- [2]P */
1238     P2 = ecp_double(p, 0);
1239     P1 = ec_point_copy(p);
1240
1241     /* for i = bits âˆ’ 2 down to 0 */
1242     bits = bignum_bitcount(n);
1243     for (i = bits - 2; i >= 0; --i)
1244     {
1245         if (!bignum_bit(n, i))
1246         {
1247             /* P2 <- P1 + P2 */
1248             struct ec_point *tmp = ecp_addm(P1, P2, p);
1249             ec_point_free(P2);
1250             P2 = tmp;
1251
1252             /* P1 <- [2]P1 */
1253             tmp = ecp_double(P1, 0);
1254             ec_point_free(P1);
1255             P1 = tmp;
1256         }
1257         else
1258         {
1259             /* P1 <- P1 + P2 */
1260             struct ec_point *tmp = ecp_addm(P1, P2, p);
1261             ec_point_free(P1);
1262             P1 = tmp;
1263
1264             /* P2 <- [2]P2 */
1265             tmp = ecp_double(P2, 0);
1266             ec_point_free(P2);
1267             P2 = tmp;
1268         }
1269     }
1270
1271     ec_point_free(P2);
1272
1273     if (!ecp_normalise(P1)) {
1274         ec_point_free(P1);
1275         return NULL;
1276     }
1277
1278     return P1;
1279 }
1280
1281 /* Not static because it is used by sshecdsag.c to generate a new key */
1282 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1283 {
1284     if (a->curve->type == EC_WEIERSTRASS) {
1285         return ecp_mulw(a, b);
1286     } else if (a->curve->type == EC_EDWARDS) {
1287         return ecp_mule(a, b);
1288     } else {
1289         return ecp_mulm(a, b);
1290     }
1291 }
1292
1293 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
1294                                    const struct ec_point *point)
1295 {
1296     struct ec_point *aG, *bP, *ret;
1297     int aminus3;
1298
1299     if (point->curve->type != EC_WEIERSTRASS) {
1300         return NULL;
1301     }
1302
1303     aminus3 = ec_aminus3(point->curve);
1304
1305     aG = ecp_mul_(&point->curve->w.G, a, aminus3);
1306     if (!aG) return NULL;
1307     bP = ecp_mul_(point, b, aminus3);
1308     if (!bP) {
1309         ec_point_free(aG);
1310         return NULL;
1311     }
1312
1313     ret = ecp_add(aG, bP, aminus3);
1314
1315     ec_point_free(aG);
1316     ec_point_free(bP);
1317
1318     if (!ecp_normalise(ret)) {
1319         ec_point_free(ret);
1320         return NULL;
1321     }
1322
1323     return ret;
1324 }
1325 static Bignum *ecp_edx(const struct ec_curve *curve, const Bignum y)
1326 {
1327     /* Get the x value on the given Edwards curve for a given y */
1328     Bignum x, xx;
1329
1330     /* xx = (y^2 - 1) / (d * y^2 + 1) */
1331     {
1332         Bignum tmp, tmp2, tmp3;
1333
1334         tmp = ecf_square(y, curve);
1335         tmp2 = modmul(curve->e.d, tmp, curve->p);
1336         tmp3 = ecf_add(tmp2, One, curve);
1337         freebn(tmp2);
1338         tmp2 = modinv(tmp3, curve->p);
1339         freebn(tmp3);
1340         if (!tmp2) {
1341             freebn(tmp);
1342             return NULL;
1343         }
1344
1345         tmp3 = modsub(tmp, One, curve->p);
1346         freebn(tmp);
1347         xx = modmul(tmp3, tmp2, curve->p);
1348         freebn(tmp3);
1349         freebn(tmp2);
1350     }
1351
1352     /* x = xx^((p + 3) / 8) */
1353     {
1354         Bignum tmp, tmp2;
1355
1356         tmp = bignum_add_long(curve->p, 3);
1357         tmp2 = bignum_rshift(tmp, 3);
1358         freebn(tmp);
1359         x = modpow(xx, tmp2, curve->p);
1360         freebn(tmp2);
1361     }
1362
1363     /* if x^2 - xx != 0 then x = x*(2^((p - 1) / 4)) */
1364     {
1365         Bignum tmp, tmp2;
1366
1367         tmp = ecf_square(x, curve);
1368         tmp2 = modsub(tmp, xx, curve->p);
1369         freebn(tmp);
1370         freebn(xx);
1371         if (bignum_cmp(tmp2, Zero)) {
1372             Bignum tmp3;
1373
1374             freebn(tmp2);
1375
1376             tmp = modsub(curve->p, One, curve->p);
1377             tmp2 = bignum_rshift(tmp, 2);
1378             freebn(tmp);
1379             tmp = bignum_from_long(2);
1380             tmp3 = modpow(tmp, tmp2, curve->p);
1381             freebn(tmp);
1382             freebn(tmp2);
1383
1384             tmp = modmul(x, tmp3, curve->p);
1385             freebn(x);
1386             freebn(tmp3);
1387             x = tmp;
1388         } else {
1389             freebn(tmp2);
1390         }
1391     }
1392
1393     /* if x % 2 != 0 then x = p - x */
1394     if (bignum_bit(x, 0)) {
1395         Bignum tmp = modsub(curve->p, x, curve->p);
1396         freebn(x);
1397         x = tmp;
1398     }
1399
1400     return x;
1401 }
1402
1403 /* ----------------------------------------------------------------------
1404  * Public point from private
1405  */
1406
1407 struct ec_point *ec_public(const Bignum privateKey, const struct ec_curve *curve)
1408 {
1409     if (curve->type == EC_WEIERSTRASS) {
1410         return ecp_mul(&curve->w.G, privateKey);
1411     } else if (curve->type == EC_EDWARDS) {
1412         /* hash = H(sk) (where hash creates 2 * fieldBits)
1413          * b = fieldBits
1414          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
1415          * publicKey = aB */
1416         struct ec_point *ret;
1417         unsigned char hash[512/8];
1418         Bignum a;
1419         int i, keylen;
1420         SHA512_State s;
1421         SHA512_Init(&s);
1422
1423         keylen = curve->fieldBits / 8;
1424         for (i = 0; i < keylen; ++i) {
1425             unsigned char b = bignum_byte(privateKey, i);
1426             SHA512_Bytes(&s, &b, 1);
1427         }
1428         SHA512_Final(&s, hash);
1429
1430         /* The second part is simply turning the hash into a Bignum,
1431          * however the 2^(b-2) bit *must* be set, and the bottom 3
1432          * bits *must* not be */
1433         hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
1434         hash[31] &= 0x7f; /* Unset above (b-2) */
1435         hash[31] |= 0x40; /* Set 2^(b-2) */
1436         /* Chop off the top part and convert to int */
1437         a = bignum_from_bytes_le(hash, 32);
1438
1439         ret = ecp_mul(&curve->e.B, a);
1440         freebn(a);
1441         return ret;
1442     } else {
1443         return NULL;
1444     }
1445 }
1446
1447 /* ----------------------------------------------------------------------
1448  * Basic sign and verify routines
1449  */
1450
1451 static int _ecdsa_verify(const struct ec_point *publicKey,
1452                          const unsigned char *data, const int dataLen,
1453                          const Bignum r, const Bignum s)
1454 {
1455     int z_bits, n_bits;
1456     Bignum z;
1457     int valid = 0;
1458
1459     if (publicKey->curve->type != EC_WEIERSTRASS) {
1460         return 0;
1461     }
1462
1463     /* Sanity checks */
1464     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->w.n) >= 0
1465         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->w.n) >= 0)
1466     {
1467         return 0;
1468     }
1469
1470     /* z = left most bitlen(curve->n) of data */
1471     z = bignum_from_bytes(data, dataLen);
1472     n_bits = bignum_bitcount(publicKey->curve->w.n);
1473     z_bits = bignum_bitcount(z);
1474     if (z_bits > n_bits)
1475     {
1476         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
1477         freebn(z);
1478         z = tmp;
1479     }
1480
1481     /* Ensure z in range of n */
1482     {
1483         Bignum tmp = bigmod(z, publicKey->curve->w.n);
1484         freebn(z);
1485         z = tmp;
1486     }
1487
1488     /* Calculate signature */
1489     {
1490         Bignum w, x, u1, u2;
1491         struct ec_point *tmp;
1492
1493         w = modinv(s, publicKey->curve->w.n);
1494         if (!w) {
1495             freebn(z);
1496             return 0;
1497         }
1498         u1 = modmul(z, w, publicKey->curve->w.n);
1499         u2 = modmul(r, w, publicKey->curve->w.n);
1500         freebn(w);
1501
1502         tmp = ecp_summul(u1, u2, publicKey);
1503         freebn(u1);
1504         freebn(u2);
1505         if (!tmp) {
1506             freebn(z);
1507             return 0;
1508         }
1509
1510         x = bigmod(tmp->x, publicKey->curve->w.n);
1511         ec_point_free(tmp);
1512
1513         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
1514         freebn(x);
1515     }
1516
1517     freebn(z);
1518
1519     return valid;
1520 }
1521
1522 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
1523                         const unsigned char *data, const int dataLen,
1524                         Bignum *r, Bignum *s)
1525 {
1526     unsigned char digest[20];
1527     int z_bits, n_bits;
1528     Bignum z, k;
1529     struct ec_point *kG;
1530
1531     *r = NULL;
1532     *s = NULL;
1533
1534     if (curve->type != EC_WEIERSTRASS) {
1535         return;
1536     }
1537
1538     /* z = left most bitlen(curve->n) of data */
1539     z = bignum_from_bytes(data, dataLen);
1540     n_bits = bignum_bitcount(curve->w.n);
1541     z_bits = bignum_bitcount(z);
1542     if (z_bits > n_bits)
1543     {
1544         Bignum tmp;
1545         tmp = bignum_rshift(z, z_bits - n_bits);
1546         freebn(z);
1547         z = tmp;
1548     }
1549
1550     /* Generate k between 1 and curve->n, using the same deterministic
1551      * k generation system we use for conventional DSA. */
1552     SHA_Simple(data, dataLen, digest);
1553     k = dss_gen_k("ECDSA deterministic k generator", curve->w.n, privateKey,
1554                   digest, sizeof(digest));
1555
1556     kG = ecp_mul(&curve->w.G, k);
1557     if (!kG) {
1558         freebn(z);
1559         freebn(k);
1560         return;
1561     }
1562
1563     /* r = kG.x mod n */
1564     *r = bigmod(kG->x, curve->w.n);
1565     ec_point_free(kG);
1566
1567     /* s = (z + r * priv)/k mod n */
1568     {
1569         Bignum rPriv, zMod, first, firstMod, kInv;
1570         rPriv = modmul(*r, privateKey, curve->w.n);
1571         zMod = bigmod(z, curve->w.n);
1572         freebn(z);
1573         first = bigadd(rPriv, zMod);
1574         freebn(rPriv);
1575         freebn(zMod);
1576         firstMod = bigmod(first, curve->w.n);
1577         freebn(first);
1578         kInv = modinv(k, curve->w.n);
1579         freebn(k);
1580         if (!kInv) {
1581             freebn(firstMod);
1582             freebn(*r);
1583             return;
1584         }
1585         *s = modmul(firstMod, kInv, curve->w.n);
1586         freebn(firstMod);
1587         freebn(kInv);
1588     }
1589 }
1590
1591 /* ----------------------------------------------------------------------
1592  * Misc functions
1593  */
1594
1595 static void getstring(const char **data, int *datalen,
1596                       const char **p, int *length)
1597 {
1598     *p = NULL;
1599     if (*datalen < 4)
1600         return;
1601     *length = toint(GET_32BIT(*data));
1602     if (*length < 0)
1603         return;
1604     *datalen -= 4;
1605     *data += 4;
1606     if (*datalen < *length)
1607         return;
1608     *p = *data;
1609     *data += *length;
1610     *datalen -= *length;
1611 }
1612
1613 static Bignum getmp(const char **data, int *datalen)
1614 {
1615     const char *p;
1616     int length;
1617
1618     getstring(data, datalen, &p, &length);
1619     if (!p)
1620         return NULL;
1621     if (p[0] & 0x80)
1622         return NULL;                   /* negative mp */
1623     return bignum_from_bytes((unsigned char *)p, length);
1624 }
1625
1626 static Bignum getmp_le(const char **data, int *datalen)
1627 {
1628     const char *p;
1629     int length;
1630
1631     getstring(data, datalen, &p, &length);
1632     if (!p)
1633         return NULL;
1634     return bignum_from_bytes_le((const unsigned char *)p, length);
1635 }
1636
1637 static int decodepoint_ed(const char *p, int length, struct ec_point *point)
1638 {
1639     /* Got some conversion to do, first read in the y co-ord */
1640     int negative;
1641
1642     point->y = bignum_from_bytes_le((const unsigned char*)p, length);
1643     if ((unsigned)bignum_bitcount(point->y) > point->curve->fieldBits) {
1644         freebn(point->y);
1645         point->y = NULL;
1646         return 0;
1647     }
1648     /* Read x bit and then reset it */
1649     negative = bignum_bit(point->y, point->curve->fieldBits - 1);
1650     bignum_set_bit(point->y, point->curve->fieldBits - 1, 0);
1651
1652     /* Get the x from the y */
1653     point->x = ecp_edx(point->curve, point->y);
1654     if (!point->x) {
1655         freebn(point->y);
1656         point->y = NULL;
1657         return 0;
1658     }
1659     if (negative) {
1660         Bignum tmp = modsub(point->curve->p, point->x, point->curve->p);
1661         freebn(point->x);
1662         point->x = tmp;
1663     }
1664
1665     /* Verify the point is on the curve */
1666     if (!ec_point_verify(point)) {
1667         freebn(point->x);
1668         point->x = NULL;
1669         freebn(point->y);
1670         point->y = NULL;
1671         return 0;
1672     }
1673
1674     return 1;
1675 }
1676
1677 static int decodepoint(const char *p, int length, struct ec_point *point)
1678 {
1679     if (point->curve->type == EC_EDWARDS) {
1680         return decodepoint_ed(p, length, point);
1681     }
1682
1683     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
1684         return 0;
1685     /* Skip compression flag */
1686     ++p;
1687     --length;
1688     /* The two values must be equal length */
1689     if (length % 2 != 0) {
1690         point->x = NULL;
1691         point->y = NULL;
1692         point->z = NULL;
1693         return 0;
1694     }
1695     length = length / 2;
1696     point->x = bignum_from_bytes((const unsigned char *)p, length);
1697     p += length;
1698     point->y = bignum_from_bytes((const unsigned char *)p, length);
1699     point->z = NULL;
1700
1701     /* Verify the point is on the curve */
1702     if (!ec_point_verify(point)) {
1703         freebn(point->x);
1704         point->x = NULL;
1705         freebn(point->y);
1706         point->y = NULL;
1707         return 0;
1708     }
1709
1710     return 1;
1711 }
1712
1713 static int getmppoint(const char **data, int *datalen, struct ec_point *point)
1714 {
1715     const char *p;
1716     int length;
1717
1718     getstring(data, datalen, &p, &length);
1719     if (!p) return 0;
1720     return decodepoint(p, length, point);
1721 }
1722
1723 /* ----------------------------------------------------------------------
1724  * Exposed ECDSA interface
1725  */
1726
1727 struct ecsign_extra {
1728     struct ec_curve *(*curve)(void);
1729     const struct ssh_hash *hash;
1730
1731     /* These fields are used by the OpenSSH PEM format importer/exporter */
1732     const unsigned char *oid;
1733     int oidlen;
1734 };
1735
1736 static void ecdsa_freekey(void *key)
1737 {
1738     struct ec_key *ec = (struct ec_key *) key;
1739     if (!ec) return;
1740
1741     if (ec->publicKey.x)
1742         freebn(ec->publicKey.x);
1743     if (ec->publicKey.y)
1744         freebn(ec->publicKey.y);
1745     if (ec->publicKey.z)
1746         freebn(ec->publicKey.z);
1747     if (ec->privateKey)
1748         freebn(ec->privateKey);
1749     sfree(ec);
1750 }
1751
1752 static void *ecdsa_newkey(const struct ssh_signkey *self,
1753                           const char *data, int len)
1754 {
1755     const struct ecsign_extra *extra =
1756         (const struct ecsign_extra *)self->extra;
1757     const char *p;
1758     int slen;
1759     struct ec_key *ec;
1760     struct ec_curve *curve;
1761
1762     getstring(&data, &len, &p, &slen);
1763
1764     if (!p) {
1765         return NULL;
1766     }
1767     curve = extra->curve();
1768     assert(curve->type == EC_WEIERSTRASS || curve->type == EC_EDWARDS);
1769
1770     /* Curve name is duplicated for Weierstrass form */
1771     if (curve->type == EC_WEIERSTRASS) {
1772         getstring(&data, &len, &p, &slen);
1773         if (!p) return NULL;
1774         if (!match_ssh_id(slen, p, curve->name)) return NULL;
1775     }
1776
1777     ec = snew(struct ec_key);
1778
1779     ec->signalg = self;
1780     ec->publicKey.curve = curve;
1781     ec->publicKey.infinity = 0;
1782     ec->publicKey.x = NULL;
1783     ec->publicKey.y = NULL;
1784     ec->publicKey.z = NULL;
1785     ec->privateKey = NULL;
1786     if (!getmppoint(&data, &len, &ec->publicKey)) {
1787         ecdsa_freekey(ec);
1788         return NULL;
1789     }
1790
1791     if (!ec->publicKey.x || !ec->publicKey.y ||
1792         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
1793         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
1794     {
1795         ecdsa_freekey(ec);
1796         ec = NULL;
1797     }
1798
1799     return ec;
1800 }
1801
1802 static char *ecdsa_fmtkey(void *key)
1803 {
1804     struct ec_key *ec = (struct ec_key *) key;
1805     char *p;
1806     int len, i, pos, nibbles;
1807     static const char hex[] = "0123456789abcdef";
1808     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
1809         return NULL;
1810
1811     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
1812     if (ec->publicKey.curve->name)
1813         len += strlen(ec->publicKey.curve->name); /* Curve name */
1814     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
1815     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
1816     p = snewn(len, char);
1817
1818     pos = 0;
1819     if (ec->publicKey.curve->name)
1820         pos += sprintf(p + pos, "%s,", ec->publicKey.curve->name);
1821     pos += sprintf(p + pos, "0x");
1822     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
1823     if (nibbles < 1)
1824         nibbles = 1;
1825     for (i = nibbles; i--;) {
1826         p[pos++] =
1827             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
1828     }
1829     pos += sprintf(p + pos, ",0x");
1830     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
1831     if (nibbles < 1)
1832         nibbles = 1;
1833     for (i = nibbles; i--;) {
1834         p[pos++] =
1835             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
1836     }
1837     p[pos] = '\0';
1838     return p;
1839 }
1840
1841 static unsigned char *ecdsa_public_blob(void *key, int *len)
1842 {
1843     struct ec_key *ec = (struct ec_key *) key;
1844     int pointlen, bloblen, fullnamelen, namelen;
1845     int i;
1846     unsigned char *blob, *p;
1847
1848     fullnamelen = strlen(ec->signalg->name);
1849
1850     if (ec->publicKey.curve->type == EC_EDWARDS) {
1851         /* Edwards compressed form "ssh-ed25519" point y[:-1] + x[0:1] */
1852
1853         pointlen = ec->publicKey.curve->fieldBits / 8;
1854
1855         /* Can't handle this in our loop */
1856         if (pointlen < 2) return NULL;
1857
1858         bloblen = 4 + fullnamelen + 4 + pointlen;
1859         blob = snewn(bloblen, unsigned char);
1860
1861         p = blob;
1862         PUT_32BIT(p, fullnamelen);
1863         p += 4;
1864         memcpy(p, ec->signalg->name, fullnamelen);
1865         p += fullnamelen;
1866         PUT_32BIT(p, pointlen);
1867         p += 4;
1868
1869         /* Unset last bit of y and set first bit of x in its place */
1870         for (i = 0; i < pointlen - 1; ++i) {
1871             *p++ = bignum_byte(ec->publicKey.y, i);
1872         }
1873         /* Unset last bit of y and set first bit of x in its place */
1874         *p = bignum_byte(ec->publicKey.y, i) & 0x7f;
1875         *p++ |= bignum_bit(ec->publicKey.x, 0) << 7;
1876     } else if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
1877         assert(ec->publicKey.curve->name);
1878         namelen = strlen(ec->publicKey.curve->name);
1879
1880         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1881
1882         /*
1883          * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
1884          */
1885         bloblen = 4 + fullnamelen + 4 + namelen + 4 + 1 + (pointlen * 2);
1886         blob = snewn(bloblen, unsigned char);
1887
1888         p = blob;
1889         PUT_32BIT(p, fullnamelen);
1890         p += 4;
1891         memcpy(p, ec->signalg->name, fullnamelen);
1892         p += fullnamelen;
1893         PUT_32BIT(p, namelen);
1894         p += 4;
1895         memcpy(p, ec->publicKey.curve->name, namelen);
1896         p += namelen;
1897         PUT_32BIT(p, (2 * pointlen) + 1);
1898         p += 4;
1899         *p++ = 0x04;
1900         for (i = pointlen; i--;) {
1901             *p++ = bignum_byte(ec->publicKey.x, i);
1902         }
1903         for (i = pointlen; i--;) {
1904             *p++ = bignum_byte(ec->publicKey.y, i);
1905         }
1906     } else {
1907         return NULL;
1908     }
1909
1910     assert(p == blob + bloblen);
1911     *len = bloblen;
1912
1913     return blob;
1914 }
1915
1916 static unsigned char *ecdsa_private_blob(void *key, int *len)
1917 {
1918     struct ec_key *ec = (struct ec_key *) key;
1919     int keylen, bloblen;
1920     int i;
1921     unsigned char *blob, *p;
1922
1923     if (!ec->privateKey) return NULL;
1924
1925     if (ec->publicKey.curve->type == EC_EDWARDS) {
1926         /* Unsigned */
1927         keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
1928     } else {
1929         /* Signed */
1930         keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
1931     }
1932
1933     /*
1934      * mpint privateKey. Total 4 + keylen.
1935      */
1936     bloblen = 4 + keylen;
1937     blob = snewn(bloblen, unsigned char);
1938
1939     p = blob;
1940     PUT_32BIT(p, keylen);
1941     p += 4;
1942     if (ec->publicKey.curve->type == EC_EDWARDS) {
1943         /* Little endian */
1944         for (i = 0; i < keylen; ++i)
1945             *p++ = bignum_byte(ec->privateKey, i);
1946     } else {
1947         for (i = keylen; i--;)
1948             *p++ = bignum_byte(ec->privateKey, i);
1949     }
1950
1951     assert(p == blob + bloblen);
1952     *len = bloblen;
1953     return blob;
1954 }
1955
1956 static void *ecdsa_createkey(const struct ssh_signkey *self,
1957                              const unsigned char *pub_blob, int pub_len,
1958                              const unsigned char *priv_blob, int priv_len)
1959 {
1960     struct ec_key *ec;
1961     struct ec_point *publicKey;
1962     const char *pb = (const char *) priv_blob;
1963
1964     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) pub_blob, pub_len);
1965     if (!ec) {
1966         return NULL;
1967     }
1968
1969     if (ec->publicKey.curve->type != EC_WEIERSTRASS
1970         && ec->publicKey.curve->type != EC_EDWARDS) {
1971         ecdsa_freekey(ec);
1972         return NULL;
1973     }
1974
1975     if (ec->publicKey.curve->type == EC_EDWARDS) {
1976         ec->privateKey = getmp_le(&pb, &priv_len);
1977     } else {
1978         ec->privateKey = getmp(&pb, &priv_len);
1979     }
1980     if (!ec->privateKey) {
1981         ecdsa_freekey(ec);
1982         return NULL;
1983     }
1984
1985     /* Check that private key generates public key */
1986     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
1987
1988     if (!publicKey ||
1989         bignum_cmp(publicKey->x, ec->publicKey.x) ||
1990         bignum_cmp(publicKey->y, ec->publicKey.y))
1991     {
1992         ecdsa_freekey(ec);
1993         ec = NULL;
1994     }
1995     ec_point_free(publicKey);
1996
1997     return ec;
1998 }
1999
2000 static void *ed25519_openssh_createkey(const struct ssh_signkey *self,
2001                                        const unsigned char **blob, int *len)
2002 {
2003     struct ec_key *ec;
2004     struct ec_point *publicKey;
2005     const char *p, *q;
2006     int plen, qlen;
2007
2008     getstring((const char**)blob, len, &p, &plen);
2009     if (!p)
2010     {
2011         return NULL;
2012     }
2013
2014     ec = snew(struct ec_key);
2015
2016     ec->signalg = self;
2017     ec->publicKey.curve = ec_ed25519();
2018     ec->publicKey.infinity = 0;
2019     ec->privateKey = NULL;
2020     ec->publicKey.x = NULL;
2021     ec->publicKey.z = NULL;
2022     ec->publicKey.y = NULL;
2023
2024     if (!decodepoint_ed(p, plen, &ec->publicKey))
2025     {
2026         ecdsa_freekey(ec);
2027         return NULL;
2028     }
2029
2030     getstring((const char**)blob, len, &q, &qlen);
2031     if (!q)
2032         return NULL;
2033     if (qlen != 64)
2034         return NULL;
2035
2036     ec->privateKey = bignum_from_bytes_le((const unsigned char *)q, 32);
2037
2038     /* Check that private key generates public key */
2039     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2040
2041     if (!publicKey ||
2042         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2043         bignum_cmp(publicKey->y, ec->publicKey.y))
2044     {
2045         ecdsa_freekey(ec);
2046         ec = NULL;
2047     }
2048     ec_point_free(publicKey);
2049
2050     /* The OpenSSH format for ed25519 private keys also for some
2051      * reason encodes an extra copy of the public key in the second
2052      * half of the secret-key string. Check that that's present and
2053      * correct as well, otherwise the key we think we've imported
2054      * won't behave identically to the way OpenSSH would have treated
2055      * it. */
2056     if (plen != 32 || 0 != memcmp(q + 32, p, 32)) {
2057         ecdsa_freekey(ec);
2058         return NULL;
2059     }
2060
2061     return ec;
2062 }
2063
2064 static int ed25519_openssh_fmtkey(void *key, unsigned char *blob, int len)
2065 {
2066     struct ec_key *ec = (struct ec_key *) key;
2067
2068     int pointlen;
2069     int keylen;
2070     int bloblen;
2071     int i;
2072
2073     if (ec->publicKey.curve->type != EC_EDWARDS) {
2074         return 0;
2075     }
2076
2077     pointlen = (bignum_bitcount(ec->publicKey.y) + 7) / 8;
2078     keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2079     bloblen = 4 + pointlen + 4 + keylen + pointlen;
2080
2081     if (bloblen > len)
2082         return bloblen;
2083
2084     /* Encode the public point */
2085     PUT_32BIT(blob, pointlen);
2086     blob += 4;
2087
2088     for (i = 0; i < pointlen - 1; ++i) {
2089          *blob++ = bignum_byte(ec->publicKey.y, i);
2090     }
2091     /* Unset last bit of y and set first bit of x in its place */
2092     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2093     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2094
2095     PUT_32BIT(blob, keylen + pointlen);
2096     blob += 4;
2097     for (i = 0; i < keylen; ++i) {
2098          *blob++ = bignum_byte(ec->privateKey, i);
2099     }
2100     /* Now encode an extra copy of the public point as the second half
2101      * of the private key string, as the OpenSSH format for some
2102      * reason requires */
2103     for (i = 0; i < pointlen - 1; ++i) {
2104          *blob++ = bignum_byte(ec->publicKey.y, i);
2105     }
2106     /* Unset last bit of y and set first bit of x in its place */
2107     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2108     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2109
2110     return bloblen;
2111 }
2112
2113 static void *ecdsa_openssh_createkey(const struct ssh_signkey *self,
2114                                      const unsigned char **blob, int *len)
2115 {
2116     const struct ecsign_extra *extra =
2117         (const struct ecsign_extra *)self->extra;
2118     const char **b = (const char **) blob;
2119     const char *p;
2120     int slen;
2121     struct ec_key *ec;
2122     struct ec_curve *curve;
2123     struct ec_point *publicKey;
2124
2125     getstring(b, len, &p, &slen);
2126
2127     if (!p) {
2128         return NULL;
2129     }
2130     curve = extra->curve();
2131     assert(curve->type == EC_WEIERSTRASS);
2132
2133     ec = snew(struct ec_key);
2134
2135     ec->signalg = self;
2136     ec->publicKey.curve = curve;
2137     ec->publicKey.infinity = 0;
2138     ec->publicKey.x = NULL;
2139     ec->publicKey.y = NULL;
2140     ec->publicKey.z = NULL;
2141     if (!getmppoint(b, len, &ec->publicKey)) {
2142         ecdsa_freekey(ec);
2143         return NULL;
2144     }
2145     ec->privateKey = NULL;
2146
2147     if (!ec->publicKey.x || !ec->publicKey.y ||
2148         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2149         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2150     {
2151         ecdsa_freekey(ec);
2152         return NULL;
2153     }
2154
2155     ec->privateKey = getmp(b, len);
2156     if (ec->privateKey == NULL)
2157     {
2158         ecdsa_freekey(ec);
2159         return NULL;
2160     }
2161
2162     /* Now check that the private key makes the public key */
2163     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2164     if (!publicKey)
2165     {
2166         ecdsa_freekey(ec);
2167         return NULL;
2168     }
2169
2170     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
2171         bignum_cmp(ec->publicKey.y, publicKey->y))
2172     {
2173         /* Private key doesn't make the public key on the given curve */
2174         ecdsa_freekey(ec);
2175         ec_point_free(publicKey);
2176         return NULL;
2177     }
2178
2179     ec_point_free(publicKey);
2180
2181     return ec;
2182 }
2183
2184 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
2185 {
2186     struct ec_key *ec = (struct ec_key *) key;
2187
2188     int pointlen;
2189     int namelen;
2190     int bloblen;
2191     int i;
2192
2193     if (ec->publicKey.curve->type != EC_WEIERSTRASS) {
2194         return 0;
2195     }
2196
2197     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2198     namelen = strlen(ec->publicKey.curve->name);
2199     bloblen =
2200         4 + namelen /* <LEN> nistpXXX */
2201         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
2202         + ssh2_bignum_length(ec->privateKey);
2203
2204     if (bloblen > len)
2205         return bloblen;
2206
2207     bloblen = 0;
2208
2209     PUT_32BIT(blob+bloblen, namelen);
2210     bloblen += 4;
2211     memcpy(blob+bloblen, ec->publicKey.curve->name, namelen);
2212     bloblen += namelen;
2213
2214     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
2215     bloblen += 4;
2216     blob[bloblen++] = 0x04;
2217     for (i = pointlen; i--; )
2218         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
2219     for (i = pointlen; i--; )
2220         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
2221
2222     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2223     PUT_32BIT(blob+bloblen, pointlen);
2224     bloblen += 4;
2225     for (i = pointlen; i--; )
2226         blob[bloblen++] = bignum_byte(ec->privateKey, i);
2227
2228     return bloblen;
2229 }
2230
2231 static int ecdsa_pubkey_bits(const struct ssh_signkey *self,
2232                              const void *blob, int len)
2233 {
2234     struct ec_key *ec;
2235     int ret;
2236
2237     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) blob, len);
2238     if (!ec)
2239         return -1;
2240     ret = ec->publicKey.curve->fieldBits;
2241     ecdsa_freekey(ec);
2242
2243     return ret;
2244 }
2245
2246 static int ecdsa_verifysig(void *key, const char *sig, int siglen,
2247                            const char *data, int datalen)
2248 {
2249     struct ec_key *ec = (struct ec_key *) key;
2250     const struct ecsign_extra *extra =
2251         (const struct ecsign_extra *)ec->signalg->extra;
2252     const char *p;
2253     int slen;
2254     int digestLen;
2255     int ret;
2256
2257     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2258         return 0;
2259
2260     /* Check the signature starts with the algorithm name */
2261     getstring(&sig, &siglen, &p, &slen);
2262     if (!p) {
2263         return 0;
2264     }
2265     if (!match_ssh_id(slen, p, ec->signalg->name)) {
2266         return 0;
2267     }
2268
2269     getstring(&sig, &siglen, &p, &slen);
2270     if (ec->publicKey.curve->type == EC_EDWARDS) {
2271         struct ec_point *r;
2272         Bignum s, h;
2273
2274         /* Check that the signature is two times the length of a point */
2275         if (slen != (ec->publicKey.curve->fieldBits / 8) * 2) {
2276             return 0;
2277         }
2278
2279         /* Check it's the 256 bit field so that SHA512 is the correct hash */
2280         if (ec->publicKey.curve->fieldBits != 256) {
2281             return 0;
2282         }
2283
2284         /* Get the signature */
2285         r = ec_point_new(ec->publicKey.curve, NULL, NULL, NULL, 0);
2286         if (!r) {
2287             return 0;
2288         }
2289         if (!decodepoint(p, ec->publicKey.curve->fieldBits / 8, r)) {
2290             ec_point_free(r);
2291             return 0;
2292         }
2293         s = bignum_from_bytes_le((unsigned char*)p + (ec->publicKey.curve->fieldBits / 8),
2294                                  ec->publicKey.curve->fieldBits / 8);
2295
2296         /* Get the hash of the encoded value of R + encoded value of pk + message */
2297         {
2298             int i, pointlen;
2299             unsigned char b;
2300             unsigned char digest[512 / 8];
2301             SHA512_State hs;
2302             SHA512_Init(&hs);
2303
2304             /* Add encoded r (no need to encode it again, it was in the signature) */
2305             SHA512_Bytes(&hs, p, ec->publicKey.curve->fieldBits / 8);
2306
2307             /* Encode pk and add it */
2308             pointlen = ec->publicKey.curve->fieldBits / 8;
2309             for (i = 0; i < pointlen - 1; ++i) {
2310                 b = bignum_byte(ec->publicKey.y, i);
2311                 SHA512_Bytes(&hs, &b, 1);
2312             }
2313             /* Unset last bit of y and set first bit of x in its place */
2314             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
2315             b |= bignum_bit(ec->publicKey.x, 0) << 7;
2316             SHA512_Bytes(&hs, &b, 1);
2317
2318             /* Add the message itself */
2319             SHA512_Bytes(&hs, data, datalen);
2320
2321             /* Get the hash */
2322             SHA512_Final(&hs, digest);
2323
2324             /* Convert to Bignum */
2325             h = bignum_from_bytes_le(digest, sizeof(digest));
2326         }
2327
2328         /* Verify sB == r + h*publicKey */
2329         {
2330             struct ec_point *lhs, *rhs, *tmp;
2331
2332             /* lhs = sB */
2333             lhs = ecp_mul(&ec->publicKey.curve->e.B, s);
2334             freebn(s);
2335             if (!lhs) {
2336                 ec_point_free(r);
2337                 freebn(h);
2338                 return 0;
2339             }
2340
2341             /* rhs = r + h*publicKey */
2342             tmp = ecp_mul(&ec->publicKey, h);
2343             freebn(h);
2344             if (!tmp) {
2345                 ec_point_free(lhs);
2346                 ec_point_free(r);
2347                 return 0;
2348             }
2349             rhs = ecp_add(r, tmp, 0);
2350             ec_point_free(r);
2351             ec_point_free(tmp);
2352             if (!rhs) {
2353                 ec_point_free(lhs);
2354                 return 0;
2355             }
2356
2357             /* Check the point is the same */
2358             ret = !bignum_cmp(lhs->x, rhs->x);
2359             if (ret) {
2360                 ret = !bignum_cmp(lhs->y, rhs->y);
2361                 if (ret) {
2362                     ret = 1;
2363                 }
2364             }
2365             ec_point_free(lhs);
2366             ec_point_free(rhs);
2367         }
2368     } else {
2369         Bignum r, s;
2370         unsigned char digest[512 / 8];
2371         void *hashctx;
2372
2373         r = getmp(&p, &slen);
2374         if (!r) return 0;
2375         s = getmp(&p, &slen);
2376         if (!s) {
2377             freebn(r);
2378             return 0;
2379         }
2380
2381         digestLen = extra->hash->hlen;
2382         assert(digestLen <= sizeof(digest));
2383         hashctx = extra->hash->init();
2384         extra->hash->bytes(hashctx, data, datalen);
2385         extra->hash->final(hashctx, digest);
2386
2387         /* Verify the signature */
2388         ret = _ecdsa_verify(&ec->publicKey, digest, digestLen, r, s);
2389
2390         freebn(r);
2391         freebn(s);
2392     }
2393
2394     return ret;
2395 }
2396
2397 static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
2398                                  int *siglen)
2399 {
2400     struct ec_key *ec = (struct ec_key *) key;
2401     const struct ecsign_extra *extra =
2402         (const struct ecsign_extra *)ec->signalg->extra;
2403     unsigned char digest[512 / 8];
2404     int digestLen;
2405     Bignum r = NULL, s = NULL;
2406     unsigned char *buf, *p;
2407     int rlen, slen, namelen;
2408     int i;
2409
2410     if (!ec->privateKey || !ec->publicKey.curve) {
2411         return NULL;
2412     }
2413
2414     if (ec->publicKey.curve->type == EC_EDWARDS) {
2415         struct ec_point *rp;
2416         int pointlen = ec->publicKey.curve->fieldBits / 8;
2417
2418         /* hash = H(sk) (where hash creates 2 * fieldBits)
2419          * b = fieldBits
2420          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
2421          * r = H(h[b/8:b/4] + m)
2422          * R = rB
2423          * S = (r + H(encodepoint(R) + encodepoint(pk) + m) * a) % l */
2424         {
2425             unsigned char hash[512/8];
2426             unsigned char b;
2427             Bignum a;
2428             SHA512_State hs;
2429             SHA512_Init(&hs);
2430
2431             for (i = 0; i < pointlen; ++i) {
2432                 unsigned char b = (unsigned char)bignum_byte(ec->privateKey, i);
2433                 SHA512_Bytes(&hs, &b, 1);
2434             }
2435
2436             SHA512_Final(&hs, hash);
2437
2438             /* The second part is simply turning the hash into a
2439              * Bignum, however the 2^(b-2) bit *must* be set, and the
2440              * bottom 3 bits *must* not be */
2441             hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
2442             hash[31] &= 0x7f; /* Unset above (b-2) */
2443             hash[31] |= 0x40; /* Set 2^(b-2) */
2444             /* Chop off the top part and convert to int */
2445             a = bignum_from_bytes_le(hash, 32);
2446
2447             SHA512_Init(&hs);
2448             SHA512_Bytes(&hs,
2449                          hash+(ec->publicKey.curve->fieldBits / 8),
2450                          (ec->publicKey.curve->fieldBits / 4)
2451                          - (ec->publicKey.curve->fieldBits / 8));
2452             SHA512_Bytes(&hs, data, datalen);
2453             SHA512_Final(&hs, hash);
2454
2455             r = bignum_from_bytes_le(hash, 512/8);
2456             rp = ecp_mul(&ec->publicKey.curve->e.B, r);
2457             if (!rp) {
2458                 freebn(r);
2459                 freebn(a);
2460                 return NULL;
2461             }
2462
2463             /* Now calculate s */
2464             SHA512_Init(&hs);
2465             /* Encode the point R */
2466             for (i = 0; i < pointlen - 1; ++i) {
2467                 b = bignum_byte(rp->y, i);
2468                 SHA512_Bytes(&hs, &b, 1);
2469             }
2470             /* Unset last bit of y and set first bit of x in its place */
2471             b = bignum_byte(rp->y, i) & 0x7f;
2472             b |= bignum_bit(rp->x, 0) << 7;
2473             SHA512_Bytes(&hs, &b, 1);
2474
2475             /* Encode the point pk */
2476             for (i = 0; i < pointlen - 1; ++i) {
2477                 b = bignum_byte(ec->publicKey.y, i);
2478                 SHA512_Bytes(&hs, &b, 1);
2479             }
2480             /* Unset last bit of y and set first bit of x in its place */
2481             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
2482             b |= bignum_bit(ec->publicKey.x, 0) << 7;
2483             SHA512_Bytes(&hs, &b, 1);
2484
2485             /* Add the message */
2486             SHA512_Bytes(&hs, data, datalen);
2487             SHA512_Final(&hs, hash);
2488
2489             {
2490                 Bignum tmp, tmp2;
2491
2492                 tmp = bignum_from_bytes_le(hash, 512/8);
2493                 tmp2 = modmul(tmp, a, ec->publicKey.curve->e.l);
2494                 freebn(a);
2495                 freebn(tmp);
2496                 tmp = bigadd(r, tmp2);
2497                 freebn(r);
2498                 freebn(tmp2);
2499                 s = bigmod(tmp, ec->publicKey.curve->e.l);
2500                 freebn(tmp);
2501             }
2502         }
2503
2504         /* Format the output */
2505         namelen = strlen(ec->signalg->name);
2506         *siglen = 4+namelen+4+((ec->publicKey.curve->fieldBits / 8)*2);
2507         buf = snewn(*siglen, unsigned char);
2508         p = buf;
2509         PUT_32BIT(p, namelen);
2510         p += 4;
2511         memcpy(p, ec->signalg->name, namelen);
2512         p += namelen;
2513         PUT_32BIT(p, ((ec->publicKey.curve->fieldBits / 8)*2));
2514         p += 4;
2515
2516         /* Encode the point */
2517         pointlen = ec->publicKey.curve->fieldBits / 8;
2518         for (i = 0; i < pointlen - 1; ++i) {
2519             *p++ = bignum_byte(rp->y, i);
2520         }
2521         /* Unset last bit of y and set first bit of x in its place */
2522         *p = bignum_byte(rp->y, i) & 0x7f;
2523         *p++ |= bignum_bit(rp->x, 0) << 7;
2524         ec_point_free(rp);
2525
2526         /* Encode the int */
2527         for (i = 0; i < pointlen; ++i) {
2528             *p++ = bignum_byte(s, i);
2529         }
2530         freebn(s);
2531     } else {
2532         void *hashctx;
2533
2534         digestLen = extra->hash->hlen;
2535         assert(digestLen <= sizeof(digest));
2536         hashctx = extra->hash->init();
2537         extra->hash->bytes(hashctx, data, datalen);
2538         extra->hash->final(hashctx, digest);
2539
2540         /* Do the signature */
2541         _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
2542         if (!r || !s) {
2543             if (r) freebn(r);
2544             if (s) freebn(s);
2545             return NULL;
2546         }
2547
2548         rlen = (bignum_bitcount(r) + 8) / 8;
2549         slen = (bignum_bitcount(s) + 8) / 8;
2550
2551         namelen = strlen(ec->signalg->name);
2552
2553         /* Format the output */
2554         *siglen = 8+namelen+rlen+slen+8;
2555         buf = snewn(*siglen, unsigned char);
2556         p = buf;
2557         PUT_32BIT(p, namelen);
2558         p += 4;
2559         memcpy(p, ec->signalg->name, namelen);
2560         p += namelen;
2561         PUT_32BIT(p, rlen + slen + 8);
2562         p += 4;
2563         PUT_32BIT(p, rlen);
2564         p += 4;
2565         for (i = rlen; i--;)
2566             *p++ = bignum_byte(r, i);
2567         PUT_32BIT(p, slen);
2568         p += 4;
2569         for (i = slen; i--;)
2570             *p++ = bignum_byte(s, i);
2571
2572         freebn(r);
2573         freebn(s);
2574     }
2575
2576     return buf;
2577 }
2578
2579 const struct ecsign_extra sign_extra_ed25519 = {
2580     ec_ed25519, NULL,
2581     NULL, 0,
2582 };
2583 const struct ssh_signkey ssh_ecdsa_ed25519 = {
2584     ecdsa_newkey,
2585     ecdsa_freekey,
2586     ecdsa_fmtkey,
2587     ecdsa_public_blob,
2588     ecdsa_private_blob,
2589     ecdsa_createkey,
2590     ed25519_openssh_createkey,
2591     ed25519_openssh_fmtkey,
2592     2 /* point, private exponent */,
2593     ecdsa_pubkey_bits,
2594     ecdsa_verifysig,
2595     ecdsa_sign,
2596     "ssh-ed25519",
2597     "ssh-ed25519",
2598     &sign_extra_ed25519,
2599 };
2600
2601 /* OID: 1.2.840.10045.3.1.7 (ansiX9p256r1) */
2602 static const unsigned char nistp256_oid[] = {
2603     0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07
2604 };
2605 const struct ecsign_extra sign_extra_nistp256 = {
2606     ec_p256, &ssh_sha256,
2607     nistp256_oid, lenof(nistp256_oid),
2608 };
2609 const struct ssh_signkey ssh_ecdsa_nistp256 = {
2610     ecdsa_newkey,
2611     ecdsa_freekey,
2612     ecdsa_fmtkey,
2613     ecdsa_public_blob,
2614     ecdsa_private_blob,
2615     ecdsa_createkey,
2616     ecdsa_openssh_createkey,
2617     ecdsa_openssh_fmtkey,
2618     3 /* curve name, point, private exponent */,
2619     ecdsa_pubkey_bits,
2620     ecdsa_verifysig,
2621     ecdsa_sign,
2622     "ecdsa-sha2-nistp256",
2623     "ecdsa-sha2-nistp256",
2624     &sign_extra_nistp256,
2625 };
2626
2627 /* OID: 1.3.132.0.34 (secp384r1) */
2628 static const unsigned char nistp384_oid[] = {
2629     0x2b, 0x81, 0x04, 0x00, 0x22
2630 };
2631 const struct ecsign_extra sign_extra_nistp384 = {
2632     ec_p384, &ssh_sha384,
2633     nistp384_oid, lenof(nistp384_oid),
2634 };
2635 const struct ssh_signkey ssh_ecdsa_nistp384 = {
2636     ecdsa_newkey,
2637     ecdsa_freekey,
2638     ecdsa_fmtkey,
2639     ecdsa_public_blob,
2640     ecdsa_private_blob,
2641     ecdsa_createkey,
2642     ecdsa_openssh_createkey,
2643     ecdsa_openssh_fmtkey,
2644     3 /* curve name, point, private exponent */,
2645     ecdsa_pubkey_bits,
2646     ecdsa_verifysig,
2647     ecdsa_sign,
2648     "ecdsa-sha2-nistp384",
2649     "ecdsa-sha2-nistp384",
2650     &sign_extra_nistp384,
2651 };
2652
2653 /* OID: 1.3.132.0.35 (secp521r1) */
2654 static const unsigned char nistp521_oid[] = {
2655     0x2b, 0x81, 0x04, 0x00, 0x23
2656 };
2657 const struct ecsign_extra sign_extra_nistp521 = {
2658     ec_p521, &ssh_sha512,
2659     nistp521_oid, lenof(nistp521_oid),
2660 };
2661 const struct ssh_signkey ssh_ecdsa_nistp521 = {
2662     ecdsa_newkey,
2663     ecdsa_freekey,
2664     ecdsa_fmtkey,
2665     ecdsa_public_blob,
2666     ecdsa_private_blob,
2667     ecdsa_createkey,
2668     ecdsa_openssh_createkey,
2669     ecdsa_openssh_fmtkey,
2670     3 /* curve name, point, private exponent */,
2671     ecdsa_pubkey_bits,
2672     ecdsa_verifysig,
2673     ecdsa_sign,
2674     "ecdsa-sha2-nistp521",
2675     "ecdsa-sha2-nistp521",
2676     &sign_extra_nistp521,
2677 };
2678
2679 /* ----------------------------------------------------------------------
2680  * Exposed ECDH interface
2681  */
2682
2683 struct eckex_extra {
2684     struct ec_curve *(*curve)(void);
2685 };
2686
2687 static Bignum ecdh_calculate(const Bignum private,
2688                              const struct ec_point *public)
2689 {
2690     struct ec_point *p;
2691     Bignum ret;
2692     p = ecp_mul(public, private);
2693     if (!p) return NULL;
2694     ret = p->x;
2695     p->x = NULL;
2696
2697     if (p->curve->type == EC_MONTGOMERY) {
2698         /*
2699          * Endianness-swap. The Curve25519 algorithm definition
2700          * assumes you were doing your computation in arrays of 32
2701          * little-endian bytes, and now specifies that you take your
2702          * final one of those and convert it into a bignum in
2703          * _network_ byte order, i.e. big-endian.
2704          *
2705          * In particular, the spec says, you convert the _whole_ 32
2706          * bytes into a bignum. That is, on the rare occasions that
2707          * p->x has come out with the most significant 8 bits zero, we
2708          * have to imagine that being represented by a 32-byte string
2709          * with the last byte being zero, so that has to be converted
2710          * into an SSH-2 bignum with the _low_ byte zero, i.e. a
2711          * multiple of 256.
2712          */
2713         int i;
2714         int bytes = (p->curve->fieldBits+7) / 8;
2715         unsigned char *byteorder = snewn(bytes, unsigned char);
2716         for (i = 0; i < bytes; ++i) {
2717             byteorder[i] = bignum_byte(ret, i);
2718         }
2719         freebn(ret);
2720         ret = bignum_from_bytes(byteorder, bytes);
2721         smemclr(byteorder, bytes);
2722         sfree(byteorder);
2723     }
2724
2725     ec_point_free(p);
2726     return ret;
2727 }
2728
2729 const char *ssh_ecdhkex_curve_textname(const struct ssh_kex *kex)
2730 {
2731     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
2732     struct ec_curve *curve = extra->curve();
2733     return curve->textname;
2734 }
2735
2736 void *ssh_ecdhkex_newkey(const struct ssh_kex *kex)
2737 {
2738     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
2739     struct ec_curve *curve;
2740     struct ec_key *key;
2741     struct ec_point *publicKey;
2742
2743     curve = extra->curve();
2744
2745     key = snew(struct ec_key);
2746
2747     key->signalg = NULL;
2748     key->publicKey.curve = curve;
2749
2750     if (curve->type == EC_MONTGOMERY) {
2751         unsigned char bytes[32] = {0};
2752         int i;
2753
2754         for (i = 0; i < sizeof(bytes); ++i)
2755         {
2756             bytes[i] = (unsigned char)random_byte();
2757         }
2758         bytes[0] &= 248;
2759         bytes[31] &= 127;
2760         bytes[31] |= 64;
2761         key->privateKey = bignum_from_bytes(bytes, sizeof(bytes));
2762         for (i = 0; i < sizeof(bytes); ++i)
2763         {
2764             ((volatile char*)bytes)[i] = 0;
2765         }
2766         if (!key->privateKey) {
2767             sfree(key);
2768             return NULL;
2769         }
2770         publicKey = ecp_mul(&key->publicKey.curve->m.G, key->privateKey);
2771         if (!publicKey) {
2772             freebn(key->privateKey);
2773             sfree(key);
2774             return NULL;
2775         }
2776         key->publicKey.x = publicKey->x;
2777         key->publicKey.y = publicKey->y;
2778         key->publicKey.z = NULL;
2779         sfree(publicKey);
2780     } else {
2781         key->privateKey = bignum_random_in_range(One, key->publicKey.curve->w.n);
2782         if (!key->privateKey) {
2783             sfree(key);
2784             return NULL;
2785         }
2786         publicKey = ecp_mul(&key->publicKey.curve->w.G, key->privateKey);
2787         if (!publicKey) {
2788             freebn(key->privateKey);
2789             sfree(key);
2790             return NULL;
2791         }
2792         key->publicKey.x = publicKey->x;
2793         key->publicKey.y = publicKey->y;
2794         key->publicKey.z = NULL;
2795         sfree(publicKey);
2796     }
2797     return key;
2798 }
2799
2800 char *ssh_ecdhkex_getpublic(void *key, int *len)
2801 {
2802     struct ec_key *ec = (struct ec_key*)key;
2803     char *point, *p;
2804     int i;
2805     int pointlen;
2806
2807     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2808
2809     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2810         *len = 1 + pointlen * 2;
2811     } else {
2812         *len = pointlen;
2813     }
2814     point = (char*)snewn(*len, char);
2815
2816     p = point;
2817     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2818         *p++ = 0x04;
2819         for (i = pointlen; i--;) {
2820             *p++ = bignum_byte(ec->publicKey.x, i);
2821         }
2822         for (i = pointlen; i--;) {
2823             *p++ = bignum_byte(ec->publicKey.y, i);
2824         }
2825     } else {
2826         for (i = 0; i < pointlen; ++i) {
2827             *p++ = bignum_byte(ec->publicKey.x, i);
2828         }
2829     }
2830
2831     return point;
2832 }
2833
2834 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
2835 {
2836     struct ec_key *ec = (struct ec_key*) key;
2837     struct ec_point remote;
2838     Bignum ret;
2839
2840     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2841         remote.curve = ec->publicKey.curve;
2842         remote.infinity = 0;
2843         if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
2844             return NULL;
2845         }
2846     } else {
2847         /* Point length has to be the same length */
2848         if (remoteKeyLen != (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8) {
2849             return NULL;
2850         }
2851
2852         remote.curve = ec->publicKey.curve;
2853         remote.infinity = 0;
2854         remote.x = bignum_from_bytes_le((unsigned char*)remoteKey, remoteKeyLen);
2855         remote.y = NULL;
2856         remote.z = NULL;
2857     }
2858
2859     ret = ecdh_calculate(ec->privateKey, &remote);
2860     if (remote.x) freebn(remote.x);
2861     if (remote.y) freebn(remote.y);
2862     return ret;
2863 }
2864
2865 void ssh_ecdhkex_freekey(void *key)
2866 {
2867     ecdsa_freekey(key);
2868 }
2869
2870 static const struct eckex_extra kex_extra_curve25519 = { ec_curve25519 };
2871 static const struct ssh_kex ssh_ec_kex_curve25519 = {
2872     "curve25519-sha256@libssh.org", NULL, KEXTYPE_ECDH,
2873     &ssh_sha256, &kex_extra_curve25519,
2874 };
2875
2876 const struct eckex_extra kex_extra_nistp256 = { ec_p256 };
2877 static const struct ssh_kex ssh_ec_kex_nistp256 = {
2878     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH,
2879     &ssh_sha256, &kex_extra_nistp256,
2880 };
2881
2882 const struct eckex_extra kex_extra_nistp384 = { ec_p384 };
2883 static const struct ssh_kex ssh_ec_kex_nistp384 = {
2884     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH,
2885     &ssh_sha384, &kex_extra_nistp384,
2886 };
2887
2888 const struct eckex_extra kex_extra_nistp521 = { ec_p521 };
2889 static const struct ssh_kex ssh_ec_kex_nistp521 = {
2890     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH,
2891     &ssh_sha512, &kex_extra_nistp521,
2892 };
2893
2894 static const struct ssh_kex *const ec_kex_list[] = {
2895     &ssh_ec_kex_curve25519,
2896     &ssh_ec_kex_nistp256,
2897     &ssh_ec_kex_nistp384,
2898     &ssh_ec_kex_nistp521,
2899 };
2900
2901 const struct ssh_kexes ssh_ecdh_kex = {
2902     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
2903     ec_kex_list
2904 };
2905
2906 /* ----------------------------------------------------------------------
2907  * Helper functions for finding key algorithms and returning auxiliary
2908  * data.
2909  */
2910
2911 const struct ssh_signkey *ec_alg_by_oid(int len, const void *oid,
2912                                         const struct ec_curve **curve)
2913 {
2914     static const struct ssh_signkey *algs_with_oid[] = {
2915         &ssh_ecdsa_nistp256,
2916         &ssh_ecdsa_nistp384,
2917         &ssh_ecdsa_nistp521,
2918     };
2919     int i;
2920
2921     for (i = 0; i < lenof(algs_with_oid); i++) {
2922         const struct ssh_signkey *alg = algs_with_oid[i];
2923         const struct ecsign_extra *extra =
2924             (const struct ecsign_extra *)alg->extra;
2925         if (len == extra->oidlen && !memcmp(oid, extra->oid, len)) {
2926             *curve = extra->curve();
2927             return alg;
2928         }
2929     }
2930     return NULL;
2931 }
2932
2933 const unsigned char *ec_alg_oid(const struct ssh_signkey *alg,
2934                                 int *oidlen)
2935 {
2936     const struct ecsign_extra *extra = (const struct ecsign_extra *)alg->extra;
2937     *oidlen = extra->oidlen;
2938     return extra->oid;
2939 }
2940
2941 const int ec_nist_alg_and_curve_by_bits(int bits,
2942                                         const struct ec_curve **curve,
2943                                         const struct ssh_signkey **alg)
2944 {
2945     switch (bits) {
2946       case 256: *alg = &ssh_ecdsa_nistp256; break;
2947       case 384: *alg = &ssh_ecdsa_nistp384; break;
2948       case 521: *alg = &ssh_ecdsa_nistp521; break;
2949       default: return FALSE;
2950     }
2951     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
2952     return TRUE;
2953 }
2954
2955 const int ec_ed_alg_and_curve_by_bits(int bits,
2956                                       const struct ec_curve **curve,
2957                                       const struct ssh_signkey **alg)
2958 {
2959     switch (bits) {
2960       case 256: *alg = &ssh_ecdsa_ed25519; break;
2961       default: return FALSE;
2962     }
2963     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
2964     return TRUE;
2965 }