]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
first pass
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  *
5  * NOTE: Only curves on prime field are handled by the maths functions
6  *       in Weierstrass form using Jacobian co-ordinates.
7  *
8  *       Montgomery form curves are supported for DH. (Curve25519)
9  *
10  *       Edwards form curves are supported for DSA. (Ed25519)
11  */
12
13 /*
14  * References:
15  *
16  * Elliptic curves in SSH are specified in RFC 5656:
17  *   http://tools.ietf.org/html/rfc5656
18  *
19  * That specification delegates details of public key formatting and a
20  * lot of underlying mechanism to SEC 1:
21  *   http://www.secg.org/sec1-v2.pdf
22  *
23  * Montgomery maths from:
24  * Handbook of elliptic and hyperelliptic curve cryptography, Chapter 13
25  *   http://cs.ucsb.edu/~koc/ccs130h/2013/EllipticHyperelliptic-CohenFrey.pdf
26  *
27  * Curve25519 spec from libssh (with reference to other things in the
28  * libssh code):
29  *   https://git.libssh.org/users/aris/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt
30  *
31  * Edwards DSA:
32  *   http://ed25519.cr.yp.to/ed25519-20110926.pdf
33  */
34
35 #include <stdlib.h>
36 #include <assert.h>
37
38 #include "ssh.h"
39
40 /* ----------------------------------------------------------------------
41  * Elliptic curve definitions
42  */
43
44 static void initialise_wcurve(struct ec_curve *curve, int bits,
45                               const unsigned char *p,
46                               const unsigned char *a, const unsigned char *b,
47                               const unsigned char *n, const unsigned char *Gx,
48                               const unsigned char *Gy)
49 {
50     int length = bits / 8;
51     if (bits % 8) ++length;
52
53     curve->type = EC_WEIERSTRASS;
54
55     curve->fieldBits = bits;
56     curve->p = bignum_from_bytes(p, length);
57
58     /* Curve co-efficients */
59     curve->w.a = bignum_from_bytes(a, length);
60     curve->w.b = bignum_from_bytes(b, length);
61
62     /* Group order and generator */
63     curve->w.n = bignum_from_bytes(n, length);
64     curve->w.G.x = bignum_from_bytes(Gx, length);
65     curve->w.G.y = bignum_from_bytes(Gy, length);
66     curve->w.G.curve = curve;
67     curve->w.G.infinity = 0;
68 }
69
70 static void initialise_mcurve(struct ec_curve *curve, int bits,
71                               const unsigned char *p,
72                               const unsigned char *a, const unsigned char *b,
73                               const unsigned char *Gx)
74 {
75     int length = bits / 8;
76     if (bits % 8) ++length;
77
78     curve->type = EC_MONTGOMERY;
79
80     curve->fieldBits = bits;
81     curve->p = bignum_from_bytes(p, length);
82
83     /* Curve co-efficients */
84     curve->m.a = bignum_from_bytes(a, length);
85     curve->m.b = bignum_from_bytes(b, length);
86
87     /* Generator */
88     curve->m.G.x = bignum_from_bytes(Gx, length);
89     curve->m.G.y = NULL;
90     curve->m.G.z = NULL;
91     curve->m.G.curve = curve;
92     curve->m.G.infinity = 0;
93 }
94
95 static void initialise_ecurve(struct ec_curve *curve, int bits,
96                               const unsigned char *p,
97                               const unsigned char *l, const unsigned char *d,
98                               const unsigned char *Bx, const unsigned char *By)
99 {
100     int length = bits / 8;
101     if (bits % 8) ++length;
102
103     curve->type = EC_EDWARDS;
104
105     curve->fieldBits = bits;
106     curve->p = bignum_from_bytes(p, length);
107
108     /* Curve co-efficients */
109     curve->e.l = bignum_from_bytes(l, length);
110     curve->e.d = bignum_from_bytes(d, length);
111
112     /* Group order and generator */
113     curve->e.B.x = bignum_from_bytes(Bx, length);
114     curve->e.B.y = bignum_from_bytes(By, length);
115     curve->e.B.curve = curve;
116     curve->e.B.infinity = 0;
117 }
118
119 static struct ec_curve *ec_p256(void)
120 {
121     static struct ec_curve curve = { 0 };
122     static unsigned char initialised = 0;
123
124     if (!initialised)
125     {
126         static const unsigned char p[] = {
127             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
128             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
129             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
130             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
131         };
132         static const unsigned char a[] = {
133             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
134             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
135             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
136             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
137         };
138         static const unsigned char b[] = {
139             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
140             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
141             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
142             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
143         };
144         static const unsigned char n[] = {
145             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
146             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
147             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
148             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
149         };
150         static const unsigned char Gx[] = {
151             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
152             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
153             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
154             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
155         };
156         static const unsigned char Gy[] = {
157             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
158             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
159             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
160             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
161         };
162
163         initialise_wcurve(&curve, 256, p, a, b, n, Gx, Gy);
164         curve.textname = curve.name = "nistp256";
165
166         /* Now initialised, no need to do it again */
167         initialised = 1;
168     }
169
170     return &curve;
171 }
172
173 static struct ec_curve *ec_p384(void)
174 {
175     static struct ec_curve curve = { 0 };
176     static unsigned char initialised = 0;
177
178     if (!initialised)
179     {
180         static const unsigned char p[] = {
181             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
182             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
183             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
184             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
185             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
186             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
187         };
188         static const unsigned char a[] = {
189             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
190             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
191             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
192             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
193             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
194             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
195         };
196         static const unsigned char b[] = {
197             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
198             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
199             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
200             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
201             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
202             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
203         };
204         static const unsigned char n[] = {
205             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
206             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
207             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
208             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
209             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
210             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
211         };
212         static const unsigned char Gx[] = {
213             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
214             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
215             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
216             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
217             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
218             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
219         };
220         static const unsigned char Gy[] = {
221             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
222             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
223             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
224             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
225             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
226             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
227         };
228
229         initialise_wcurve(&curve, 384, p, a, b, n, Gx, Gy);
230         curve.textname = curve.name = "nistp384";
231
232         /* Now initialised, no need to do it again */
233         initialised = 1;
234     }
235
236     return &curve;
237 }
238
239 static struct ec_curve *ec_p521(void)
240 {
241     static struct ec_curve curve = { 0 };
242     static unsigned char initialised = 0;
243
244     if (!initialised)
245     {
246         static const unsigned char p[] = {
247             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
248             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
249             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
250             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
251             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
252             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
253             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
254             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
255             0xff, 0xff
256         };
257         static const unsigned char a[] = {
258             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
259             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
260             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
261             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
262             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
263             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
264             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
265             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
266             0xff, 0xfc
267         };
268         static const unsigned char b[] = {
269             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
270             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
271             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
272             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
273             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
274             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
275             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
276             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
277             0x3f, 0x00
278         };
279         static const unsigned char n[] = {
280             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
281             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
282             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
283             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
284             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
285             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
286             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
287             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
288             0x64, 0x09
289         };
290         static const unsigned char Gx[] = {
291             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
292             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
293             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
294             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
295             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
296             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
297             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
298             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
299             0xbd, 0x66
300         };
301         static const unsigned char Gy[] = {
302             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
303             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
304             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
305             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
306             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
307             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
308             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
309             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
310             0x66, 0x50
311         };
312
313         initialise_wcurve(&curve, 521, p, a, b, n, Gx, Gy);
314         curve.textname = curve.name = "nistp521";
315
316         /* Now initialised, no need to do it again */
317         initialised = 1;
318     }
319
320     return &curve;
321 }
322
323 static struct ec_curve *ec_curve25519(void)
324 {
325     static struct ec_curve curve = { 0 };
326     static unsigned char initialised = 0;
327
328     if (!initialised)
329     {
330         static const unsigned char p[] = {
331             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
332             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
333             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
334             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
335         };
336         static const unsigned char a[] = {
337             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
338             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
339             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
340             0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x6d, 0x06
341         };
342         static const unsigned char b[] = {
343             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
344             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
345             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
346             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01
347         };
348         static const unsigned char gx[32] = {
349             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
350             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
351             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
352             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09
353         };
354
355         initialise_mcurve(&curve, 256, p, a, b, gx);
356         /* This curve doesn't need a name, because it's never used in
357          * any format that embeds the curve name */
358         curve.name = NULL;
359         curve.textname = "Curve25519";
360
361         /* Now initialised, no need to do it again */
362         initialised = 1;
363     }
364
365     return &curve;
366 }
367
368 static struct ec_curve *ec_ed25519(void)
369 {
370     static struct ec_curve curve = { 0 };
371     static unsigned char initialised = 0;
372
373     if (!initialised)
374     {
375         static const unsigned char q[] = {
376             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
377             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
378             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
379             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
380         };
381         static const unsigned char l[32] = {
382             0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
383             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
384             0x14, 0xde, 0xf9, 0xde, 0xa2, 0xf7, 0x9c, 0xd6,
385             0x58, 0x12, 0x63, 0x1a, 0x5c, 0xf5, 0xd3, 0xed
386         };
387         static const unsigned char d[32] = {
388             0x52, 0x03, 0x6c, 0xee, 0x2b, 0x6f, 0xfe, 0x73,
389             0x8c, 0xc7, 0x40, 0x79, 0x77, 0x79, 0xe8, 0x98,
390             0x00, 0x70, 0x0a, 0x4d, 0x41, 0x41, 0xd8, 0xab,
391             0x75, 0xeb, 0x4d, 0xca, 0x13, 0x59, 0x78, 0xa3
392         };
393         static const unsigned char Bx[32] = {
394             0x21, 0x69, 0x36, 0xd3, 0xcd, 0x6e, 0x53, 0xfe,
395             0xc0, 0xa4, 0xe2, 0x31, 0xfd, 0xd6, 0xdc, 0x5c,
396             0x69, 0x2c, 0xc7, 0x60, 0x95, 0x25, 0xa7, 0xb2,
397             0xc9, 0x56, 0x2d, 0x60, 0x8f, 0x25, 0xd5, 0x1a
398         };
399         static const unsigned char By[32] = {
400             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
401             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
402             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
403             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x58
404         };
405
406         /* This curve doesn't need a name, because it's never used in
407          * any format that embeds the curve name */
408         curve.name = NULL;
409
410         initialise_ecurve(&curve, 256, q, l, d, Bx, By);
411         curve.textname = "Ed25519";
412
413         /* Now initialised, no need to do it again */
414         initialised = 1;
415     }
416
417     return &curve;
418 }
419
420 /* Return 1 if a is -3 % p, otherwise return 0
421  * This is used because there are some maths optimisations */
422 static int ec_aminus3(const struct ec_curve *curve)
423 {
424     int ret;
425     Bignum _p;
426
427     if (curve->type != EC_WEIERSTRASS) {
428         return 0;
429     }
430
431     _p = bignum_add_long(curve->w.a, 3);
432
433     ret = !bignum_cmp(curve->p, _p);
434     freebn(_p);
435     return ret;
436 }
437
438 /* ----------------------------------------------------------------------
439  * Elliptic curve field maths
440  */
441
442 static Bignum ecf_add(const Bignum a, const Bignum b,
443                       const struct ec_curve *curve)
444 {
445     Bignum a1, b1, ab, ret;
446
447     a1 = bigmod(a, curve->p);
448     b1 = bigmod(b, curve->p);
449
450     ab = bigadd(a1, b1);
451     freebn(a1);
452     freebn(b1);
453
454     ret = bigmod(ab, curve->p);
455     freebn(ab);
456
457     return ret;
458 }
459
460 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
461 {
462     return modmul(a, a, curve->p);
463 }
464
465 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
466 {
467     Bignum ret, tmp;
468
469     /* Double */
470     tmp = bignum_lshift(a, 1);
471
472     /* Add itself (i.e. treble) */
473     ret = bigadd(tmp, a);
474     freebn(tmp);
475
476     /* Normalise */
477     while (bignum_cmp(ret, curve->p) >= 0)
478     {
479         tmp = bigsub(ret, curve->p);
480         assert(tmp);
481         freebn(ret);
482         ret = tmp;
483     }
484
485     return ret;
486 }
487
488 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
489 {
490     Bignum ret = bignum_lshift(a, 1);
491     if (bignum_cmp(ret, curve->p) >= 0)
492     {
493         Bignum tmp = bigsub(ret, curve->p);
494         assert(tmp);
495         freebn(ret);
496         return tmp;
497     }
498     else
499     {
500         return ret;
501     }
502 }
503
504 /* ----------------------------------------------------------------------
505  * Memory functions
506  */
507
508 void ec_point_free(struct ec_point *point)
509 {
510     if (point == NULL) return;
511     point->curve = 0;
512     if (point->x) freebn(point->x);
513     if (point->y) freebn(point->y);
514     if (point->z) freebn(point->z);
515     point->infinity = 0;
516     sfree(point);
517 }
518
519 static struct ec_point *ec_point_new(const struct ec_curve *curve,
520                                      const Bignum x, const Bignum y, const Bignum z,
521                                      unsigned char infinity)
522 {
523     struct ec_point *point = snewn(1, struct ec_point);
524     point->curve = curve;
525     point->x = x;
526     point->y = y;
527     point->z = z;
528     point->infinity = infinity ? 1 : 0;
529     return point;
530 }
531
532 static struct ec_point *ec_point_copy(const struct ec_point *a)
533 {
534     if (a == NULL) return NULL;
535     return ec_point_new(a->curve,
536                         a->x ? copybn(a->x) : NULL,
537                         a->y ? copybn(a->y) : NULL,
538                         a->z ? copybn(a->z) : NULL,
539                         a->infinity);
540 }
541
542 static int ec_point_verify(const struct ec_point *a)
543 {
544     if (a->infinity) {
545         return 1;
546     } else if (a->curve->type == EC_EDWARDS) {
547         /* Check y^2 - x^2 - 1 - d * x^2 * y^2 == 0 */
548         Bignum y2, x2, tmp, tmp2, tmp3;
549         int ret;
550
551         y2 = ecf_square(a->y, a->curve);
552         x2 = ecf_square(a->x, a->curve);
553         tmp = modmul(a->curve->e.d, x2, a->curve->p);
554         tmp2 = modmul(tmp, y2, a->curve->p);
555         freebn(tmp);
556         tmp = modsub(y2, x2, a->curve->p);
557         freebn(y2);
558         freebn(x2);
559         tmp3 = modsub(tmp, tmp2, a->curve->p);
560         freebn(tmp);
561         freebn(tmp2);
562         ret = !bignum_cmp(tmp3, One);
563         freebn(tmp3);
564         return ret;
565     } else if (a->curve->type == EC_WEIERSTRASS) {
566         /* Verify y^2 = x^3 + ax + b */
567         int ret = 0;
568
569         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
570
571         Bignum Three = bignum_from_long(3);
572
573         lhs = modmul(a->y, a->y, a->curve->p);
574
575         /* This uses montgomery multiplication to optimise */
576         x3 = modpow(a->x, Three, a->curve->p);
577         freebn(Three);
578         ax = modmul(a->curve->w.a, a->x, a->curve->p);
579         x3ax = bigadd(x3, ax);
580         freebn(x3); x3 = NULL;
581         freebn(ax); ax = NULL;
582         x3axm = bigmod(x3ax, a->curve->p);
583         freebn(x3ax); x3ax = NULL;
584         x3axb = bigadd(x3axm, a->curve->w.b);
585         freebn(x3axm); x3axm = NULL;
586         rhs = bigmod(x3axb, a->curve->p);
587         freebn(x3axb);
588
589         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
590         freebn(lhs);
591         freebn(rhs);
592
593         return ret;
594     } else {
595         return 0;
596     }
597 }
598
599 /* ----------------------------------------------------------------------
600  * Elliptic curve point maths
601  */
602
603 /* Returns 1 on success and 0 on memory error */
604 static int ecp_normalise(struct ec_point *a)
605 {
606     if (!a) {
607         /* No point */
608         return 0;
609     }
610
611     if (a->infinity) {
612         /* Point is at infinity - i.e. normalised */
613         return 1;
614     }
615
616     if (a->curve->type == EC_WEIERSTRASS) {
617         /* In Jacobian Coordinates the triple (X, Y, Z) represents
618            the affine point (X / Z^2, Y / Z^3) */
619
620         Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
621
622         if (!a->x || !a->y) {
623             /* No point defined */
624             return 0;
625         } else if (!a->z) {
626             /* Already normalised */
627             return 1;
628         }
629
630         Z2 = ecf_square(a->z, a->curve);
631         Z2inv = modinv(Z2, a->curve->p);
632         if (!Z2inv) {
633             freebn(Z2);
634             return 0;
635         }
636         tx = modmul(a->x, Z2inv, a->curve->p);
637         freebn(Z2inv);
638
639         Z3 = modmul(Z2, a->z, a->curve->p);
640         freebn(Z2);
641         Z3inv = modinv(Z3, a->curve->p);
642         freebn(Z3);
643         if (!Z3inv) {
644             freebn(tx);
645             return 0;
646         }
647         ty = modmul(a->y, Z3inv, a->curve->p);
648         freebn(Z3inv);
649
650         freebn(a->x);
651         a->x = tx;
652         freebn(a->y);
653         a->y = ty;
654         freebn(a->z);
655         a->z = NULL;
656         return 1;
657     } else if (a->curve->type == EC_MONTGOMERY) {
658         /* In Montgomery (X : Z) represents the x co-ord (X / Z, ?) */
659
660         Bignum tmp, tmp2;
661
662         if (!a->x) {
663             /* No point defined */
664             return 0;
665         } else if (!a->z) {
666             /* Already normalised */
667             return 1;
668         }
669
670         tmp = modinv(a->z, a->curve->p);
671         if (!tmp) {
672             return 0;
673         }
674         tmp2 = modmul(a->x, tmp, a->curve->p);
675         freebn(tmp);
676
677         freebn(a->z);
678         a->z = NULL;
679         freebn(a->x);
680         a->x = tmp2;
681         return 1;
682     } else if (a->curve->type == EC_EDWARDS) {
683         /* Always normalised */
684         return 1;
685     } else {
686         return 0;
687     }
688 }
689
690 static struct ec_point *ecp_doublew(const struct ec_point *a, const int aminus3)
691 {
692     Bignum S, M, outx, outy, outz;
693
694     if (bignum_cmp(a->y, Zero) == 0)
695     {
696         /* Identity */
697         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
698     }
699
700     /* S = 4*X*Y^2 */
701     {
702         Bignum Y2, XY2, _2XY2;
703
704         Y2 = ecf_square(a->y, a->curve);
705         XY2 = modmul(a->x, Y2, a->curve->p);
706         freebn(Y2);
707
708         _2XY2 = ecf_double(XY2, a->curve);
709         freebn(XY2);
710         S = ecf_double(_2XY2, a->curve);
711         freebn(_2XY2);
712     }
713
714     /* Faster calculation if a = -3 */
715     if (aminus3) {
716         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
717         Bignum Z2, XpZ2, XmZ2, second;
718
719         if (a->z == NULL) {
720             Z2 = copybn(One);
721         } else {
722             Z2 = ecf_square(a->z, a->curve);
723         }
724
725         XpZ2 = ecf_add(a->x, Z2, a->curve);
726         XmZ2 = modsub(a->x, Z2, a->curve->p);
727         freebn(Z2);
728
729         second = modmul(XpZ2, XmZ2, a->curve->p);
730         freebn(XpZ2);
731         freebn(XmZ2);
732
733         M = ecf_treble(second, a->curve);
734         freebn(second);
735     } else {
736         /* M = 3*X^2 + a*Z^4 */
737         Bignum _3X2, X2, aZ4;
738
739         if (a->z == NULL) {
740             aZ4 = copybn(a->curve->w.a);
741         } else {
742             Bignum Z2, Z4;
743
744             Z2 = ecf_square(a->z, a->curve);
745             Z4 = ecf_square(Z2, a->curve);
746             freebn(Z2);
747             aZ4 = modmul(a->curve->w.a, Z4, a->curve->p);
748             freebn(Z4);
749         }
750
751         X2 = modmul(a->x, a->x, a->curve->p);
752         _3X2 = ecf_treble(X2, a->curve);
753         freebn(X2);
754         M = ecf_add(_3X2, aZ4, a->curve);
755         freebn(_3X2);
756         freebn(aZ4);
757     }
758
759     /* X' = M^2 - 2*S */
760     {
761         Bignum M2, _2S;
762
763         M2 = ecf_square(M, a->curve);
764         _2S = ecf_double(S, a->curve);
765         outx = modsub(M2, _2S, a->curve->p);
766         freebn(M2);
767         freebn(_2S);
768     }
769
770     /* Y' = M*(S - X') - 8*Y^4 */
771     {
772         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
773
774         SX = modsub(S, outx, a->curve->p);
775         freebn(S);
776         MSX = modmul(M, SX, a->curve->p);
777         freebn(SX);
778         freebn(M);
779         Y2 = ecf_square(a->y, a->curve);
780         Y4 = ecf_square(Y2, a->curve);
781         freebn(Y2);
782         Eight = bignum_from_long(8);
783         _8Y4 = modmul(Eight, Y4, a->curve->p);
784         freebn(Eight);
785         freebn(Y4);
786         outy = modsub(MSX, _8Y4, a->curve->p);
787         freebn(MSX);
788         freebn(_8Y4);
789     }
790
791     /* Z' = 2*Y*Z */
792     {
793         Bignum YZ;
794
795         if (a->z == NULL) {
796             YZ = copybn(a->y);
797         } else {
798             YZ = modmul(a->y, a->z, a->curve->p);
799         }
800
801         outz = ecf_double(YZ, a->curve);
802         freebn(YZ);
803     }
804
805     return ec_point_new(a->curve, outx, outy, outz, 0);
806 }
807
808 static struct ec_point *ecp_doublem(const struct ec_point *a)
809 {
810     Bignum z, outx, outz, xpz, xmz;
811
812     z = a->z;
813     if (!z) {
814         z = One;
815     }
816
817     /* 4xz = (x + z)^2 - (x - z)^2 */
818     {
819         Bignum tmp;
820
821         tmp = ecf_add(a->x, z, a->curve);
822         xpz = ecf_square(tmp, a->curve);
823         freebn(tmp);
824
825         tmp = modsub(a->x, z, a->curve->p);
826         xmz = ecf_square(tmp, a->curve);
827         freebn(tmp);
828     }
829
830     /* outx = (x + z)^2 * (x - z)^2 */
831     outx = modmul(xpz, xmz, a->curve->p);
832
833     /* outz = 4xz * ((x - z)^2 + ((A + 2) / 4)*4xz) */
834     {
835         Bignum _4xz, tmp, tmp2, tmp3;
836
837         tmp = bignum_from_long(2);
838         tmp2 = ecf_add(a->curve->m.a, tmp, a->curve);
839         freebn(tmp);
840
841         _4xz = modsub(xpz, xmz, a->curve->p);
842         freebn(xpz);
843         tmp = modmul(tmp2, _4xz, a->curve->p);
844         freebn(tmp2);
845
846         tmp2 = bignum_from_long(4);
847         tmp3 = modinv(tmp2, a->curve->p);
848         freebn(tmp2);
849         if (!tmp3) {
850             freebn(tmp);
851             freebn(_4xz);
852             freebn(outx);
853             freebn(xmz);
854             return NULL;
855         }
856         tmp2 = modmul(tmp, tmp3, a->curve->p);
857         freebn(tmp);
858         freebn(tmp3);
859
860         tmp = ecf_add(xmz, tmp2, a->curve);
861         freebn(xmz);
862         freebn(tmp2);
863         outz = modmul(_4xz, tmp, a->curve->p);
864         freebn(_4xz);
865         freebn(tmp);
866     }
867
868     return ec_point_new(a->curve, outx, NULL, outz, 0);
869 }
870
871 /* Forward declaration for Edwards curve doubling */
872 static struct ec_point *ecp_add(const struct ec_point *a,
873                                 const struct ec_point *b,
874                                 const int aminus3);
875
876 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
877 {
878     if (a->infinity)
879     {
880         /* Identity */
881         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
882     }
883
884     if (a->curve->type == EC_EDWARDS)
885     {
886         return ecp_add(a, a, aminus3);
887     }
888     else if (a->curve->type == EC_WEIERSTRASS)
889     {
890         return ecp_doublew(a, aminus3);
891     }
892     else
893     {
894         return ecp_doublem(a);
895     }
896 }
897
898 static struct ec_point *ecp_addw(const struct ec_point *a,
899                                  const struct ec_point *b,
900                                  const int aminus3)
901 {
902     Bignum U1, U2, S1, S2, outx, outy, outz;
903
904     /* U1 = X1*Z2^2 */
905     /* S1 = Y1*Z2^3 */
906     if (b->z) {
907         Bignum Z2, Z3;
908
909         Z2 = ecf_square(b->z, a->curve);
910         U1 = modmul(a->x, Z2, a->curve->p);
911         Z3 = modmul(Z2, b->z, a->curve->p);
912         freebn(Z2);
913         S1 = modmul(a->y, Z3, a->curve->p);
914         freebn(Z3);
915     } else {
916         U1 = copybn(a->x);
917         S1 = copybn(a->y);
918     }
919
920     /* U2 = X2*Z1^2 */
921     /* S2 = Y2*Z1^3 */
922     if (a->z) {
923         Bignum Z2, Z3;
924
925         Z2 = ecf_square(a->z, b->curve);
926         U2 = modmul(b->x, Z2, b->curve->p);
927         Z3 = modmul(Z2, a->z, b->curve->p);
928         freebn(Z2);
929         S2 = modmul(b->y, Z3, b->curve->p);
930         freebn(Z3);
931     } else {
932         U2 = copybn(b->x);
933         S2 = copybn(b->y);
934     }
935
936     /* Check if multiplying by self */
937     if (bignum_cmp(U1, U2) == 0)
938     {
939         freebn(U1);
940         freebn(U2);
941         if (bignum_cmp(S1, S2) == 0)
942         {
943             freebn(S1);
944             freebn(S2);
945             return ecp_double(a, aminus3);
946         }
947         else
948         {
949             freebn(S1);
950             freebn(S2);
951             /* Infinity */
952             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
953         }
954     }
955
956     {
957         Bignum H, R, UH2, H3;
958
959         /* H = U2 - U1 */
960         H = modsub(U2, U1, a->curve->p);
961         freebn(U2);
962
963         /* R = S2 - S1 */
964         R = modsub(S2, S1, a->curve->p);
965         freebn(S2);
966
967         /* X3 = R^2 - H^3 - 2*U1*H^2 */
968         {
969             Bignum R2, H2, _2UH2, first;
970
971             H2 = ecf_square(H, a->curve);
972             UH2 = modmul(U1, H2, a->curve->p);
973             freebn(U1);
974             H3 = modmul(H2, H, a->curve->p);
975             freebn(H2);
976             R2 = ecf_square(R, a->curve);
977             _2UH2 = ecf_double(UH2, a->curve);
978             first = modsub(R2, H3, a->curve->p);
979             freebn(R2);
980             outx = modsub(first, _2UH2, a->curve->p);
981             freebn(first);
982             freebn(_2UH2);
983         }
984
985         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
986         {
987             Bignum RUH2mX, UH2mX, SH3;
988
989             UH2mX = modsub(UH2, outx, a->curve->p);
990             freebn(UH2);
991             RUH2mX = modmul(R, UH2mX, a->curve->p);
992             freebn(UH2mX);
993             freebn(R);
994             SH3 = modmul(S1, H3, a->curve->p);
995             freebn(S1);
996             freebn(H3);
997
998             outy = modsub(RUH2mX, SH3, a->curve->p);
999             freebn(RUH2mX);
1000             freebn(SH3);
1001         }
1002
1003         /* Z3 = H*Z1*Z2 */
1004         if (a->z && b->z) {
1005             Bignum ZZ;
1006
1007             ZZ = modmul(a->z, b->z, a->curve->p);
1008             outz = modmul(H, ZZ, a->curve->p);
1009             freebn(H);
1010             freebn(ZZ);
1011         } else if (a->z) {
1012             outz = modmul(H, a->z, a->curve->p);
1013             freebn(H);
1014         } else if (b->z) {
1015             outz = modmul(H, b->z, a->curve->p);
1016             freebn(H);
1017         } else {
1018             outz = H;
1019         }
1020     }
1021
1022     return ec_point_new(a->curve, outx, outy, outz, 0);
1023 }
1024
1025 static struct ec_point *ecp_addm(const struct ec_point *a,
1026                                  const struct ec_point *b,
1027                                  const struct ec_point *base)
1028 {
1029     Bignum outx, outz, az, bz;
1030
1031     az = a->z;
1032     if (!az) {
1033         az = One;
1034     }
1035     bz = b->z;
1036     if (!bz) {
1037         bz = One;
1038     }
1039
1040     /* a-b is maintained at 1 due to Montgomery ladder implementation */
1041     /* Xa+b = Za-b * ((Xa - Za)*(Xb + Zb) + (Xa + Za)*(Xb - Zb))^2 */
1042     /* Za+b = Xa-b * ((Xa - Za)*(Xb + Zb) - (Xa + Za)*(Xb - Zb))^2 */
1043     {
1044         Bignum tmp, tmp2, tmp3, tmp4;
1045
1046         /* (Xa + Za) * (Xb - Zb) */
1047         tmp = ecf_add(a->x, az, a->curve);
1048         tmp2 = modsub(b->x, bz, a->curve->p);
1049         tmp3 = modmul(tmp, tmp2, a->curve->p);
1050         freebn(tmp);
1051         freebn(tmp2);
1052
1053         /* (Xa - Za) * (Xb + Zb) */
1054         tmp = modsub(a->x, az, a->curve->p);
1055         tmp2 = ecf_add(b->x, bz, a->curve);
1056         tmp4 = modmul(tmp, tmp2, a->curve->p);
1057         freebn(tmp);
1058         freebn(tmp2);
1059
1060         tmp = ecf_add(tmp3, tmp4, a->curve);
1061         outx = ecf_square(tmp, a->curve);
1062         freebn(tmp);
1063
1064         tmp = modsub(tmp3, tmp4, a->curve->p);
1065         freebn(tmp3);
1066         freebn(tmp4);
1067         tmp2 = ecf_square(tmp, a->curve);
1068         freebn(tmp);
1069         outz = modmul(base->x, tmp2, a->curve->p);
1070         freebn(tmp2);
1071     }
1072
1073     return ec_point_new(a->curve, outx, NULL, outz, 0);
1074 }
1075
1076 static struct ec_point *ecp_adde(const struct ec_point *a,
1077                                  const struct ec_point *b)
1078 {
1079     Bignum outx, outy, dmul;
1080
1081     /* outx = (a->x * b->y + b->x * a->y) /
1082      *        (1 + a->curve->e.d * a->x * b->x * a->y * b->y) */
1083     {
1084         Bignum tmp, tmp2, tmp3, tmp4;
1085
1086         tmp = modmul(a->x, b->y, a->curve->p);
1087         tmp2 = modmul(b->x, a->y, a->curve->p);
1088         tmp3 = ecf_add(tmp, tmp2, a->curve);
1089
1090         tmp4 = modmul(tmp, tmp2, a->curve->p);
1091         freebn(tmp);
1092         freebn(tmp2);
1093         dmul = modmul(a->curve->e.d, tmp4, a->curve->p);
1094         freebn(tmp4);
1095
1096         tmp = ecf_add(One, dmul, a->curve);
1097         tmp2 = modinv(tmp, a->curve->p);
1098         freebn(tmp);
1099         if (!tmp2)
1100         {
1101             freebn(tmp3);
1102             freebn(dmul);
1103             return NULL;
1104         }
1105
1106         outx = modmul(tmp3, tmp2, a->curve->p);
1107         freebn(tmp3);
1108         freebn(tmp2);
1109     }
1110
1111     /* outy = (a->y * b->y + a->x * b->x) /
1112      *        (1 - a->curve->e.d * a->x * b->x * a->y * b->y) */
1113     {
1114         Bignum tmp, tmp2, tmp3, tmp4;
1115
1116         tmp = modsub(One, dmul, a->curve->p);
1117         freebn(dmul);
1118
1119         tmp2 = modinv(tmp, a->curve->p);
1120         freebn(tmp);
1121         if (!tmp2)
1122         {
1123             freebn(outx);
1124             return NULL;
1125         }
1126
1127         tmp = modmul(a->y, b->y, a->curve->p);
1128         tmp3 = modmul(a->x, b->x, a->curve->p);
1129         tmp4 = ecf_add(tmp, tmp3, a->curve);
1130         freebn(tmp);
1131         freebn(tmp3);
1132
1133         outy = modmul(tmp4, tmp2, a->curve->p);
1134         freebn(tmp4);
1135         freebn(tmp2);
1136     }
1137
1138     return ec_point_new(a->curve, outx, outy, NULL, 0);
1139 }
1140
1141 static struct ec_point *ecp_add(const struct ec_point *a,
1142                                 const struct ec_point *b,
1143                                 const int aminus3)
1144 {
1145     if (a->curve != b->curve) {
1146         return NULL;
1147     }
1148
1149     /* Check if multiplying by infinity */
1150     if (a->infinity) return ec_point_copy(b);
1151     if (b->infinity) return ec_point_copy(a);
1152
1153     if (a->curve->type == EC_EDWARDS)
1154     {
1155         return ecp_adde(a, b);
1156     }
1157
1158     if (a->curve->type == EC_WEIERSTRASS)
1159     {
1160         return ecp_addw(a, b, aminus3);
1161     }
1162
1163     return NULL;
1164 }
1165
1166 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1167 {
1168     struct ec_point *A, *ret;
1169     int bits, i;
1170
1171     A = ec_point_copy(a);
1172     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1173
1174     bits = bignum_bitcount(b);
1175     for (i = 0; i < bits; ++i)
1176     {
1177         if (bignum_bit(b, i))
1178         {
1179             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1180             ec_point_free(ret);
1181             ret = tmp;
1182         }
1183         if (i+1 != bits)
1184         {
1185             struct ec_point *tmp = ecp_double(A, aminus3);
1186             ec_point_free(A);
1187             A = tmp;
1188         }
1189     }
1190
1191     ec_point_free(A);
1192     return ret;
1193 }
1194
1195 static struct ec_point *ecp_mulw(const struct ec_point *a, const Bignum b)
1196 {
1197     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1198
1199     if (!ecp_normalise(ret)) {
1200         ec_point_free(ret);
1201         return NULL;
1202     }
1203
1204     return ret;
1205 }
1206
1207 static struct ec_point *ecp_mule(const struct ec_point *a, const Bignum b)
1208 {
1209     int i;
1210     struct ec_point *ret;
1211
1212     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1213
1214     for (i = bignum_bitcount(b); i >= 0 && ret; --i)
1215     {
1216         {
1217             struct ec_point *tmp = ecp_double(ret, 0);
1218             ec_point_free(ret);
1219             ret = tmp;
1220         }
1221         if (ret && bignum_bit(b, i))
1222         {
1223             struct ec_point *tmp = ecp_add(ret, a, 0);
1224             ec_point_free(ret);
1225             ret = tmp;
1226         }
1227     }
1228
1229     return ret;
1230 }
1231
1232 static struct ec_point *ecp_mulm(const struct ec_point *p, const Bignum n)
1233 {
1234     struct ec_point *P1, *P2;
1235     int bits, i;
1236
1237     /* P1 <- P and P2 <- [2]P */
1238     P2 = ecp_double(p, 0);
1239     P1 = ec_point_copy(p);
1240
1241     /* for i = bits âˆ’ 2 down to 0 */
1242     bits = bignum_bitcount(n);
1243     for (i = bits - 2; i >= 0; --i)
1244     {
1245         if (!bignum_bit(n, i))
1246         {
1247             /* P2 <- P1 + P2 */
1248             struct ec_point *tmp = ecp_addm(P1, P2, p);
1249             ec_point_free(P2);
1250             P2 = tmp;
1251
1252             /* P1 <- [2]P1 */
1253             tmp = ecp_double(P1, 0);
1254             ec_point_free(P1);
1255             P1 = tmp;
1256         }
1257         else
1258         {
1259             /* P1 <- P1 + P2 */
1260             struct ec_point *tmp = ecp_addm(P1, P2, p);
1261             ec_point_free(P1);
1262             P1 = tmp;
1263
1264             /* P2 <- [2]P2 */
1265             tmp = ecp_double(P2, 0);
1266             ec_point_free(P2);
1267             P2 = tmp;
1268         }
1269     }
1270
1271     ec_point_free(P2);
1272
1273     if (!ecp_normalise(P1)) {
1274         ec_point_free(P1);
1275         return NULL;
1276     }
1277
1278     return P1;
1279 }
1280
1281 /* Not static because it is used by sshecdsag.c to generate a new key */
1282 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1283 {
1284     if (a->curve->type == EC_WEIERSTRASS) {
1285         return ecp_mulw(a, b);
1286     } else if (a->curve->type == EC_EDWARDS) {
1287         return ecp_mule(a, b);
1288     } else {
1289         return ecp_mulm(a, b);
1290     }
1291 }
1292
1293 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
1294                                    const struct ec_point *point)
1295 {
1296     struct ec_point *aG, *bP, *ret;
1297     int aminus3;
1298
1299     if (point->curve->type != EC_WEIERSTRASS) {
1300         return NULL;
1301     }
1302
1303     aminus3 = ec_aminus3(point->curve);
1304
1305     aG = ecp_mul_(&point->curve->w.G, a, aminus3);
1306     if (!aG) return NULL;
1307     bP = ecp_mul_(point, b, aminus3);
1308     if (!bP) {
1309         ec_point_free(aG);
1310         return NULL;
1311     }
1312
1313     ret = ecp_add(aG, bP, aminus3);
1314
1315     ec_point_free(aG);
1316     ec_point_free(bP);
1317
1318     if (!ecp_normalise(ret)) {
1319         ec_point_free(ret);
1320         return NULL;
1321     }
1322
1323     return ret;
1324 }
1325 static Bignum *ecp_edx(const struct ec_curve *curve, const Bignum y)
1326 {
1327     /* Get the x value on the given Edwards curve for a given y */
1328     Bignum x, xx;
1329
1330     /* xx = (y^2 - 1) / (d * y^2 + 1) */
1331     {
1332         Bignum tmp, tmp2, tmp3;
1333
1334         tmp = ecf_square(y, curve);
1335         tmp2 = modmul(curve->e.d, tmp, curve->p);
1336         tmp3 = ecf_add(tmp2, One, curve);
1337         freebn(tmp2);
1338         tmp2 = modinv(tmp3, curve->p);
1339         freebn(tmp3);
1340         if (!tmp2) {
1341             freebn(tmp);
1342             return NULL;
1343         }
1344
1345         tmp3 = modsub(tmp, One, curve->p);
1346         freebn(tmp);
1347         xx = modmul(tmp3, tmp2, curve->p);
1348         freebn(tmp3);
1349         freebn(tmp2);
1350     }
1351
1352     /* x = xx^((p + 3) / 8) */
1353     {
1354         Bignum tmp, tmp2;
1355
1356         tmp = bignum_add_long(curve->p, 3);
1357         tmp2 = bignum_rshift(tmp, 3);
1358         freebn(tmp);
1359         x = modpow(xx, tmp2, curve->p);
1360         freebn(tmp2);
1361     }
1362
1363     /* if x^2 - xx != 0 then x = x*(2^((p - 1) / 4)) */
1364     {
1365         Bignum tmp, tmp2;
1366
1367         tmp = ecf_square(x, curve);
1368         tmp2 = modsub(tmp, xx, curve->p);
1369         freebn(tmp);
1370         freebn(xx);
1371         if (bignum_cmp(tmp2, Zero)) {
1372             Bignum tmp3;
1373
1374             freebn(tmp2);
1375
1376             tmp = modsub(curve->p, One, curve->p);
1377             tmp2 = bignum_rshift(tmp, 2);
1378             freebn(tmp);
1379             tmp = bignum_from_long(2);
1380             tmp3 = modpow(tmp, tmp2, curve->p);
1381             freebn(tmp);
1382             freebn(tmp2);
1383
1384             tmp = modmul(x, tmp3, curve->p);
1385             freebn(x);
1386             freebn(tmp3);
1387             x = tmp;
1388         } else {
1389             freebn(tmp2);
1390         }
1391     }
1392
1393     /* if x % 2 != 0 then x = p - x */
1394     if (bignum_bit(x, 0)) {
1395         Bignum tmp = modsub(curve->p, x, curve->p);
1396         freebn(x);
1397         x = tmp;
1398     }
1399
1400     return x;
1401 }
1402
1403 /* ----------------------------------------------------------------------
1404  * Public point from private
1405  */
1406
1407 struct ec_point *ec_public(const Bignum privateKey, const struct ec_curve *curve)
1408 {
1409     if (curve->type == EC_WEIERSTRASS) {
1410         return ecp_mul(&curve->w.G, privateKey);
1411     } else if (curve->type == EC_EDWARDS) {
1412         /* hash = H(sk) (where hash creates 2 * fieldBits)
1413          * b = fieldBits
1414          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
1415          * publicKey = aB */
1416         struct ec_point *ret;
1417         unsigned char hash[512/8];
1418         Bignum a;
1419         int i, keylen;
1420         SHA512_State s;
1421         SHA512_Init(&s);
1422
1423         keylen = curve->fieldBits / 8;
1424         for (i = 0; i < keylen; ++i) {
1425             unsigned char b = bignum_byte(privateKey, i);
1426             SHA512_Bytes(&s, &b, 1);
1427         }
1428         SHA512_Final(&s, hash);
1429
1430         /* The second part is simply turning the hash into a Bignum,
1431          * however the 2^(b-2) bit *must* be set, and the bottom 3
1432          * bits *must* not be */
1433         hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
1434         hash[31] &= 0x7f; /* Unset above (b-2) */
1435         hash[31] |= 0x40; /* Set 2^(b-2) */
1436         /* Chop off the top part and convert to int */
1437         a = bignum_from_bytes_le(hash, 32);
1438
1439         ret = ecp_mul(&curve->e.B, a);
1440         freebn(a);
1441         return ret;
1442     } else {
1443         return NULL;
1444     }
1445 }
1446
1447 /* ----------------------------------------------------------------------
1448  * Basic sign and verify routines
1449  */
1450
1451 static int _ecdsa_verify(const struct ec_point *publicKey,
1452                          const unsigned char *data, const int dataLen,
1453                          const Bignum r, const Bignum s)
1454 {
1455     int z_bits, n_bits;
1456     Bignum z;
1457     int valid = 0;
1458
1459     if (publicKey->curve->type != EC_WEIERSTRASS) {
1460         return 0;
1461     }
1462
1463     /* Sanity checks */
1464     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->w.n) >= 0
1465         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->w.n) >= 0)
1466     {
1467         return 0;
1468     }
1469
1470     /* z = left most bitlen(curve->n) of data */
1471     z = bignum_from_bytes(data, dataLen);
1472     n_bits = bignum_bitcount(publicKey->curve->w.n);
1473     z_bits = bignum_bitcount(z);
1474     if (z_bits > n_bits)
1475     {
1476         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
1477         freebn(z);
1478         z = tmp;
1479     }
1480
1481     /* Ensure z in range of n */
1482     {
1483         Bignum tmp = bigmod(z, publicKey->curve->w.n);
1484         freebn(z);
1485         z = tmp;
1486     }
1487
1488     /* Calculate signature */
1489     {
1490         Bignum w, x, u1, u2;
1491         struct ec_point *tmp;
1492
1493         w = modinv(s, publicKey->curve->w.n);
1494         if (!w) {
1495             freebn(z);
1496             return 0;
1497         }
1498         u1 = modmul(z, w, publicKey->curve->w.n);
1499         u2 = modmul(r, w, publicKey->curve->w.n);
1500         freebn(w);
1501
1502         tmp = ecp_summul(u1, u2, publicKey);
1503         freebn(u1);
1504         freebn(u2);
1505         if (!tmp) {
1506             freebn(z);
1507             return 0;
1508         }
1509
1510         x = bigmod(tmp->x, publicKey->curve->w.n);
1511         ec_point_free(tmp);
1512
1513         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
1514         freebn(x);
1515     }
1516
1517     freebn(z);
1518
1519     return valid;
1520 }
1521
1522 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
1523                         const unsigned char *data, const int dataLen,
1524                         Bignum *r, Bignum *s)
1525 {
1526     unsigned char digest[20];
1527     int z_bits, n_bits;
1528     Bignum z, k;
1529     struct ec_point *kG;
1530
1531     *r = NULL;
1532     *s = NULL;
1533
1534     if (curve->type != EC_WEIERSTRASS) {
1535         return;
1536     }
1537
1538     /* z = left most bitlen(curve->n) of data */
1539     z = bignum_from_bytes(data, dataLen);
1540     n_bits = bignum_bitcount(curve->w.n);
1541     z_bits = bignum_bitcount(z);
1542     if (z_bits > n_bits)
1543     {
1544         Bignum tmp;
1545         tmp = bignum_rshift(z, z_bits - n_bits);
1546         freebn(z);
1547         z = tmp;
1548     }
1549
1550     /* Generate k between 1 and curve->n, using the same deterministic
1551      * k generation system we use for conventional DSA. */
1552     SHA_Simple(data, dataLen, digest);
1553     k = dss_gen_k("ECDSA deterministic k generator", curve->w.n, privateKey,
1554                   digest, sizeof(digest));
1555
1556     kG = ecp_mul(&curve->w.G, k);
1557     if (!kG) {
1558         freebn(z);
1559         freebn(k);
1560         return;
1561     }
1562
1563     /* r = kG.x mod n */
1564     *r = bigmod(kG->x, curve->w.n);
1565     ec_point_free(kG);
1566
1567     /* s = (z + r * priv)/k mod n */
1568     {
1569         Bignum rPriv, zMod, first, firstMod, kInv;
1570         rPriv = modmul(*r, privateKey, curve->w.n);
1571         zMod = bigmod(z, curve->w.n);
1572         freebn(z);
1573         first = bigadd(rPriv, zMod);
1574         freebn(rPriv);
1575         freebn(zMod);
1576         firstMod = bigmod(first, curve->w.n);
1577         freebn(first);
1578         kInv = modinv(k, curve->w.n);
1579         freebn(k);
1580         if (!kInv) {
1581             freebn(firstMod);
1582             freebn(*r);
1583             return;
1584         }
1585         *s = modmul(firstMod, kInv, curve->w.n);
1586         freebn(firstMod);
1587         freebn(kInv);
1588     }
1589 }
1590
1591 /* ----------------------------------------------------------------------
1592  * Misc functions
1593  */
1594
1595 static void getstring(const char **data, int *datalen,
1596                       const char **p, int *length)
1597 {
1598     *p = NULL;
1599     if (*datalen < 4)
1600         return;
1601     *length = toint(GET_32BIT(*data));
1602     if (*length < 0)
1603         return;
1604     *datalen -= 4;
1605     *data += 4;
1606     if (*datalen < *length)
1607         return;
1608     *p = *data;
1609     *data += *length;
1610     *datalen -= *length;
1611 }
1612
1613 static Bignum getmp(const char **data, int *datalen)
1614 {
1615     const char *p;
1616     int length;
1617
1618     getstring(data, datalen, &p, &length);
1619     if (!p)
1620         return NULL;
1621     if (p[0] & 0x80)
1622         return NULL;                   /* negative mp */
1623     return bignum_from_bytes((unsigned char *)p, length);
1624 }
1625
1626 static Bignum getmp_le(const char **data, int *datalen)
1627 {
1628     const char *p;
1629     int length;
1630
1631     getstring(data, datalen, &p, &length);
1632     if (!p)
1633         return NULL;
1634     return bignum_from_bytes_le((const unsigned char *)p, length);
1635 }
1636
1637 static int decodepoint_ed(const char *p, int length, struct ec_point *point)
1638 {
1639     /* Got some conversion to do, first read in the y co-ord */
1640     int negative;
1641
1642     point->y = bignum_from_bytes_le((const unsigned char*)p, length);
1643     if ((unsigned)bignum_bitcount(point->y) > point->curve->fieldBits) {
1644         freebn(point->y);
1645         point->y = NULL;
1646         return 0;
1647     }
1648     /* Read x bit and then reset it */
1649     negative = bignum_bit(point->y, point->curve->fieldBits - 1);
1650     bignum_set_bit(point->y, point->curve->fieldBits - 1, 0);
1651     bn_restore_invariant(point->y);
1652
1653     /* Get the x from the y */
1654     point->x = ecp_edx(point->curve, point->y);
1655     if (!point->x) {
1656         freebn(point->y);
1657         point->y = NULL;
1658         return 0;
1659     }
1660     if (negative) {
1661         Bignum tmp = modsub(point->curve->p, point->x, point->curve->p);
1662         freebn(point->x);
1663         point->x = tmp;
1664     }
1665
1666     /* Verify the point is on the curve */
1667     if (!ec_point_verify(point)) {
1668         freebn(point->x);
1669         point->x = NULL;
1670         freebn(point->y);
1671         point->y = NULL;
1672         return 0;
1673     }
1674
1675     return 1;
1676 }
1677
1678 static int decodepoint(const char *p, int length, struct ec_point *point)
1679 {
1680     if (point->curve->type == EC_EDWARDS) {
1681         return decodepoint_ed(p, length, point);
1682     }
1683
1684     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
1685         return 0;
1686     /* Skip compression flag */
1687     ++p;
1688     --length;
1689     /* The two values must be equal length */
1690     if (length % 2 != 0) {
1691         point->x = NULL;
1692         point->y = NULL;
1693         point->z = NULL;
1694         return 0;
1695     }
1696     length = length / 2;
1697     point->x = bignum_from_bytes((const unsigned char *)p, length);
1698     p += length;
1699     point->y = bignum_from_bytes((const unsigned char *)p, length);
1700     point->z = NULL;
1701
1702     /* Verify the point is on the curve */
1703     if (!ec_point_verify(point)) {
1704         freebn(point->x);
1705         point->x = NULL;
1706         freebn(point->y);
1707         point->y = NULL;
1708         return 0;
1709     }
1710
1711     return 1;
1712 }
1713
1714 static int getmppoint(const char **data, int *datalen, struct ec_point *point)
1715 {
1716     const char *p;
1717     int length;
1718
1719     getstring(data, datalen, &p, &length);
1720     if (!p) return 0;
1721     return decodepoint(p, length, point);
1722 }
1723
1724 /* ----------------------------------------------------------------------
1725  * Exposed ECDSA interface
1726  */
1727
1728 struct ecsign_extra {
1729     struct ec_curve *(*curve)(void);
1730     const struct ssh_hash *hash;
1731
1732     /* These fields are used by the OpenSSH PEM format importer/exporter */
1733     const unsigned char *oid;
1734     int oidlen;
1735 };
1736
1737 static void ecdsa_freekey(void *key)
1738 {
1739     struct ec_key *ec = (struct ec_key *) key;
1740     if (!ec) return;
1741
1742     if (ec->publicKey.x)
1743         freebn(ec->publicKey.x);
1744     if (ec->publicKey.y)
1745         freebn(ec->publicKey.y);
1746     if (ec->publicKey.z)
1747         freebn(ec->publicKey.z);
1748     if (ec->privateKey)
1749         freebn(ec->privateKey);
1750     sfree(ec);
1751 }
1752
1753 static void *ecdsa_newkey(const struct ssh_signkey *self,
1754                           const char *data, int len)
1755 {
1756     const struct ecsign_extra *extra =
1757         (const struct ecsign_extra *)self->extra;
1758     const char *p;
1759     int slen;
1760     struct ec_key *ec;
1761     struct ec_curve *curve;
1762
1763     getstring(&data, &len, &p, &slen);
1764
1765     if (!p) {
1766         return NULL;
1767     }
1768     curve = extra->curve();
1769     assert(curve->type == EC_WEIERSTRASS || curve->type == EC_EDWARDS);
1770
1771     /* Curve name is duplicated for Weierstrass form */
1772     if (curve->type == EC_WEIERSTRASS) {
1773         getstring(&data, &len, &p, &slen);
1774         if (!p) return NULL;
1775         if (!match_ssh_id(slen, p, curve->name)) return NULL;
1776     }
1777
1778     ec = snew(struct ec_key);
1779
1780     ec->signalg = self;
1781     ec->publicKey.curve = curve;
1782     ec->publicKey.infinity = 0;
1783     ec->publicKey.x = NULL;
1784     ec->publicKey.y = NULL;
1785     ec->publicKey.z = NULL;
1786     ec->privateKey = NULL;
1787     if (!getmppoint(&data, &len, &ec->publicKey)) {
1788         ecdsa_freekey(ec);
1789         return NULL;
1790     }
1791
1792     if (!ec->publicKey.x || !ec->publicKey.y ||
1793         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
1794         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
1795     {
1796         ecdsa_freekey(ec);
1797         ec = NULL;
1798     }
1799
1800     return ec;
1801 }
1802
1803 static char *ecdsa_fmtkey(void *key)
1804 {
1805     struct ec_key *ec = (struct ec_key *) key;
1806     char *p;
1807     int len, i, pos, nibbles;
1808     static const char hex[] = "0123456789abcdef";
1809     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
1810         return NULL;
1811
1812     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
1813     if (ec->publicKey.curve->name)
1814         len += strlen(ec->publicKey.curve->name); /* Curve name */
1815     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
1816     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
1817     p = snewn(len, char);
1818
1819     pos = 0;
1820     if (ec->publicKey.curve->name)
1821         pos += sprintf(p + pos, "%s,", ec->publicKey.curve->name);
1822     pos += sprintf(p + pos, "0x");
1823     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
1824     if (nibbles < 1)
1825         nibbles = 1;
1826     for (i = nibbles; i--;) {
1827         p[pos++] =
1828             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
1829     }
1830     pos += sprintf(p + pos, ",0x");
1831     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
1832     if (nibbles < 1)
1833         nibbles = 1;
1834     for (i = nibbles; i--;) {
1835         p[pos++] =
1836             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
1837     }
1838     p[pos] = '\0';
1839     return p;
1840 }
1841
1842 static unsigned char *ecdsa_public_blob(void *key, int *len)
1843 {
1844     struct ec_key *ec = (struct ec_key *) key;
1845     int pointlen, bloblen, fullnamelen, namelen;
1846     int i;
1847     unsigned char *blob, *p;
1848
1849     fullnamelen = strlen(ec->signalg->name);
1850
1851     if (ec->publicKey.curve->type == EC_EDWARDS) {
1852         /* Edwards compressed form "ssh-ed25519" point y[:-1] + x[0:1] */
1853
1854         pointlen = ec->publicKey.curve->fieldBits / 8;
1855
1856         /* Can't handle this in our loop */
1857         if (pointlen < 2) return NULL;
1858
1859         bloblen = 4 + fullnamelen + 4 + pointlen;
1860         blob = snewn(bloblen, unsigned char);
1861
1862         p = blob;
1863         PUT_32BIT(p, fullnamelen);
1864         p += 4;
1865         memcpy(p, ec->signalg->name, fullnamelen);
1866         p += fullnamelen;
1867         PUT_32BIT(p, pointlen);
1868         p += 4;
1869
1870         /* Unset last bit of y and set first bit of x in its place */
1871         for (i = 0; i < pointlen - 1; ++i) {
1872             *p++ = bignum_byte(ec->publicKey.y, i);
1873         }
1874         /* Unset last bit of y and set first bit of x in its place */
1875         *p = bignum_byte(ec->publicKey.y, i) & 0x7f;
1876         *p++ |= bignum_bit(ec->publicKey.x, 0) << 7;
1877     } else if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
1878         assert(ec->publicKey.curve->name);
1879         namelen = strlen(ec->publicKey.curve->name);
1880
1881         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1882
1883         /*
1884          * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
1885          */
1886         bloblen = 4 + fullnamelen + 4 + namelen + 4 + 1 + (pointlen * 2);
1887         blob = snewn(bloblen, unsigned char);
1888
1889         p = blob;
1890         PUT_32BIT(p, fullnamelen);
1891         p += 4;
1892         memcpy(p, ec->signalg->name, fullnamelen);
1893         p += fullnamelen;
1894         PUT_32BIT(p, namelen);
1895         p += 4;
1896         memcpy(p, ec->publicKey.curve->name, namelen);
1897         p += namelen;
1898         PUT_32BIT(p, (2 * pointlen) + 1);
1899         p += 4;
1900         *p++ = 0x04;
1901         for (i = pointlen; i--;) {
1902             *p++ = bignum_byte(ec->publicKey.x, i);
1903         }
1904         for (i = pointlen; i--;) {
1905             *p++ = bignum_byte(ec->publicKey.y, i);
1906         }
1907     } else {
1908         return NULL;
1909     }
1910
1911     assert(p == blob + bloblen);
1912     *len = bloblen;
1913
1914     return blob;
1915 }
1916
1917 static unsigned char *ecdsa_private_blob(void *key, int *len)
1918 {
1919     struct ec_key *ec = (struct ec_key *) key;
1920     int keylen, bloblen;
1921     int i;
1922     unsigned char *blob, *p;
1923
1924     if (!ec->privateKey) return NULL;
1925
1926     if (ec->publicKey.curve->type == EC_EDWARDS) {
1927         /* Unsigned */
1928         keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
1929     } else {
1930         /* Signed */
1931         keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
1932     }
1933
1934     /*
1935      * mpint privateKey. Total 4 + keylen.
1936      */
1937     bloblen = 4 + keylen;
1938     blob = snewn(bloblen, unsigned char);
1939
1940     p = blob;
1941     PUT_32BIT(p, keylen);
1942     p += 4;
1943     if (ec->publicKey.curve->type == EC_EDWARDS) {
1944         /* Little endian */
1945         for (i = 0; i < keylen; ++i)
1946             *p++ = bignum_byte(ec->privateKey, i);
1947     } else {
1948         for (i = keylen; i--;)
1949             *p++ = bignum_byte(ec->privateKey, i);
1950     }
1951
1952     assert(p == blob + bloblen);
1953     *len = bloblen;
1954     return blob;
1955 }
1956
1957 static void *ecdsa_createkey(const struct ssh_signkey *self,
1958                              const unsigned char *pub_blob, int pub_len,
1959                              const unsigned char *priv_blob, int priv_len)
1960 {
1961     struct ec_key *ec;
1962     struct ec_point *publicKey;
1963     const char *pb = (const char *) priv_blob;
1964
1965     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) pub_blob, pub_len);
1966     if (!ec) {
1967         return NULL;
1968     }
1969
1970     if (ec->publicKey.curve->type != EC_WEIERSTRASS
1971         && ec->publicKey.curve->type != EC_EDWARDS) {
1972         ecdsa_freekey(ec);
1973         return NULL;
1974     }
1975
1976     if (ec->publicKey.curve->type == EC_EDWARDS) {
1977         ec->privateKey = getmp_le(&pb, &priv_len);
1978     } else {
1979         ec->privateKey = getmp(&pb, &priv_len);
1980     }
1981     if (!ec->privateKey) {
1982         ecdsa_freekey(ec);
1983         return NULL;
1984     }
1985
1986     /* Check that private key generates public key */
1987     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
1988
1989     if (!publicKey ||
1990         bignum_cmp(publicKey->x, ec->publicKey.x) ||
1991         bignum_cmp(publicKey->y, ec->publicKey.y))
1992     {
1993         ecdsa_freekey(ec);
1994         ec = NULL;
1995     }
1996     ec_point_free(publicKey);
1997
1998     return ec;
1999 }
2000
2001 static void *ed25519_openssh_createkey(const struct ssh_signkey *self,
2002                                        const unsigned char **blob, int *len)
2003 {
2004     struct ec_key *ec;
2005     struct ec_point *publicKey;
2006     const char *p, *q;
2007     int plen, qlen;
2008
2009     getstring((const char**)blob, len, &p, &plen);
2010     if (!p)
2011     {
2012         return NULL;
2013     }
2014
2015     ec = snew(struct ec_key);
2016
2017     ec->signalg = self;
2018     ec->publicKey.curve = ec_ed25519();
2019     ec->publicKey.infinity = 0;
2020     ec->privateKey = NULL;
2021     ec->publicKey.x = NULL;
2022     ec->publicKey.z = NULL;
2023     ec->publicKey.y = NULL;
2024
2025     if (!decodepoint_ed(p, plen, &ec->publicKey))
2026     {
2027         ecdsa_freekey(ec);
2028         return NULL;
2029     }
2030
2031     getstring((const char**)blob, len, &q, &qlen);
2032     if (!q || qlen != 64) {
2033         ecdsa_freekey(ec);
2034         return NULL;
2035     }
2036
2037     ec->privateKey = bignum_from_bytes_le((const unsigned char *)q, 32);
2038
2039     /* Check that private key generates public key */
2040     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2041
2042     if (!publicKey ||
2043         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2044         bignum_cmp(publicKey->y, ec->publicKey.y))
2045     {
2046         ecdsa_freekey(ec);
2047         ec = NULL;
2048     }
2049     ec_point_free(publicKey);
2050
2051     /* The OpenSSH format for ed25519 private keys also for some
2052      * reason encodes an extra copy of the public key in the second
2053      * half of the secret-key string. Check that that's present and
2054      * correct as well, otherwise the key we think we've imported
2055      * won't behave identically to the way OpenSSH would have treated
2056      * it. */
2057     if (plen != 32 || 0 != memcmp(q + 32, p, 32)) {
2058         ecdsa_freekey(ec);
2059         return NULL;
2060     }
2061
2062     return ec;
2063 }
2064
2065 static int ed25519_openssh_fmtkey(void *key, unsigned char *blob, int len)
2066 {
2067     struct ec_key *ec = (struct ec_key *) key;
2068
2069     int pointlen;
2070     int keylen;
2071     int bloblen;
2072     int i;
2073
2074     if (ec->publicKey.curve->type != EC_EDWARDS) {
2075         return 0;
2076     }
2077
2078     pointlen = (bignum_bitcount(ec->publicKey.y) + 7) / 8;
2079     keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2080     bloblen = 4 + pointlen + 4 + keylen + pointlen;
2081
2082     if (bloblen > len)
2083         return bloblen;
2084
2085     /* Encode the public point */
2086     PUT_32BIT(blob, pointlen);
2087     blob += 4;
2088
2089     for (i = 0; i < pointlen - 1; ++i) {
2090          *blob++ = bignum_byte(ec->publicKey.y, i);
2091     }
2092     /* Unset last bit of y and set first bit of x in its place */
2093     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2094     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2095
2096     PUT_32BIT(blob, keylen + pointlen);
2097     blob += 4;
2098     for (i = 0; i < keylen; ++i) {
2099          *blob++ = bignum_byte(ec->privateKey, i);
2100     }
2101     /* Now encode an extra copy of the public point as the second half
2102      * of the private key string, as the OpenSSH format for some
2103      * reason requires */
2104     for (i = 0; i < pointlen - 1; ++i) {
2105          *blob++ = bignum_byte(ec->publicKey.y, i);
2106     }
2107     /* Unset last bit of y and set first bit of x in its place */
2108     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2109     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2110
2111     return bloblen;
2112 }
2113
2114 static void *ecdsa_openssh_createkey(const struct ssh_signkey *self,
2115                                      const unsigned char **blob, int *len)
2116 {
2117     const struct ecsign_extra *extra =
2118         (const struct ecsign_extra *)self->extra;
2119     const char **b = (const char **) blob;
2120     const char *p;
2121     int slen;
2122     struct ec_key *ec;
2123     struct ec_curve *curve;
2124     struct ec_point *publicKey;
2125
2126     getstring(b, len, &p, &slen);
2127
2128     if (!p) {
2129         return NULL;
2130     }
2131     curve = extra->curve();
2132     assert(curve->type == EC_WEIERSTRASS);
2133
2134     ec = snew(struct ec_key);
2135
2136     ec->signalg = self;
2137     ec->publicKey.curve = curve;
2138     ec->publicKey.infinity = 0;
2139     ec->publicKey.x = NULL;
2140     ec->publicKey.y = NULL;
2141     ec->publicKey.z = NULL;
2142     if (!getmppoint(b, len, &ec->publicKey)) {
2143         ecdsa_freekey(ec);
2144         return NULL;
2145     }
2146     ec->privateKey = NULL;
2147
2148     if (!ec->publicKey.x || !ec->publicKey.y ||
2149         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2150         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2151     {
2152         ecdsa_freekey(ec);
2153         return NULL;
2154     }
2155
2156     ec->privateKey = getmp(b, len);
2157     if (ec->privateKey == NULL)
2158     {
2159         ecdsa_freekey(ec);
2160         return NULL;
2161     }
2162
2163     /* Now check that the private key makes the public key */
2164     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2165     if (!publicKey)
2166     {
2167         ecdsa_freekey(ec);
2168         return NULL;
2169     }
2170
2171     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
2172         bignum_cmp(ec->publicKey.y, publicKey->y))
2173     {
2174         /* Private key doesn't make the public key on the given curve */
2175         ecdsa_freekey(ec);
2176         ec_point_free(publicKey);
2177         return NULL;
2178     }
2179
2180     ec_point_free(publicKey);
2181
2182     return ec;
2183 }
2184
2185 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
2186 {
2187     struct ec_key *ec = (struct ec_key *) key;
2188
2189     int pointlen;
2190     int namelen;
2191     int bloblen;
2192     int i;
2193
2194     if (ec->publicKey.curve->type != EC_WEIERSTRASS) {
2195         return 0;
2196     }
2197
2198     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2199     namelen = strlen(ec->publicKey.curve->name);
2200     bloblen =
2201         4 + namelen /* <LEN> nistpXXX */
2202         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
2203         + ssh2_bignum_length(ec->privateKey);
2204
2205     if (bloblen > len)
2206         return bloblen;
2207
2208     bloblen = 0;
2209
2210     PUT_32BIT(blob+bloblen, namelen);
2211     bloblen += 4;
2212     memcpy(blob+bloblen, ec->publicKey.curve->name, namelen);
2213     bloblen += namelen;
2214
2215     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
2216     bloblen += 4;
2217     blob[bloblen++] = 0x04;
2218     for (i = pointlen; i--; )
2219         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
2220     for (i = pointlen; i--; )
2221         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
2222
2223     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2224     PUT_32BIT(blob+bloblen, pointlen);
2225     bloblen += 4;
2226     for (i = pointlen; i--; )
2227         blob[bloblen++] = bignum_byte(ec->privateKey, i);
2228
2229     return bloblen;
2230 }
2231
2232 static int ecdsa_pubkey_bits(const struct ssh_signkey *self,
2233                              const void *blob, int len)
2234 {
2235     struct ec_key *ec;
2236     int ret;
2237
2238     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) blob, len);
2239     if (!ec)
2240         return -1;
2241     ret = ec->publicKey.curve->fieldBits;
2242     ecdsa_freekey(ec);
2243
2244     return ret;
2245 }
2246
2247 static int ecdsa_verifysig(void *key, const char *sig, int siglen,
2248                            const char *data, int datalen)
2249 {
2250     struct ec_key *ec = (struct ec_key *) key;
2251     const struct ecsign_extra *extra =
2252         (const struct ecsign_extra *)ec->signalg->extra;
2253     const char *p;
2254     int slen;
2255     int digestLen;
2256     int ret;
2257
2258     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2259         return 0;
2260
2261     /* Check the signature starts with the algorithm name */
2262     getstring(&sig, &siglen, &p, &slen);
2263     if (!p) {
2264         return 0;
2265     }
2266     if (!match_ssh_id(slen, p, ec->signalg->name)) {
2267         return 0;
2268     }
2269
2270     getstring(&sig, &siglen, &p, &slen);
2271     if (!p) return 0;
2272     if (ec->publicKey.curve->type == EC_EDWARDS) {
2273         struct ec_point *r;
2274         Bignum s, h;
2275
2276         /* Check that the signature is two times the length of a point */
2277         if (slen != (ec->publicKey.curve->fieldBits / 8) * 2) {
2278             return 0;
2279         }
2280
2281         /* Check it's the 256 bit field so that SHA512 is the correct hash */
2282         if (ec->publicKey.curve->fieldBits != 256) {
2283             return 0;
2284         }
2285
2286         /* Get the signature */
2287         r = ec_point_new(ec->publicKey.curve, NULL, NULL, NULL, 0);
2288         if (!r) {
2289             return 0;
2290         }
2291         if (!decodepoint(p, ec->publicKey.curve->fieldBits / 8, r)) {
2292             ec_point_free(r);
2293             return 0;
2294         }
2295         s = bignum_from_bytes_le((unsigned char*)p + (ec->publicKey.curve->fieldBits / 8),
2296                                  ec->publicKey.curve->fieldBits / 8);
2297
2298         /* Get the hash of the encoded value of R + encoded value of pk + message */
2299         {
2300             int i, pointlen;
2301             unsigned char b;
2302             unsigned char digest[512 / 8];
2303             SHA512_State hs;
2304             SHA512_Init(&hs);
2305
2306             /* Add encoded r (no need to encode it again, it was in the signature) */
2307             SHA512_Bytes(&hs, p, ec->publicKey.curve->fieldBits / 8);
2308
2309             /* Encode pk and add it */
2310             pointlen = ec->publicKey.curve->fieldBits / 8;
2311             for (i = 0; i < pointlen - 1; ++i) {
2312                 b = bignum_byte(ec->publicKey.y, i);
2313                 SHA512_Bytes(&hs, &b, 1);
2314             }
2315             /* Unset last bit of y and set first bit of x in its place */
2316             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
2317             b |= bignum_bit(ec->publicKey.x, 0) << 7;
2318             SHA512_Bytes(&hs, &b, 1);
2319
2320             /* Add the message itself */
2321             SHA512_Bytes(&hs, data, datalen);
2322
2323             /* Get the hash */
2324             SHA512_Final(&hs, digest);
2325
2326             /* Convert to Bignum */
2327             h = bignum_from_bytes_le(digest, sizeof(digest));
2328         }
2329
2330         /* Verify sB == r + h*publicKey */
2331         {
2332             struct ec_point *lhs, *rhs, *tmp;
2333
2334             /* lhs = sB */
2335             lhs = ecp_mul(&ec->publicKey.curve->e.B, s);
2336             freebn(s);
2337             if (!lhs) {
2338                 ec_point_free(r);
2339                 freebn(h);
2340                 return 0;
2341             }
2342
2343             /* rhs = r + h*publicKey */
2344             tmp = ecp_mul(&ec->publicKey, h);
2345             freebn(h);
2346             if (!tmp) {
2347                 ec_point_free(lhs);
2348                 ec_point_free(r);
2349                 return 0;
2350             }
2351             rhs = ecp_add(r, tmp, 0);
2352             ec_point_free(r);
2353             ec_point_free(tmp);
2354             if (!rhs) {
2355                 ec_point_free(lhs);
2356                 return 0;
2357             }
2358
2359             /* Check the point is the same */
2360             ret = !bignum_cmp(lhs->x, rhs->x);
2361             if (ret) {
2362                 ret = !bignum_cmp(lhs->y, rhs->y);
2363                 if (ret) {
2364                     ret = 1;
2365                 }
2366             }
2367             ec_point_free(lhs);
2368             ec_point_free(rhs);
2369         }
2370     } else {
2371         Bignum r, s;
2372         unsigned char digest[512 / 8];
2373         void *hashctx;
2374
2375         r = getmp(&p, &slen);
2376         if (!r) return 0;
2377         s = getmp(&p, &slen);
2378         if (!s) {
2379             freebn(r);
2380             return 0;
2381         }
2382
2383         digestLen = extra->hash->hlen;
2384         assert(digestLen <= sizeof(digest));
2385         hashctx = extra->hash->init();
2386         extra->hash->bytes(hashctx, data, datalen);
2387         extra->hash->final(hashctx, digest);
2388
2389         /* Verify the signature */
2390         ret = _ecdsa_verify(&ec->publicKey, digest, digestLen, r, s);
2391
2392         freebn(r);
2393         freebn(s);
2394     }
2395
2396     return ret;
2397 }
2398
2399 static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
2400                                  int *siglen)
2401 {
2402     struct ec_key *ec = (struct ec_key *) key;
2403     const struct ecsign_extra *extra =
2404         (const struct ecsign_extra *)ec->signalg->extra;
2405     unsigned char digest[512 / 8];
2406     int digestLen;
2407     Bignum r = NULL, s = NULL;
2408     unsigned char *buf, *p;
2409     int rlen, slen, namelen;
2410     int i;
2411
2412     if (!ec->privateKey || !ec->publicKey.curve) {
2413         return NULL;
2414     }
2415
2416     if (ec->publicKey.curve->type == EC_EDWARDS) {
2417         struct ec_point *rp;
2418         int pointlen = ec->publicKey.curve->fieldBits / 8;
2419
2420         /* hash = H(sk) (where hash creates 2 * fieldBits)
2421          * b = fieldBits
2422          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
2423          * r = H(h[b/8:b/4] + m)
2424          * R = rB
2425          * S = (r + H(encodepoint(R) + encodepoint(pk) + m) * a) % l */
2426         {
2427             unsigned char hash[512/8];
2428             unsigned char b;
2429             Bignum a;
2430             SHA512_State hs;
2431             SHA512_Init(&hs);
2432
2433             for (i = 0; i < pointlen; ++i) {
2434                 unsigned char b = (unsigned char)bignum_byte(ec->privateKey, i);
2435                 SHA512_Bytes(&hs, &b, 1);
2436             }
2437
2438             SHA512_Final(&hs, hash);
2439
2440             /* The second part is simply turning the hash into a
2441              * Bignum, however the 2^(b-2) bit *must* be set, and the
2442              * bottom 3 bits *must* not be */
2443             hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
2444             hash[31] &= 0x7f; /* Unset above (b-2) */
2445             hash[31] |= 0x40; /* Set 2^(b-2) */
2446             /* Chop off the top part and convert to int */
2447             a = bignum_from_bytes_le(hash, 32);
2448
2449             SHA512_Init(&hs);
2450             SHA512_Bytes(&hs,
2451                          hash+(ec->publicKey.curve->fieldBits / 8),
2452                          (ec->publicKey.curve->fieldBits / 4)
2453                          - (ec->publicKey.curve->fieldBits / 8));
2454             SHA512_Bytes(&hs, data, datalen);
2455             SHA512_Final(&hs, hash);
2456
2457             r = bignum_from_bytes_le(hash, 512/8);
2458             rp = ecp_mul(&ec->publicKey.curve->e.B, r);
2459             if (!rp) {
2460                 freebn(r);
2461                 freebn(a);
2462                 return NULL;
2463             }
2464
2465             /* Now calculate s */
2466             SHA512_Init(&hs);
2467             /* Encode the point R */
2468             for (i = 0; i < pointlen - 1; ++i) {
2469                 b = bignum_byte(rp->y, i);
2470                 SHA512_Bytes(&hs, &b, 1);
2471             }
2472             /* Unset last bit of y and set first bit of x in its place */
2473             b = bignum_byte(rp->y, i) & 0x7f;
2474             b |= bignum_bit(rp->x, 0) << 7;
2475             SHA512_Bytes(&hs, &b, 1);
2476
2477             /* Encode the point pk */
2478             for (i = 0; i < pointlen - 1; ++i) {
2479                 b = bignum_byte(ec->publicKey.y, i);
2480                 SHA512_Bytes(&hs, &b, 1);
2481             }
2482             /* Unset last bit of y and set first bit of x in its place */
2483             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
2484             b |= bignum_bit(ec->publicKey.x, 0) << 7;
2485             SHA512_Bytes(&hs, &b, 1);
2486
2487             /* Add the message */
2488             SHA512_Bytes(&hs, data, datalen);
2489             SHA512_Final(&hs, hash);
2490
2491             {
2492                 Bignum tmp, tmp2;
2493
2494                 tmp = bignum_from_bytes_le(hash, 512/8);
2495                 tmp2 = modmul(tmp, a, ec->publicKey.curve->e.l);
2496                 freebn(a);
2497                 freebn(tmp);
2498                 tmp = bigadd(r, tmp2);
2499                 freebn(r);
2500                 freebn(tmp2);
2501                 s = bigmod(tmp, ec->publicKey.curve->e.l);
2502                 freebn(tmp);
2503             }
2504         }
2505
2506         /* Format the output */
2507         namelen = strlen(ec->signalg->name);
2508         *siglen = 4+namelen+4+((ec->publicKey.curve->fieldBits / 8)*2);
2509         buf = snewn(*siglen, unsigned char);
2510         p = buf;
2511         PUT_32BIT(p, namelen);
2512         p += 4;
2513         memcpy(p, ec->signalg->name, namelen);
2514         p += namelen;
2515         PUT_32BIT(p, ((ec->publicKey.curve->fieldBits / 8)*2));
2516         p += 4;
2517
2518         /* Encode the point */
2519         pointlen = ec->publicKey.curve->fieldBits / 8;
2520         for (i = 0; i < pointlen - 1; ++i) {
2521             *p++ = bignum_byte(rp->y, i);
2522         }
2523         /* Unset last bit of y and set first bit of x in its place */
2524         *p = bignum_byte(rp->y, i) & 0x7f;
2525         *p++ |= bignum_bit(rp->x, 0) << 7;
2526         ec_point_free(rp);
2527
2528         /* Encode the int */
2529         for (i = 0; i < pointlen; ++i) {
2530             *p++ = bignum_byte(s, i);
2531         }
2532         freebn(s);
2533     } else {
2534         void *hashctx;
2535
2536         digestLen = extra->hash->hlen;
2537         assert(digestLen <= sizeof(digest));
2538         hashctx = extra->hash->init();
2539         extra->hash->bytes(hashctx, data, datalen);
2540         extra->hash->final(hashctx, digest);
2541
2542         /* Do the signature */
2543         _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
2544         if (!r || !s) {
2545             if (r) freebn(r);
2546             if (s) freebn(s);
2547             return NULL;
2548         }
2549
2550         rlen = (bignum_bitcount(r) + 8) / 8;
2551         slen = (bignum_bitcount(s) + 8) / 8;
2552
2553         namelen = strlen(ec->signalg->name);
2554
2555         /* Format the output */
2556         *siglen = 8+namelen+rlen+slen+8;
2557         buf = snewn(*siglen, unsigned char);
2558         p = buf;
2559         PUT_32BIT(p, namelen);
2560         p += 4;
2561         memcpy(p, ec->signalg->name, namelen);
2562         p += namelen;
2563         PUT_32BIT(p, rlen + slen + 8);
2564         p += 4;
2565         PUT_32BIT(p, rlen);
2566         p += 4;
2567         for (i = rlen; i--;)
2568             *p++ = bignum_byte(r, i);
2569         PUT_32BIT(p, slen);
2570         p += 4;
2571         for (i = slen; i--;)
2572             *p++ = bignum_byte(s, i);
2573
2574         freebn(r);
2575         freebn(s);
2576     }
2577
2578     return buf;
2579 }
2580
2581 const struct ecsign_extra sign_extra_ed25519 = {
2582     ec_ed25519, NULL,
2583     NULL, 0,
2584 };
2585 const struct ssh_signkey ssh_ecdsa_ed25519 = {
2586     ecdsa_newkey,
2587     ecdsa_freekey,
2588     ecdsa_fmtkey,
2589     ecdsa_public_blob,
2590     ecdsa_private_blob,
2591     ecdsa_createkey,
2592     ed25519_openssh_createkey,
2593     ed25519_openssh_fmtkey,
2594     2 /* point, private exponent */,
2595     ecdsa_pubkey_bits,
2596     ecdsa_verifysig,
2597     ecdsa_sign,
2598     "ssh-ed25519",
2599     "ssh-ed25519",
2600     &sign_extra_ed25519,
2601 };
2602
2603 /* OID: 1.2.840.10045.3.1.7 (ansiX9p256r1) */
2604 static const unsigned char nistp256_oid[] = {
2605     0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07
2606 };
2607 const struct ecsign_extra sign_extra_nistp256 = {
2608     ec_p256, &ssh_sha256,
2609     nistp256_oid, lenof(nistp256_oid),
2610 };
2611 const struct ssh_signkey ssh_ecdsa_nistp256 = {
2612     ecdsa_newkey,
2613     ecdsa_freekey,
2614     ecdsa_fmtkey,
2615     ecdsa_public_blob,
2616     ecdsa_private_blob,
2617     ecdsa_createkey,
2618     ecdsa_openssh_createkey,
2619     ecdsa_openssh_fmtkey,
2620     3 /* curve name, point, private exponent */,
2621     ecdsa_pubkey_bits,
2622     ecdsa_verifysig,
2623     ecdsa_sign,
2624     "ecdsa-sha2-nistp256",
2625     "ecdsa-sha2-nistp256",
2626     &sign_extra_nistp256,
2627 };
2628
2629 /* OID: 1.3.132.0.34 (secp384r1) */
2630 static const unsigned char nistp384_oid[] = {
2631     0x2b, 0x81, 0x04, 0x00, 0x22
2632 };
2633 const struct ecsign_extra sign_extra_nistp384 = {
2634     ec_p384, &ssh_sha384,
2635     nistp384_oid, lenof(nistp384_oid),
2636 };
2637 const struct ssh_signkey ssh_ecdsa_nistp384 = {
2638     ecdsa_newkey,
2639     ecdsa_freekey,
2640     ecdsa_fmtkey,
2641     ecdsa_public_blob,
2642     ecdsa_private_blob,
2643     ecdsa_createkey,
2644     ecdsa_openssh_createkey,
2645     ecdsa_openssh_fmtkey,
2646     3 /* curve name, point, private exponent */,
2647     ecdsa_pubkey_bits,
2648     ecdsa_verifysig,
2649     ecdsa_sign,
2650     "ecdsa-sha2-nistp384",
2651     "ecdsa-sha2-nistp384",
2652     &sign_extra_nistp384,
2653 };
2654
2655 /* OID: 1.3.132.0.35 (secp521r1) */
2656 static const unsigned char nistp521_oid[] = {
2657     0x2b, 0x81, 0x04, 0x00, 0x23
2658 };
2659 const struct ecsign_extra sign_extra_nistp521 = {
2660     ec_p521, &ssh_sha512,
2661     nistp521_oid, lenof(nistp521_oid),
2662 };
2663 const struct ssh_signkey ssh_ecdsa_nistp521 = {
2664     ecdsa_newkey,
2665     ecdsa_freekey,
2666     ecdsa_fmtkey,
2667     ecdsa_public_blob,
2668     ecdsa_private_blob,
2669     ecdsa_createkey,
2670     ecdsa_openssh_createkey,
2671     ecdsa_openssh_fmtkey,
2672     3 /* curve name, point, private exponent */,
2673     ecdsa_pubkey_bits,
2674     ecdsa_verifysig,
2675     ecdsa_sign,
2676     "ecdsa-sha2-nistp521",
2677     "ecdsa-sha2-nistp521",
2678     &sign_extra_nistp521,
2679 };
2680
2681 /* ----------------------------------------------------------------------
2682  * Exposed ECDH interface
2683  */
2684
2685 struct eckex_extra {
2686     struct ec_curve *(*curve)(void);
2687 };
2688
2689 static Bignum ecdh_calculate(const Bignum private,
2690                              const struct ec_point *public)
2691 {
2692     struct ec_point *p;
2693     Bignum ret;
2694     p = ecp_mul(public, private);
2695     if (!p) return NULL;
2696     ret = p->x;
2697     p->x = NULL;
2698
2699     if (p->curve->type == EC_MONTGOMERY) {
2700         /*
2701          * Endianness-swap. The Curve25519 algorithm definition
2702          * assumes you were doing your computation in arrays of 32
2703          * little-endian bytes, and now specifies that you take your
2704          * final one of those and convert it into a bignum in
2705          * _network_ byte order, i.e. big-endian.
2706          *
2707          * In particular, the spec says, you convert the _whole_ 32
2708          * bytes into a bignum. That is, on the rare occasions that
2709          * p->x has come out with the most significant 8 bits zero, we
2710          * have to imagine that being represented by a 32-byte string
2711          * with the last byte being zero, so that has to be converted
2712          * into an SSH-2 bignum with the _low_ byte zero, i.e. a
2713          * multiple of 256.
2714          */
2715         int i;
2716         int bytes = (p->curve->fieldBits+7) / 8;
2717         unsigned char *byteorder = snewn(bytes, unsigned char);
2718         for (i = 0; i < bytes; ++i) {
2719             byteorder[i] = bignum_byte(ret, i);
2720         }
2721         freebn(ret);
2722         ret = bignum_from_bytes(byteorder, bytes);
2723         smemclr(byteorder, bytes);
2724         sfree(byteorder);
2725     }
2726
2727     ec_point_free(p);
2728     return ret;
2729 }
2730
2731 const char *ssh_ecdhkex_curve_textname(const struct ssh_kex *kex)
2732 {
2733     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
2734     struct ec_curve *curve = extra->curve();
2735     return curve->textname;
2736 }
2737
2738 void *ssh_ecdhkex_newkey(const struct ssh_kex *kex)
2739 {
2740     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
2741     struct ec_curve *curve;
2742     struct ec_key *key;
2743     struct ec_point *publicKey;
2744
2745     curve = extra->curve();
2746
2747     key = snew(struct ec_key);
2748
2749     key->signalg = NULL;
2750     key->publicKey.curve = curve;
2751
2752     if (curve->type == EC_MONTGOMERY) {
2753         unsigned char bytes[32] = {0};
2754         int i;
2755
2756         for (i = 0; i < sizeof(bytes); ++i)
2757         {
2758             bytes[i] = (unsigned char)random_byte();
2759         }
2760         bytes[0] &= 248;
2761         bytes[31] &= 127;
2762         bytes[31] |= 64;
2763         key->privateKey = bignum_from_bytes_le(bytes, sizeof(bytes));
2764         smemclr(bytes, sizeof(bytes));
2765         if (!key->privateKey) {
2766             sfree(key);
2767             return NULL;
2768         }
2769         publicKey = ecp_mul(&key->publicKey.curve->m.G, key->privateKey);
2770         if (!publicKey) {
2771             freebn(key->privateKey);
2772             sfree(key);
2773             return NULL;
2774         }
2775         key->publicKey.x = publicKey->x;
2776         key->publicKey.y = publicKey->y;
2777         key->publicKey.z = NULL;
2778         sfree(publicKey);
2779     } else {
2780         key->privateKey = bignum_random_in_range(One, key->publicKey.curve->w.n);
2781         if (!key->privateKey) {
2782             sfree(key);
2783             return NULL;
2784         }
2785         publicKey = ecp_mul(&key->publicKey.curve->w.G, key->privateKey);
2786         if (!publicKey) {
2787             freebn(key->privateKey);
2788             sfree(key);
2789             return NULL;
2790         }
2791         key->publicKey.x = publicKey->x;
2792         key->publicKey.y = publicKey->y;
2793         key->publicKey.z = NULL;
2794         sfree(publicKey);
2795     }
2796     return key;
2797 }
2798
2799 char *ssh_ecdhkex_getpublic(void *key, int *len)
2800 {
2801     struct ec_key *ec = (struct ec_key*)key;
2802     char *point, *p;
2803     int i;
2804     int pointlen;
2805
2806     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2807
2808     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2809         *len = 1 + pointlen * 2;
2810     } else {
2811         *len = pointlen;
2812     }
2813     point = (char*)snewn(*len, char);
2814
2815     p = point;
2816     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2817         *p++ = 0x04;
2818         for (i = pointlen; i--;) {
2819             *p++ = bignum_byte(ec->publicKey.x, i);
2820         }
2821         for (i = pointlen; i--;) {
2822             *p++ = bignum_byte(ec->publicKey.y, i);
2823         }
2824     } else {
2825         for (i = 0; i < pointlen; ++i) {
2826             *p++ = bignum_byte(ec->publicKey.x, i);
2827         }
2828     }
2829
2830     return point;
2831 }
2832
2833 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
2834 {
2835     struct ec_key *ec = (struct ec_key*) key;
2836     struct ec_point remote;
2837     Bignum ret;
2838
2839     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2840         remote.curve = ec->publicKey.curve;
2841         remote.infinity = 0;
2842         if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
2843             return NULL;
2844         }
2845     } else {
2846         /* Point length has to be the same length */
2847         if (remoteKeyLen != (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8) {
2848             return NULL;
2849         }
2850
2851         remote.curve = ec->publicKey.curve;
2852         remote.infinity = 0;
2853         remote.x = bignum_from_bytes_le((unsigned char*)remoteKey, remoteKeyLen);
2854         remote.y = NULL;
2855         remote.z = NULL;
2856     }
2857
2858     ret = ecdh_calculate(ec->privateKey, &remote);
2859     if (remote.x) freebn(remote.x);
2860     if (remote.y) freebn(remote.y);
2861     return ret;
2862 }
2863
2864 void ssh_ecdhkex_freekey(void *key)
2865 {
2866     ecdsa_freekey(key);
2867 }
2868
2869 static const struct eckex_extra kex_extra_curve25519 = { ec_curve25519 };
2870 static const struct ssh_kex ssh_ec_kex_curve25519 = {
2871     "curve25519-sha256@libssh.org", NULL, KEXTYPE_ECDH,
2872     &ssh_sha256, &kex_extra_curve25519,
2873 };
2874
2875 const struct eckex_extra kex_extra_nistp256 = { ec_p256 };
2876 static const struct ssh_kex ssh_ec_kex_nistp256 = {
2877     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH,
2878     &ssh_sha256, &kex_extra_nistp256,
2879 };
2880
2881 const struct eckex_extra kex_extra_nistp384 = { ec_p384 };
2882 static const struct ssh_kex ssh_ec_kex_nistp384 = {
2883     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH,
2884     &ssh_sha384, &kex_extra_nistp384,
2885 };
2886
2887 const struct eckex_extra kex_extra_nistp521 = { ec_p521 };
2888 static const struct ssh_kex ssh_ec_kex_nistp521 = {
2889     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH,
2890     &ssh_sha512, &kex_extra_nistp521,
2891 };
2892
2893 static const struct ssh_kex *const ec_kex_list[] = {
2894     &ssh_ec_kex_curve25519,
2895     &ssh_ec_kex_nistp256,
2896     &ssh_ec_kex_nistp384,
2897     &ssh_ec_kex_nistp521,
2898 };
2899
2900 const struct ssh_kexes ssh_ecdh_kex = {
2901     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
2902     ec_kex_list
2903 };
2904
2905 /* ----------------------------------------------------------------------
2906  * Helper functions for finding key algorithms and returning auxiliary
2907  * data.
2908  */
2909
2910 const struct ssh_signkey *ec_alg_by_oid(int len, const void *oid,
2911                                         const struct ec_curve **curve)
2912 {
2913     static const struct ssh_signkey *algs_with_oid[] = {
2914         &ssh_ecdsa_nistp256,
2915         &ssh_ecdsa_nistp384,
2916         &ssh_ecdsa_nistp521,
2917     };
2918     int i;
2919
2920     for (i = 0; i < lenof(algs_with_oid); i++) {
2921         const struct ssh_signkey *alg = algs_with_oid[i];
2922         const struct ecsign_extra *extra =
2923             (const struct ecsign_extra *)alg->extra;
2924         if (len == extra->oidlen && !memcmp(oid, extra->oid, len)) {
2925             *curve = extra->curve();
2926             return alg;
2927         }
2928     }
2929     return NULL;
2930 }
2931
2932 const unsigned char *ec_alg_oid(const struct ssh_signkey *alg,
2933                                 int *oidlen)
2934 {
2935     const struct ecsign_extra *extra = (const struct ecsign_extra *)alg->extra;
2936     *oidlen = extra->oidlen;
2937     return extra->oid;
2938 }
2939
2940 const int ec_nist_curve_lengths[] = { 256, 384, 521 };
2941 const int n_ec_nist_curve_lengths = lenof(ec_nist_curve_lengths);
2942
2943 int ec_nist_alg_and_curve_by_bits(int bits,
2944                                   const struct ec_curve **curve,
2945                                   const struct ssh_signkey **alg)
2946 {
2947     switch (bits) {
2948       case 256: *alg = &ssh_ecdsa_nistp256; break;
2949       case 384: *alg = &ssh_ecdsa_nistp384; break;
2950       case 521: *alg = &ssh_ecdsa_nistp521; break;
2951       default: return FALSE;
2952     }
2953     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
2954     return TRUE;
2955 }
2956
2957 int ec_ed_alg_and_curve_by_bits(int bits,
2958                                 const struct ec_curve **curve,
2959                                 const struct ssh_signkey **alg)
2960 {
2961     switch (bits) {
2962       case 256: *alg = &ssh_ecdsa_ed25519; break;
2963       default: return FALSE;
2964     }
2965     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
2966     return TRUE;
2967 }