]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
Add a reference to a spec for Curve25519.
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  *
5  * NOTE: Only curves on prime field are handled by the maths functions
6  *       in Weierstrass form using Jacobian co-ordinates.
7  *
8  *       Montgomery form curves are supported for DH. (Curve25519)
9  *
10  *       Edwards form curves are supported for DSA. (Ed25519)
11  */
12
13 /*
14  * References:
15  *
16  * Elliptic curves in SSH are specified in RFC 5656:
17  *   http://tools.ietf.org/html/rfc5656
18  *
19  * That specification delegates details of public key formatting and a
20  * lot of underlying mechanism to SEC 1:
21  *   http://www.secg.org/sec1-v2.pdf
22  *
23  * Montgomery maths from:
24  * Handbook of elliptic and hyperelliptic curve cryptography, Chapter 13
25  *   http://cs.ucsb.edu/~koc/ccs130h/2013/EllipticHyperelliptic-CohenFrey.pdf
26  *
27  * Curve25519 spec from libssh (with reference to other things in the
28  * libssh code):
29  *   https://git.libssh.org/users/aris/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt
30  *
31  * Edwards DSA:
32  *   http://ed25519.cr.yp.to/ed25519-20110926.pdf
33  */
34
35 #include <stdlib.h>
36 #include <assert.h>
37
38 #include "ssh.h"
39
40 /* ----------------------------------------------------------------------
41  * Elliptic curve definitions
42  */
43
44 static void initialise_wcurve(struct ec_curve *curve, int bits,
45                               const unsigned char *p,
46                               const unsigned char *a, const unsigned char *b,
47                               const unsigned char *n, const unsigned char *Gx,
48                               const unsigned char *Gy)
49 {
50     int length = bits / 8;
51     if (bits % 8) ++length;
52
53     curve->type = EC_WEIERSTRASS;
54
55     curve->fieldBits = bits;
56     curve->p = bignum_from_bytes(p, length);
57
58     /* Curve co-efficients */
59     curve->w.a = bignum_from_bytes(a, length);
60     curve->w.b = bignum_from_bytes(b, length);
61
62     /* Group order and generator */
63     curve->w.n = bignum_from_bytes(n, length);
64     curve->w.G.x = bignum_from_bytes(Gx, length);
65     curve->w.G.y = bignum_from_bytes(Gy, length);
66     curve->w.G.curve = curve;
67     curve->w.G.infinity = 0;
68 }
69
70 static void initialise_mcurve(struct ec_curve *curve, int bits,
71                               const unsigned char *p,
72                               const unsigned char *a, const unsigned char *b,
73                               const unsigned char *Gx)
74 {
75     int length = bits / 8;
76     if (bits % 8) ++length;
77
78     curve->type = EC_MONTGOMERY;
79
80     curve->fieldBits = bits;
81     curve->p = bignum_from_bytes(p, length);
82
83     /* Curve co-efficients */
84     curve->m.a = bignum_from_bytes(a, length);
85     curve->m.b = bignum_from_bytes(b, length);
86
87     /* Generator */
88     curve->m.G.x = bignum_from_bytes(Gx, length);
89     curve->m.G.y = NULL;
90     curve->m.G.z = NULL;
91     curve->m.G.curve = curve;
92     curve->m.G.infinity = 0;
93 }
94
95 static void initialise_ecurve(struct ec_curve *curve, int bits,
96                               const unsigned char *p,
97                               const unsigned char *l, const unsigned char *d,
98                               const unsigned char *Bx, const unsigned char *By)
99 {
100     int length = bits / 8;
101     if (bits % 8) ++length;
102
103     curve->type = EC_EDWARDS;
104
105     curve->fieldBits = bits;
106     curve->p = bignum_from_bytes(p, length);
107
108     /* Curve co-efficients */
109     curve->e.l = bignum_from_bytes(l, length);
110     curve->e.d = bignum_from_bytes(d, length);
111
112     /* Group order and generator */
113     curve->e.B.x = bignum_from_bytes(Bx, length);
114     curve->e.B.y = bignum_from_bytes(By, length);
115     curve->e.B.curve = curve;
116     curve->e.B.infinity = 0;
117 }
118
119 static struct ec_curve *ec_p256(void)
120 {
121     static struct ec_curve curve = { 0 };
122     static unsigned char initialised = 0;
123
124     if (!initialised)
125     {
126         static const unsigned char p[] = {
127             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
128             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
129             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
130             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
131         };
132         static const unsigned char a[] = {
133             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
134             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
135             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
136             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
137         };
138         static const unsigned char b[] = {
139             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
140             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
141             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
142             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
143         };
144         static const unsigned char n[] = {
145             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
146             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
147             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
148             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
149         };
150         static const unsigned char Gx[] = {
151             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
152             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
153             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
154             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
155         };
156         static const unsigned char Gy[] = {
157             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
158             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
159             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
160             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
161         };
162
163         initialise_wcurve(&curve, 256, p, a, b, n, Gx, Gy);
164         curve.textname = curve.name = "nistp256";
165
166         /* Now initialised, no need to do it again */
167         initialised = 1;
168     }
169
170     return &curve;
171 }
172
173 static struct ec_curve *ec_p384(void)
174 {
175     static struct ec_curve curve = { 0 };
176     static unsigned char initialised = 0;
177
178     if (!initialised)
179     {
180         static const unsigned char p[] = {
181             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
182             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
183             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
184             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
185             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
186             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
187         };
188         static const unsigned char a[] = {
189             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
190             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
191             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
192             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
193             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
194             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
195         };
196         static const unsigned char b[] = {
197             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
198             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
199             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
200             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
201             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
202             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
203         };
204         static const unsigned char n[] = {
205             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
206             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
207             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
208             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
209             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
210             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
211         };
212         static const unsigned char Gx[] = {
213             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
214             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
215             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
216             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
217             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
218             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
219         };
220         static const unsigned char Gy[] = {
221             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
222             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
223             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
224             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
225             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
226             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
227         };
228
229         initialise_wcurve(&curve, 384, p, a, b, n, Gx, Gy);
230         curve.textname = curve.name = "nistp384";
231
232         /* Now initialised, no need to do it again */
233         initialised = 1;
234     }
235
236     return &curve;
237 }
238
239 static struct ec_curve *ec_p521(void)
240 {
241     static struct ec_curve curve = { 0 };
242     static unsigned char initialised = 0;
243
244     if (!initialised)
245     {
246         static const unsigned char p[] = {
247             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
248             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
249             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
250             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
251             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
252             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
253             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
254             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
255             0xff, 0xff
256         };
257         static const unsigned char a[] = {
258             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
259             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
260             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
261             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
262             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
263             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
264             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
265             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
266             0xff, 0xfc
267         };
268         static const unsigned char b[] = {
269             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
270             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
271             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
272             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
273             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
274             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
275             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
276             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
277             0x3f, 0x00
278         };
279         static const unsigned char n[] = {
280             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
281             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
282             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
283             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
284             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
285             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
286             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
287             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
288             0x64, 0x09
289         };
290         static const unsigned char Gx[] = {
291             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
292             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
293             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
294             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
295             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
296             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
297             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
298             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
299             0xbd, 0x66
300         };
301         static const unsigned char Gy[] = {
302             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
303             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
304             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
305             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
306             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
307             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
308             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
309             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
310             0x66, 0x50
311         };
312
313         initialise_wcurve(&curve, 521, p, a, b, n, Gx, Gy);
314         curve.textname = curve.name = "nistp521";
315
316         /* Now initialised, no need to do it again */
317         initialised = 1;
318     }
319
320     return &curve;
321 }
322
323 static struct ec_curve *ec_curve25519(void)
324 {
325     static struct ec_curve curve = { 0 };
326     static unsigned char initialised = 0;
327
328     if (!initialised)
329     {
330         static const unsigned char p[] = {
331             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
332             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
333             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
334             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
335         };
336         static const unsigned char a[] = {
337             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
338             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
339             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
340             0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x6d, 0x06
341         };
342         static const unsigned char b[] = {
343             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
344             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
345             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
346             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01
347         };
348         static const unsigned char gx[32] = {
349             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
350             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
351             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
352             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09
353         };
354
355         initialise_mcurve(&curve, 256, p, a, b, gx);
356         /* This curve doesn't need a name, because it's never used in
357          * any format that embeds the curve name */
358         curve.name = NULL;
359         curve.textname = "Curve25519";
360
361         /* Now initialised, no need to do it again */
362         initialised = 1;
363     }
364
365     return &curve;
366 }
367
368 static struct ec_curve *ec_ed25519(void)
369 {
370     static struct ec_curve curve = { 0 };
371     static unsigned char initialised = 0;
372
373     if (!initialised)
374     {
375         static const unsigned char q[] = {
376             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
377             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
378             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
379             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
380         };
381         static const unsigned char l[32] = {
382             0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
383             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
384             0x14, 0xde, 0xf9, 0xde, 0xa2, 0xf7, 0x9c, 0xd6,
385             0x58, 0x12, 0x63, 0x1a, 0x5c, 0xf5, 0xd3, 0xed
386         };
387         static const unsigned char d[32] = {
388             0x52, 0x03, 0x6c, 0xee, 0x2b, 0x6f, 0xfe, 0x73,
389             0x8c, 0xc7, 0x40, 0x79, 0x77, 0x79, 0xe8, 0x98,
390             0x00, 0x70, 0x0a, 0x4d, 0x41, 0x41, 0xd8, 0xab,
391             0x75, 0xeb, 0x4d, 0xca, 0x13, 0x59, 0x78, 0xa3
392         };
393         static const unsigned char Bx[32] = {
394             0x21, 0x69, 0x36, 0xd3, 0xcd, 0x6e, 0x53, 0xfe,
395             0xc0, 0xa4, 0xe2, 0x31, 0xfd, 0xd6, 0xdc, 0x5c,
396             0x69, 0x2c, 0xc7, 0x60, 0x95, 0x25, 0xa7, 0xb2,
397             0xc9, 0x56, 0x2d, 0x60, 0x8f, 0x25, 0xd5, 0x1a
398         };
399         static const unsigned char By[32] = {
400             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
401             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
402             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
403             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x58
404         };
405
406         /* This curve doesn't need a name, because it's never used in
407          * any format that embeds the curve name */
408         curve.name = NULL;
409
410         initialise_ecurve(&curve, 256, q, l, d, Bx, By);
411         curve.textname = "Ed25519";
412
413         /* Now initialised, no need to do it again */
414         initialised = 1;
415     }
416
417     return &curve;
418 }
419
420 /* Return 1 if a is -3 % p, otherwise return 0
421  * This is used because there are some maths optimisations */
422 static int ec_aminus3(const struct ec_curve *curve)
423 {
424     int ret;
425     Bignum _p;
426
427     if (curve->type != EC_WEIERSTRASS) {
428         return 0;
429     }
430
431     _p = bignum_add_long(curve->w.a, 3);
432
433     ret = !bignum_cmp(curve->p, _p);
434     freebn(_p);
435     return ret;
436 }
437
438 /* ----------------------------------------------------------------------
439  * Elliptic curve field maths
440  */
441
442 static Bignum ecf_add(const Bignum a, const Bignum b,
443                       const struct ec_curve *curve)
444 {
445     Bignum a1, b1, ab, ret;
446
447     a1 = bigmod(a, curve->p);
448     b1 = bigmod(b, curve->p);
449
450     ab = bigadd(a1, b1);
451     freebn(a1);
452     freebn(b1);
453
454     ret = bigmod(ab, curve->p);
455     freebn(ab);
456
457     return ret;
458 }
459
460 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
461 {
462     return modmul(a, a, curve->p);
463 }
464
465 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
466 {
467     Bignum ret, tmp;
468
469     /* Double */
470     tmp = bignum_lshift(a, 1);
471
472     /* Add itself (i.e. treble) */
473     ret = bigadd(tmp, a);
474     freebn(tmp);
475
476     /* Normalise */
477     while (bignum_cmp(ret, curve->p) >= 0)
478     {
479         tmp = bigsub(ret, curve->p);
480         assert(tmp);
481         freebn(ret);
482         ret = tmp;
483     }
484
485     return ret;
486 }
487
488 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
489 {
490     Bignum ret = bignum_lshift(a, 1);
491     if (bignum_cmp(ret, curve->p) >= 0)
492     {
493         Bignum tmp = bigsub(ret, curve->p);
494         assert(tmp);
495         freebn(ret);
496         return tmp;
497     }
498     else
499     {
500         return ret;
501     }
502 }
503
504 /* ----------------------------------------------------------------------
505  * Memory functions
506  */
507
508 void ec_point_free(struct ec_point *point)
509 {
510     if (point == NULL) return;
511     point->curve = 0;
512     if (point->x) freebn(point->x);
513     if (point->y) freebn(point->y);
514     if (point->z) freebn(point->z);
515     point->infinity = 0;
516     sfree(point);
517 }
518
519 static struct ec_point *ec_point_new(const struct ec_curve *curve,
520                                      const Bignum x, const Bignum y, const Bignum z,
521                                      unsigned char infinity)
522 {
523     struct ec_point *point = snewn(1, struct ec_point);
524     point->curve = curve;
525     point->x = x;
526     point->y = y;
527     point->z = z;
528     point->infinity = infinity ? 1 : 0;
529     return point;
530 }
531
532 static struct ec_point *ec_point_copy(const struct ec_point *a)
533 {
534     if (a == NULL) return NULL;
535     return ec_point_new(a->curve,
536                         a->x ? copybn(a->x) : NULL,
537                         a->y ? copybn(a->y) : NULL,
538                         a->z ? copybn(a->z) : NULL,
539                         a->infinity);
540 }
541
542 static int ec_point_verify(const struct ec_point *a)
543 {
544     if (a->infinity) {
545         return 1;
546     } else if (a->curve->type == EC_EDWARDS) {
547         /* Check y^2 - x^2 - 1 - d * x^2 * y^2 == 0 */
548         Bignum y2, x2, tmp, tmp2, tmp3;
549         int ret;
550
551         y2 = ecf_square(a->y, a->curve);
552         x2 = ecf_square(a->x, a->curve);
553         tmp = modmul(a->curve->e.d, x2, a->curve->p);
554         tmp2 = modmul(tmp, y2, a->curve->p);
555         freebn(tmp);
556         tmp = modsub(y2, x2, a->curve->p);
557         freebn(y2);
558         freebn(x2);
559         tmp3 = modsub(tmp, tmp2, a->curve->p);
560         freebn(tmp);
561         freebn(tmp2);
562         ret = !bignum_cmp(tmp3, One);
563         freebn(tmp3);
564         return ret;
565     } else if (a->curve->type == EC_WEIERSTRASS) {
566         /* Verify y^2 = x^3 + ax + b */
567         int ret = 0;
568
569         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
570
571         Bignum Three = bignum_from_long(3);
572
573         lhs = modmul(a->y, a->y, a->curve->p);
574
575         /* This uses montgomery multiplication to optimise */
576         x3 = modpow(a->x, Three, a->curve->p);
577         freebn(Three);
578         ax = modmul(a->curve->w.a, a->x, a->curve->p);
579         x3ax = bigadd(x3, ax);
580         freebn(x3); x3 = NULL;
581         freebn(ax); ax = NULL;
582         x3axm = bigmod(x3ax, a->curve->p);
583         freebn(x3ax); x3ax = NULL;
584         x3axb = bigadd(x3axm, a->curve->w.b);
585         freebn(x3axm); x3axm = NULL;
586         rhs = bigmod(x3axb, a->curve->p);
587         freebn(x3axb);
588
589         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
590         freebn(lhs);
591         freebn(rhs);
592
593         return ret;
594     } else {
595         return 0;
596     }
597 }
598
599 /* ----------------------------------------------------------------------
600  * Elliptic curve point maths
601  */
602
603 /* Returns 1 on success and 0 on memory error */
604 static int ecp_normalise(struct ec_point *a)
605 {
606     if (!a) {
607         /* No point */
608         return 0;
609     }
610
611     if (a->infinity) {
612         /* Point is at infinity - i.e. normalised */
613         return 1;
614     }
615
616     if (a->curve->type == EC_WEIERSTRASS) {
617         /* In Jacobian Coordinates the triple (X, Y, Z) represents
618            the affine point (X / Z^2, Y / Z^3) */
619
620         Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
621
622         if (!a->x || !a->y) {
623             /* No point defined */
624             return 0;
625         } else if (!a->z) {
626             /* Already normalised */
627             return 1;
628         }
629
630         Z2 = ecf_square(a->z, a->curve);
631         Z2inv = modinv(Z2, a->curve->p);
632         if (!Z2inv) {
633             freebn(Z2);
634             return 0;
635         }
636         tx = modmul(a->x, Z2inv, a->curve->p);
637         freebn(Z2inv);
638
639         Z3 = modmul(Z2, a->z, a->curve->p);
640         freebn(Z2);
641         Z3inv = modinv(Z3, a->curve->p);
642         freebn(Z3);
643         if (!Z3inv) {
644             freebn(tx);
645             return 0;
646         }
647         ty = modmul(a->y, Z3inv, a->curve->p);
648         freebn(Z3inv);
649
650         freebn(a->x);
651         a->x = tx;
652         freebn(a->y);
653         a->y = ty;
654         freebn(a->z);
655         a->z = NULL;
656         return 1;
657     } else if (a->curve->type == EC_MONTGOMERY) {
658         /* In Montgomery (X : Z) represents the x co-ord (X / Z, ?) */
659
660         Bignum tmp, tmp2;
661
662         if (!a->x) {
663             /* No point defined */
664             return 0;
665         } else if (!a->z) {
666             /* Already normalised */
667             return 1;
668         }
669
670         tmp = modinv(a->z, a->curve->p);
671         if (!tmp) {
672             return 0;
673         }
674         tmp2 = modmul(a->x, tmp, a->curve->p);
675         freebn(tmp);
676
677         freebn(a->z);
678         a->z = NULL;
679         freebn(a->x);
680         a->x = tmp2;
681         return 1;
682     } else if (a->curve->type == EC_EDWARDS) {
683         /* Always normalised */
684         return 1;
685     } else {
686         return 0;
687     }
688 }
689
690 static struct ec_point *ecp_doublew(const struct ec_point *a, const int aminus3)
691 {
692     Bignum S, M, outx, outy, outz;
693
694     if (bignum_cmp(a->y, Zero) == 0)
695     {
696         /* Identity */
697         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
698     }
699
700     /* S = 4*X*Y^2 */
701     {
702         Bignum Y2, XY2, _2XY2;
703
704         Y2 = ecf_square(a->y, a->curve);
705         XY2 = modmul(a->x, Y2, a->curve->p);
706         freebn(Y2);
707
708         _2XY2 = ecf_double(XY2, a->curve);
709         freebn(XY2);
710         S = ecf_double(_2XY2, a->curve);
711         freebn(_2XY2);
712     }
713
714     /* Faster calculation if a = -3 */
715     if (aminus3) {
716         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
717         Bignum Z2, XpZ2, XmZ2, second;
718
719         if (a->z == NULL) {
720             Z2 = copybn(One);
721         } else {
722             Z2 = ecf_square(a->z, a->curve);
723         }
724
725         XpZ2 = ecf_add(a->x, Z2, a->curve);
726         XmZ2 = modsub(a->x, Z2, a->curve->p);
727         freebn(Z2);
728
729         second = modmul(XpZ2, XmZ2, a->curve->p);
730         freebn(XpZ2);
731         freebn(XmZ2);
732
733         M = ecf_treble(second, a->curve);
734         freebn(second);
735     } else {
736         /* M = 3*X^2 + a*Z^4 */
737         Bignum _3X2, X2, aZ4;
738
739         if (a->z == NULL) {
740             aZ4 = copybn(a->curve->w.a);
741         } else {
742             Bignum Z2, Z4;
743
744             Z2 = ecf_square(a->z, a->curve);
745             Z4 = ecf_square(Z2, a->curve);
746             freebn(Z2);
747             aZ4 = modmul(a->curve->w.a, Z4, a->curve->p);
748             freebn(Z4);
749         }
750
751         X2 = modmul(a->x, a->x, a->curve->p);
752         _3X2 = ecf_treble(X2, a->curve);
753         freebn(X2);
754         M = ecf_add(_3X2, aZ4, a->curve);
755         freebn(_3X2);
756         freebn(aZ4);
757     }
758
759     /* X' = M^2 - 2*S */
760     {
761         Bignum M2, _2S;
762
763         M2 = ecf_square(M, a->curve);
764         _2S = ecf_double(S, a->curve);
765         outx = modsub(M2, _2S, a->curve->p);
766         freebn(M2);
767         freebn(_2S);
768     }
769
770     /* Y' = M*(S - X') - 8*Y^4 */
771     {
772         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
773
774         SX = modsub(S, outx, a->curve->p);
775         freebn(S);
776         MSX = modmul(M, SX, a->curve->p);
777         freebn(SX);
778         freebn(M);
779         Y2 = ecf_square(a->y, a->curve);
780         Y4 = ecf_square(Y2, a->curve);
781         freebn(Y2);
782         Eight = bignum_from_long(8);
783         _8Y4 = modmul(Eight, Y4, a->curve->p);
784         freebn(Eight);
785         freebn(Y4);
786         outy = modsub(MSX, _8Y4, a->curve->p);
787         freebn(MSX);
788         freebn(_8Y4);
789     }
790
791     /* Z' = 2*Y*Z */
792     {
793         Bignum YZ;
794
795         if (a->z == NULL) {
796             YZ = copybn(a->y);
797         } else {
798             YZ = modmul(a->y, a->z, a->curve->p);
799         }
800
801         outz = ecf_double(YZ, a->curve);
802         freebn(YZ);
803     }
804
805     return ec_point_new(a->curve, outx, outy, outz, 0);
806 }
807
808 static struct ec_point *ecp_doublem(const struct ec_point *a)
809 {
810     Bignum z, outx, outz, xpz, xmz;
811
812     z = a->z;
813     if (!z) {
814         z = One;
815     }
816
817     /* 4xz = (x + z)^2 - (x - z)^2 */
818     {
819         Bignum tmp;
820
821         tmp = ecf_add(a->x, z, a->curve);
822         xpz = ecf_square(tmp, a->curve);
823         freebn(tmp);
824
825         tmp = modsub(a->x, z, a->curve->p);
826         xmz = ecf_square(tmp, a->curve);
827         freebn(tmp);
828     }
829
830     /* outx = (x + z)^2 * (x - z)^2 */
831     outx = modmul(xpz, xmz, a->curve->p);
832
833     /* outz = 4xz * ((x - z)^2 + ((A + 2) / 4)*4xz) */
834     {
835         Bignum _4xz, tmp, tmp2, tmp3;
836
837         tmp = bignum_from_long(2);
838         tmp2 = ecf_add(a->curve->m.a, tmp, a->curve);
839         freebn(tmp);
840
841         _4xz = modsub(xpz, xmz, a->curve->p);
842         freebn(xpz);
843         tmp = modmul(tmp2, _4xz, a->curve->p);
844         freebn(tmp2);
845
846         tmp2 = bignum_from_long(4);
847         tmp3 = modinv(tmp2, a->curve->p);
848         freebn(tmp2);
849         if (!tmp3) {
850             freebn(tmp);
851             freebn(_4xz);
852             freebn(outx);
853             freebn(xmz);
854             return NULL;
855         }
856         tmp2 = modmul(tmp, tmp3, a->curve->p);
857         freebn(tmp);
858         freebn(tmp3);
859
860         tmp = ecf_add(xmz, tmp2, a->curve);
861         freebn(xmz);
862         freebn(tmp2);
863         outz = modmul(_4xz, tmp, a->curve->p);
864         freebn(_4xz);
865         freebn(tmp);
866     }
867
868     return ec_point_new(a->curve, outx, NULL, outz, 0);
869 }
870
871 /* Forward declaration for Edwards curve doubling */
872 static struct ec_point *ecp_add(const struct ec_point *a,
873                                 const struct ec_point *b,
874                                 const int aminus3);
875
876 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
877 {
878     if (a->infinity)
879     {
880         /* Identity */
881         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
882     }
883
884     if (a->curve->type == EC_EDWARDS)
885     {
886         return ecp_add(a, a, aminus3);
887     }
888     else if (a->curve->type == EC_WEIERSTRASS)
889     {
890         return ecp_doublew(a, aminus3);
891     }
892     else
893     {
894         return ecp_doublem(a);
895     }
896 }
897
898 static struct ec_point *ecp_addw(const struct ec_point *a,
899                                  const struct ec_point *b,
900                                  const int aminus3)
901 {
902     Bignum U1, U2, S1, S2, outx, outy, outz;
903
904     /* U1 = X1*Z2^2 */
905     /* S1 = Y1*Z2^3 */
906     if (b->z) {
907         Bignum Z2, Z3;
908
909         Z2 = ecf_square(b->z, a->curve);
910         U1 = modmul(a->x, Z2, a->curve->p);
911         Z3 = modmul(Z2, b->z, a->curve->p);
912         freebn(Z2);
913         S1 = modmul(a->y, Z3, a->curve->p);
914         freebn(Z3);
915     } else {
916         U1 = copybn(a->x);
917         S1 = copybn(a->y);
918     }
919
920     /* U2 = X2*Z1^2 */
921     /* S2 = Y2*Z1^3 */
922     if (a->z) {
923         Bignum Z2, Z3;
924
925         Z2 = ecf_square(a->z, b->curve);
926         U2 = modmul(b->x, Z2, b->curve->p);
927         Z3 = modmul(Z2, a->z, b->curve->p);
928         freebn(Z2);
929         S2 = modmul(b->y, Z3, b->curve->p);
930         freebn(Z3);
931     } else {
932         U2 = copybn(b->x);
933         S2 = copybn(b->y);
934     }
935
936     /* Check if multiplying by self */
937     if (bignum_cmp(U1, U2) == 0)
938     {
939         freebn(U1);
940         freebn(U2);
941         if (bignum_cmp(S1, S2) == 0)
942         {
943             freebn(S1);
944             freebn(S2);
945             return ecp_double(a, aminus3);
946         }
947         else
948         {
949             freebn(S1);
950             freebn(S2);
951             /* Infinity */
952             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
953         }
954     }
955
956     {
957         Bignum H, R, UH2, H3;
958
959         /* H = U2 - U1 */
960         H = modsub(U2, U1, a->curve->p);
961         freebn(U2);
962
963         /* R = S2 - S1 */
964         R = modsub(S2, S1, a->curve->p);
965         freebn(S2);
966
967         /* X3 = R^2 - H^3 - 2*U1*H^2 */
968         {
969             Bignum R2, H2, _2UH2, first;
970
971             H2 = ecf_square(H, a->curve);
972             UH2 = modmul(U1, H2, a->curve->p);
973             freebn(U1);
974             H3 = modmul(H2, H, a->curve->p);
975             freebn(H2);
976             R2 = ecf_square(R, a->curve);
977             _2UH2 = ecf_double(UH2, a->curve);
978             first = modsub(R2, H3, a->curve->p);
979             freebn(R2);
980             outx = modsub(first, _2UH2, a->curve->p);
981             freebn(first);
982             freebn(_2UH2);
983         }
984
985         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
986         {
987             Bignum RUH2mX, UH2mX, SH3;
988
989             UH2mX = modsub(UH2, outx, a->curve->p);
990             freebn(UH2);
991             RUH2mX = modmul(R, UH2mX, a->curve->p);
992             freebn(UH2mX);
993             freebn(R);
994             SH3 = modmul(S1, H3, a->curve->p);
995             freebn(S1);
996             freebn(H3);
997
998             outy = modsub(RUH2mX, SH3, a->curve->p);
999             freebn(RUH2mX);
1000             freebn(SH3);
1001         }
1002
1003         /* Z3 = H*Z1*Z2 */
1004         if (a->z && b->z) {
1005             Bignum ZZ;
1006
1007             ZZ = modmul(a->z, b->z, a->curve->p);
1008             outz = modmul(H, ZZ, a->curve->p);
1009             freebn(H);
1010             freebn(ZZ);
1011         } else if (a->z) {
1012             outz = modmul(H, a->z, a->curve->p);
1013             freebn(H);
1014         } else if (b->z) {
1015             outz = modmul(H, b->z, a->curve->p);
1016             freebn(H);
1017         } else {
1018             outz = H;
1019         }
1020     }
1021
1022     return ec_point_new(a->curve, outx, outy, outz, 0);
1023 }
1024
1025 static struct ec_point *ecp_addm(const struct ec_point *a,
1026                                  const struct ec_point *b,
1027                                  const struct ec_point *base)
1028 {
1029     Bignum outx, outz, az, bz;
1030
1031     az = a->z;
1032     if (!az) {
1033         az = One;
1034     }
1035     bz = b->z;
1036     if (!bz) {
1037         bz = One;
1038     }
1039
1040     /* a-b is maintained at 1 due to Montgomery ladder implementation */
1041     /* Xa+b = Za-b * ((Xa - Za)*(Xb + Zb) + (Xa + Za)*(Xb - Zb))^2 */
1042     /* Za+b = Xa-b * ((Xa - Za)*(Xb + Zb) - (Xa + Za)*(Xb - Zb))^2 */
1043     {
1044         Bignum tmp, tmp2, tmp3, tmp4;
1045
1046         /* (Xa + Za) * (Xb - Zb) */
1047         tmp = ecf_add(a->x, az, a->curve);
1048         tmp2 = modsub(b->x, bz, a->curve->p);
1049         tmp3 = modmul(tmp, tmp2, a->curve->p);
1050         freebn(tmp);
1051         freebn(tmp2);
1052
1053         /* (Xa - Za) * (Xb + Zb) */
1054         tmp = modsub(a->x, az, a->curve->p);
1055         tmp2 = ecf_add(b->x, bz, a->curve);
1056         tmp4 = modmul(tmp, tmp2, a->curve->p);
1057         freebn(tmp);
1058         freebn(tmp2);
1059
1060         tmp = ecf_add(tmp3, tmp4, a->curve);
1061         outx = ecf_square(tmp, a->curve);
1062         freebn(tmp);
1063
1064         tmp = modsub(tmp3, tmp4, a->curve->p);
1065         freebn(tmp3);
1066         freebn(tmp4);
1067         tmp2 = ecf_square(tmp, a->curve);
1068         freebn(tmp);
1069         outz = modmul(base->x, tmp2, a->curve->p);
1070         freebn(tmp2);
1071     }
1072
1073     return ec_point_new(a->curve, outx, NULL, outz, 0);
1074 }
1075
1076 static struct ec_point *ecp_adde(const struct ec_point *a,
1077                                  const struct ec_point *b)
1078 {
1079     Bignum outx, outy, dmul;
1080
1081     /* outx = (a->x * b->y + b->x * a->y) /
1082      *        (1 + a->curve->e.d * a->x * b->x * a->y * b->y) */
1083     {
1084         Bignum tmp, tmp2, tmp3, tmp4;
1085
1086         tmp = modmul(a->x, b->y, a->curve->p);
1087         tmp2 = modmul(b->x, a->y, a->curve->p);
1088         tmp3 = ecf_add(tmp, tmp2, a->curve);
1089
1090         tmp4 = modmul(tmp, tmp2, a->curve->p);
1091         freebn(tmp);
1092         freebn(tmp2);
1093         dmul = modmul(a->curve->e.d, tmp4, a->curve->p);
1094         freebn(tmp4);
1095
1096         tmp = ecf_add(One, dmul, a->curve);
1097         tmp2 = modinv(tmp, a->curve->p);
1098         freebn(tmp);
1099         if (!tmp2)
1100         {
1101             freebn(tmp3);
1102             freebn(dmul);
1103             return NULL;
1104         }
1105
1106         outx = modmul(tmp3, tmp2, a->curve->p);
1107         freebn(tmp3);
1108         freebn(tmp2);
1109     }
1110
1111     /* outy = (a->y * b->y + a->x * b->x) /
1112      *        (1 - a->curve->e.d * a->x * b->x * a->y * b->y) */
1113     {
1114         Bignum tmp, tmp2, tmp3, tmp4;
1115
1116         tmp = modsub(One, dmul, a->curve->p);
1117         freebn(dmul);
1118
1119         tmp2 = modinv(tmp, a->curve->p);
1120         freebn(tmp);
1121         if (!tmp2)
1122         {
1123             freebn(outx);
1124             return NULL;
1125         }
1126
1127         tmp = modmul(a->y, b->y, a->curve->p);
1128         tmp3 = modmul(a->x, b->x, a->curve->p);
1129         tmp4 = ecf_add(tmp, tmp3, a->curve);
1130         freebn(tmp);
1131         freebn(tmp3);
1132
1133         outy = modmul(tmp4, tmp2, a->curve->p);
1134         freebn(tmp4);
1135         freebn(tmp2);
1136     }
1137
1138     return ec_point_new(a->curve, outx, outy, NULL, 0);
1139 }
1140
1141 static struct ec_point *ecp_add(const struct ec_point *a,
1142                                 const struct ec_point *b,
1143                                 const int aminus3)
1144 {
1145     if (a->curve != b->curve) {
1146         return NULL;
1147     }
1148
1149     /* Check if multiplying by infinity */
1150     if (a->infinity) return ec_point_copy(b);
1151     if (b->infinity) return ec_point_copy(a);
1152
1153     if (a->curve->type == EC_EDWARDS)
1154     {
1155         return ecp_adde(a, b);
1156     }
1157
1158     if (a->curve->type == EC_WEIERSTRASS)
1159     {
1160         return ecp_addw(a, b, aminus3);
1161     }
1162
1163     return NULL;
1164 }
1165
1166 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1167 {
1168     struct ec_point *A, *ret;
1169     int bits, i;
1170
1171     A = ec_point_copy(a);
1172     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1173
1174     bits = bignum_bitcount(b);
1175     for (i = 0; i < bits; ++i)
1176     {
1177         if (bignum_bit(b, i))
1178         {
1179             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1180             ec_point_free(ret);
1181             ret = tmp;
1182         }
1183         if (i+1 != bits)
1184         {
1185             struct ec_point *tmp = ecp_double(A, aminus3);
1186             ec_point_free(A);
1187             A = tmp;
1188         }
1189     }
1190
1191     ec_point_free(A);
1192     return ret;
1193 }
1194
1195 static struct ec_point *ecp_mulw(const struct ec_point *a, const Bignum b)
1196 {
1197     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1198
1199     if (!ecp_normalise(ret)) {
1200         ec_point_free(ret);
1201         return NULL;
1202     }
1203
1204     return ret;
1205 }
1206
1207 static struct ec_point *ecp_mule(const struct ec_point *a, const Bignum b)
1208 {
1209     int i;
1210     struct ec_point *ret;
1211
1212     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1213
1214     for (i = bignum_bitcount(b); i >= 0 && ret; --i)
1215     {
1216         {
1217             struct ec_point *tmp = ecp_double(ret, 0);
1218             ec_point_free(ret);
1219             ret = tmp;
1220         }
1221         if (ret && bignum_bit(b, i))
1222         {
1223             struct ec_point *tmp = ecp_add(ret, a, 0);
1224             ec_point_free(ret);
1225             ret = tmp;
1226         }
1227     }
1228
1229     return ret;
1230 }
1231
1232 static struct ec_point *ecp_mulm(const struct ec_point *p, const Bignum n)
1233 {
1234     struct ec_point *P1, *P2;
1235     int bits, i;
1236
1237     /* P1 <- P and P2 <- [2]P */
1238     P2 = ecp_double(p, 0);
1239     P1 = ec_point_copy(p);
1240
1241     /* for i = bits âˆ’ 2 down to 0 */
1242     bits = bignum_bitcount(n);
1243     for (i = bits - 2; i >= 0; --i)
1244     {
1245         if (!bignum_bit(n, i))
1246         {
1247             /* P2 <- P1 + P2 */
1248             struct ec_point *tmp = ecp_addm(P1, P2, p);
1249             ec_point_free(P2);
1250             P2 = tmp;
1251
1252             /* P1 <- [2]P1 */
1253             tmp = ecp_double(P1, 0);
1254             ec_point_free(P1);
1255             P1 = tmp;
1256         }
1257         else
1258         {
1259             /* P1 <- P1 + P2 */
1260             struct ec_point *tmp = ecp_addm(P1, P2, p);
1261             ec_point_free(P1);
1262             P1 = tmp;
1263
1264             /* P2 <- [2]P2 */
1265             tmp = ecp_double(P2, 0);
1266             ec_point_free(P2);
1267             P2 = tmp;
1268         }
1269     }
1270
1271     ec_point_free(P2);
1272
1273     if (!ecp_normalise(P1)) {
1274         ec_point_free(P1);
1275         return NULL;
1276     }
1277
1278     return P1;
1279 }
1280
1281 /* Not static because it is used by sshecdsag.c to generate a new key */
1282 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1283 {
1284     if (a->curve->type == EC_WEIERSTRASS) {
1285         return ecp_mulw(a, b);
1286     } else if (a->curve->type == EC_EDWARDS) {
1287         return ecp_mule(a, b);
1288     } else {
1289         return ecp_mulm(a, b);
1290     }
1291 }
1292
1293 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
1294                                    const struct ec_point *point)
1295 {
1296     struct ec_point *aG, *bP, *ret;
1297     int aminus3;
1298
1299     if (point->curve->type != EC_WEIERSTRASS) {
1300         return NULL;
1301     }
1302
1303     aminus3 = ec_aminus3(point->curve);
1304
1305     aG = ecp_mul_(&point->curve->w.G, a, aminus3);
1306     if (!aG) return NULL;
1307     bP = ecp_mul_(point, b, aminus3);
1308     if (!bP) {
1309         ec_point_free(aG);
1310         return NULL;
1311     }
1312
1313     ret = ecp_add(aG, bP, aminus3);
1314
1315     ec_point_free(aG);
1316     ec_point_free(bP);
1317
1318     if (!ecp_normalise(ret)) {
1319         ec_point_free(ret);
1320         return NULL;
1321     }
1322
1323     return ret;
1324 }
1325 static Bignum *ecp_edx(const struct ec_curve *curve, const Bignum y)
1326 {
1327     /* Get the x value on the given Edwards curve for a given y */
1328     Bignum x, xx;
1329
1330     /* xx = (y^2 - 1) / (d * y^2 + 1) */
1331     {
1332         Bignum tmp, tmp2, tmp3;
1333
1334         tmp = ecf_square(y, curve);
1335         tmp2 = modmul(curve->e.d, tmp, curve->p);
1336         tmp3 = ecf_add(tmp2, One, curve);
1337         freebn(tmp2);
1338         tmp2 = modinv(tmp3, curve->p);
1339         freebn(tmp3);
1340         if (!tmp2) {
1341             freebn(tmp);
1342             return NULL;
1343         }
1344
1345         tmp3 = modsub(tmp, One, curve->p);
1346         freebn(tmp);
1347         xx = modmul(tmp3, tmp2, curve->p);
1348         freebn(tmp3);
1349         freebn(tmp2);
1350     }
1351
1352     /* x = xx^((p + 3) / 8) */
1353     {
1354         Bignum tmp, tmp2;
1355
1356         tmp = bignum_add_long(curve->p, 3);
1357         tmp2 = bignum_rshift(tmp, 3);
1358         freebn(tmp);
1359         x = modpow(xx, tmp2, curve->p);
1360         freebn(tmp2);
1361     }
1362
1363     /* if x^2 - xx != 0 then x = x*(2^((p - 1) / 4)) */
1364     {
1365         Bignum tmp, tmp2;
1366
1367         tmp = ecf_square(x, curve);
1368         tmp2 = modsub(tmp, xx, curve->p);
1369         freebn(tmp);
1370         freebn(xx);
1371         if (bignum_cmp(tmp2, Zero)) {
1372             Bignum tmp3;
1373
1374             freebn(tmp2);
1375
1376             tmp = modsub(curve->p, One, curve->p);
1377             tmp2 = bignum_rshift(tmp, 2);
1378             freebn(tmp);
1379             tmp = bignum_from_long(2);
1380             tmp3 = modpow(tmp, tmp2, curve->p);
1381             freebn(tmp);
1382             freebn(tmp2);
1383
1384             tmp = modmul(x, tmp3, curve->p);
1385             freebn(x);
1386             freebn(tmp3);
1387             x = tmp;
1388         } else {
1389             freebn(tmp2);
1390         }
1391     }
1392
1393     /* if x % 2 != 0 then x = p - x */
1394     if (bignum_bit(x, 0)) {
1395         Bignum tmp = modsub(curve->p, x, curve->p);
1396         freebn(x);
1397         x = tmp;
1398     }
1399
1400     return x;
1401 }
1402
1403 /* ----------------------------------------------------------------------
1404  * Public point from private
1405  */
1406
1407 struct ec_point *ec_public(const Bignum privateKey, const struct ec_curve *curve)
1408 {
1409     if (curve->type == EC_WEIERSTRASS) {
1410         return ecp_mul(&curve->w.G, privateKey);
1411     } else if (curve->type == EC_EDWARDS) {
1412         /* hash = H(sk) (where hash creates 2 * fieldBits)
1413          * b = fieldBits
1414          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
1415          * publicKey = aB */
1416         struct ec_point *ret;
1417         unsigned char hash[512/8];
1418         Bignum a;
1419         int i, keylen;
1420         SHA512_State s;
1421         SHA512_Init(&s);
1422
1423         keylen = curve->fieldBits / 8;
1424         for (i = 0; i < keylen; ++i) {
1425             unsigned char b = bignum_byte(privateKey, i);
1426             SHA512_Bytes(&s, &b, 1);
1427         }
1428         SHA512_Final(&s, hash);
1429
1430         /* The second part is simply turning the hash into a Bignum,
1431          * however the 2^(b-2) bit *must* be set, and the bottom 3
1432          * bits *must* not be */
1433         hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
1434         hash[31] &= 0x7f; /* Unset above (b-2) */
1435         hash[31] |= 0x40; /* Set 2^(b-2) */
1436         /* Chop off the top part and convert to int */
1437         a = bignum_from_bytes_le(hash, 32);
1438
1439         ret = ecp_mul(&curve->e.B, a);
1440         freebn(a);
1441         return ret;
1442     } else {
1443         return NULL;
1444     }
1445 }
1446
1447 /* ----------------------------------------------------------------------
1448  * Basic sign and verify routines
1449  */
1450
1451 static int _ecdsa_verify(const struct ec_point *publicKey,
1452                          const unsigned char *data, const int dataLen,
1453                          const Bignum r, const Bignum s)
1454 {
1455     int z_bits, n_bits;
1456     Bignum z;
1457     int valid = 0;
1458
1459     if (publicKey->curve->type != EC_WEIERSTRASS) {
1460         return 0;
1461     }
1462
1463     /* Sanity checks */
1464     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->w.n) >= 0
1465         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->w.n) >= 0)
1466     {
1467         return 0;
1468     }
1469
1470     /* z = left most bitlen(curve->n) of data */
1471     z = bignum_from_bytes(data, dataLen);
1472     n_bits = bignum_bitcount(publicKey->curve->w.n);
1473     z_bits = bignum_bitcount(z);
1474     if (z_bits > n_bits)
1475     {
1476         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
1477         freebn(z);
1478         z = tmp;
1479     }
1480
1481     /* Ensure z in range of n */
1482     {
1483         Bignum tmp = bigmod(z, publicKey->curve->w.n);
1484         freebn(z);
1485         z = tmp;
1486     }
1487
1488     /* Calculate signature */
1489     {
1490         Bignum w, x, u1, u2;
1491         struct ec_point *tmp;
1492
1493         w = modinv(s, publicKey->curve->w.n);
1494         if (!w) {
1495             freebn(z);
1496             return 0;
1497         }
1498         u1 = modmul(z, w, publicKey->curve->w.n);
1499         u2 = modmul(r, w, publicKey->curve->w.n);
1500         freebn(w);
1501
1502         tmp = ecp_summul(u1, u2, publicKey);
1503         freebn(u1);
1504         freebn(u2);
1505         if (!tmp) {
1506             freebn(z);
1507             return 0;
1508         }
1509
1510         x = bigmod(tmp->x, publicKey->curve->w.n);
1511         ec_point_free(tmp);
1512
1513         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
1514         freebn(x);
1515     }
1516
1517     freebn(z);
1518
1519     return valid;
1520 }
1521
1522 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
1523                         const unsigned char *data, const int dataLen,
1524                         Bignum *r, Bignum *s)
1525 {
1526     unsigned char digest[20];
1527     int z_bits, n_bits;
1528     Bignum z, k;
1529     struct ec_point *kG;
1530
1531     *r = NULL;
1532     *s = NULL;
1533
1534     if (curve->type != EC_WEIERSTRASS) {
1535         return;
1536     }
1537
1538     /* z = left most bitlen(curve->n) of data */
1539     z = bignum_from_bytes(data, dataLen);
1540     n_bits = bignum_bitcount(curve->w.n);
1541     z_bits = bignum_bitcount(z);
1542     if (z_bits > n_bits)
1543     {
1544         Bignum tmp;
1545         tmp = bignum_rshift(z, z_bits - n_bits);
1546         freebn(z);
1547         z = tmp;
1548     }
1549
1550     /* Generate k between 1 and curve->n, using the same deterministic
1551      * k generation system we use for conventional DSA. */
1552     SHA_Simple(data, dataLen, digest);
1553     k = dss_gen_k("ECDSA deterministic k generator", curve->w.n, privateKey,
1554                   digest, sizeof(digest));
1555
1556     kG = ecp_mul(&curve->w.G, k);
1557     if (!kG) {
1558         freebn(z);
1559         freebn(k);
1560         return;
1561     }
1562
1563     /* r = kG.x mod n */
1564     *r = bigmod(kG->x, curve->w.n);
1565     ec_point_free(kG);
1566
1567     /* s = (z + r * priv)/k mod n */
1568     {
1569         Bignum rPriv, zMod, first, firstMod, kInv;
1570         rPriv = modmul(*r, privateKey, curve->w.n);
1571         zMod = bigmod(z, curve->w.n);
1572         freebn(z);
1573         first = bigadd(rPriv, zMod);
1574         freebn(rPriv);
1575         freebn(zMod);
1576         firstMod = bigmod(first, curve->w.n);
1577         freebn(first);
1578         kInv = modinv(k, curve->w.n);
1579         freebn(k);
1580         if (!kInv) {
1581             freebn(firstMod);
1582             freebn(*r);
1583             return;
1584         }
1585         *s = modmul(firstMod, kInv, curve->w.n);
1586         freebn(firstMod);
1587         freebn(kInv);
1588     }
1589 }
1590
1591 /* ----------------------------------------------------------------------
1592  * Misc functions
1593  */
1594
1595 static void getstring(const char **data, int *datalen,
1596                       const char **p, int *length)
1597 {
1598     *p = NULL;
1599     if (*datalen < 4)
1600         return;
1601     *length = toint(GET_32BIT(*data));
1602     if (*length < 0)
1603         return;
1604     *datalen -= 4;
1605     *data += 4;
1606     if (*datalen < *length)
1607         return;
1608     *p = *data;
1609     *data += *length;
1610     *datalen -= *length;
1611 }
1612
1613 static Bignum getmp(const char **data, int *datalen)
1614 {
1615     const char *p;
1616     int length;
1617
1618     getstring(data, datalen, &p, &length);
1619     if (!p)
1620         return NULL;
1621     if (p[0] & 0x80)
1622         return NULL;                   /* negative mp */
1623     return bignum_from_bytes((unsigned char *)p, length);
1624 }
1625
1626 static Bignum getmp_le(const char **data, int *datalen)
1627 {
1628     const char *p;
1629     int length;
1630
1631     getstring(data, datalen, &p, &length);
1632     if (!p)
1633         return NULL;
1634     return bignum_from_bytes_le((const unsigned char *)p, length);
1635 }
1636
1637 static int decodepoint_ed(const char *p, int length, struct ec_point *point)
1638 {
1639     /* Got some conversion to do, first read in the y co-ord */
1640     int negative;
1641
1642     point->y = bignum_from_bytes_le((const unsigned char*)p, length);
1643     if ((unsigned)bignum_bitcount(point->y) > point->curve->fieldBits) {
1644         freebn(point->y);
1645         point->y = NULL;
1646         return 0;
1647     }
1648     /* Read x bit and then reset it */
1649     negative = bignum_bit(point->y, point->curve->fieldBits - 1);
1650     bignum_set_bit(point->y, point->curve->fieldBits - 1, 0);
1651
1652     /* Get the x from the y */
1653     point->x = ecp_edx(point->curve, point->y);
1654     if (!point->x) {
1655         freebn(point->y);
1656         point->y = NULL;
1657         return 0;
1658     }
1659     if (negative) {
1660         Bignum tmp = modsub(point->curve->p, point->x, point->curve->p);
1661         freebn(point->x);
1662         point->x = tmp;
1663     }
1664
1665     /* Verify the point is on the curve */
1666     if (!ec_point_verify(point)) {
1667         freebn(point->x);
1668         point->x = NULL;
1669         freebn(point->y);
1670         point->y = NULL;
1671         return 0;
1672     }
1673
1674     return 1;
1675 }
1676
1677 static int decodepoint(const char *p, int length, struct ec_point *point)
1678 {
1679     if (point->curve->type == EC_EDWARDS) {
1680         return decodepoint_ed(p, length, point);
1681     }
1682
1683     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
1684         return 0;
1685     /* Skip compression flag */
1686     ++p;
1687     --length;
1688     /* The two values must be equal length */
1689     if (length % 2 != 0) {
1690         point->x = NULL;
1691         point->y = NULL;
1692         point->z = NULL;
1693         return 0;
1694     }
1695     length = length / 2;
1696     point->x = bignum_from_bytes((const unsigned char *)p, length);
1697     p += length;
1698     point->y = bignum_from_bytes((const unsigned char *)p, length);
1699     point->z = NULL;
1700
1701     /* Verify the point is on the curve */
1702     if (!ec_point_verify(point)) {
1703         freebn(point->x);
1704         point->x = NULL;
1705         freebn(point->y);
1706         point->y = NULL;
1707         return 0;
1708     }
1709
1710     return 1;
1711 }
1712
1713 static int getmppoint(const char **data, int *datalen, struct ec_point *point)
1714 {
1715     const char *p;
1716     int length;
1717
1718     getstring(data, datalen, &p, &length);
1719     if (!p) return 0;
1720     return decodepoint(p, length, point);
1721 }
1722
1723 /* ----------------------------------------------------------------------
1724  * Exposed ECDSA interface
1725  */
1726
1727 struct ecsign_extra {
1728     struct ec_curve *(*curve)(void);
1729     const struct ssh_hash *hash;
1730
1731     /* These fields are used by the OpenSSH PEM format importer/exporter */
1732     const unsigned char *oid;
1733     int oidlen;
1734 };
1735
1736 static void ecdsa_freekey(void *key)
1737 {
1738     struct ec_key *ec = (struct ec_key *) key;
1739     if (!ec) return;
1740
1741     if (ec->publicKey.x)
1742         freebn(ec->publicKey.x);
1743     if (ec->publicKey.y)
1744         freebn(ec->publicKey.y);
1745     if (ec->publicKey.z)
1746         freebn(ec->publicKey.z);
1747     if (ec->privateKey)
1748         freebn(ec->privateKey);
1749     sfree(ec);
1750 }
1751
1752 static void *ecdsa_newkey(const struct ssh_signkey *self,
1753                           const char *data, int len)
1754 {
1755     const struct ecsign_extra *extra =
1756         (const struct ecsign_extra *)self->extra;
1757     const char *p;
1758     int slen;
1759     struct ec_key *ec;
1760     struct ec_curve *curve;
1761
1762     getstring(&data, &len, &p, &slen);
1763
1764     if (!p) {
1765         return NULL;
1766     }
1767     curve = extra->curve();
1768     assert(curve->type == EC_WEIERSTRASS || curve->type == EC_EDWARDS);
1769
1770     /* Curve name is duplicated for Weierstrass form */
1771     if (curve->type == EC_WEIERSTRASS) {
1772         getstring(&data, &len, &p, &slen);
1773         if (!match_ssh_id(slen, p, curve->name)) return NULL;
1774     }
1775
1776     ec = snew(struct ec_key);
1777
1778     ec->signalg = self;
1779     ec->publicKey.curve = curve;
1780     ec->publicKey.infinity = 0;
1781     ec->publicKey.x = NULL;
1782     ec->publicKey.y = NULL;
1783     ec->publicKey.z = NULL;
1784     if (!getmppoint(&data, &len, &ec->publicKey)) {
1785         ecdsa_freekey(ec);
1786         return NULL;
1787     }
1788     ec->privateKey = NULL;
1789
1790     if (!ec->publicKey.x || !ec->publicKey.y ||
1791         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
1792         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
1793     {
1794         ecdsa_freekey(ec);
1795         ec = NULL;
1796     }
1797
1798     return ec;
1799 }
1800
1801 static char *ecdsa_fmtkey(void *key)
1802 {
1803     struct ec_key *ec = (struct ec_key *) key;
1804     char *p;
1805     int len, i, pos, nibbles;
1806     static const char hex[] = "0123456789abcdef";
1807     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
1808         return NULL;
1809
1810     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
1811     if (ec->publicKey.curve->name)
1812         len += strlen(ec->publicKey.curve->name); /* Curve name */
1813     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
1814     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
1815     p = snewn(len, char);
1816
1817     pos = 0;
1818     if (ec->publicKey.curve->name)
1819         pos += sprintf(p + pos, "%s,", ec->publicKey.curve->name);
1820     pos += sprintf(p + pos, "0x");
1821     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
1822     if (nibbles < 1)
1823         nibbles = 1;
1824     for (i = nibbles; i--;) {
1825         p[pos++] =
1826             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
1827     }
1828     pos += sprintf(p + pos, ",0x");
1829     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
1830     if (nibbles < 1)
1831         nibbles = 1;
1832     for (i = nibbles; i--;) {
1833         p[pos++] =
1834             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
1835     }
1836     p[pos] = '\0';
1837     return p;
1838 }
1839
1840 static unsigned char *ecdsa_public_blob(void *key, int *len)
1841 {
1842     struct ec_key *ec = (struct ec_key *) key;
1843     int pointlen, bloblen, fullnamelen, namelen;
1844     int i;
1845     unsigned char *blob, *p;
1846
1847     fullnamelen = strlen(ec->signalg->name);
1848
1849     if (ec->publicKey.curve->type == EC_EDWARDS) {
1850         /* Edwards compressed form "ssh-ed25519" point y[:-1] + x[0:1] */
1851
1852         pointlen = ec->publicKey.curve->fieldBits / 8;
1853
1854         /* Can't handle this in our loop */
1855         if (pointlen < 2) return NULL;
1856
1857         bloblen = 4 + fullnamelen + 4 + pointlen;
1858         blob = snewn(bloblen, unsigned char);
1859
1860         p = blob;
1861         PUT_32BIT(p, fullnamelen);
1862         p += 4;
1863         memcpy(p, ec->signalg->name, fullnamelen);
1864         p += fullnamelen;
1865         PUT_32BIT(p, pointlen);
1866         p += 4;
1867
1868         /* Unset last bit of y and set first bit of x in its place */
1869         for (i = 0; i < pointlen - 1; ++i) {
1870             *p++ = bignum_byte(ec->publicKey.y, i);
1871         }
1872         /* Unset last bit of y and set first bit of x in its place */
1873         *p = bignum_byte(ec->publicKey.y, i) & 0x7f;
1874         *p++ |= bignum_bit(ec->publicKey.x, 0) << 7;
1875     } else if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
1876         assert(ec->publicKey.curve->name);
1877         namelen = strlen(ec->publicKey.curve->name);
1878
1879         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1880
1881         /*
1882          * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
1883          */
1884         bloblen = 4 + fullnamelen + 4 + namelen + 4 + 1 + (pointlen * 2);
1885         blob = snewn(bloblen, unsigned char);
1886
1887         p = blob;
1888         PUT_32BIT(p, fullnamelen);
1889         p += 4;
1890         memcpy(p, ec->signalg->name, fullnamelen);
1891         p += fullnamelen;
1892         PUT_32BIT(p, namelen);
1893         p += 4;
1894         memcpy(p, ec->publicKey.curve->name, namelen);
1895         p += namelen;
1896         PUT_32BIT(p, (2 * pointlen) + 1);
1897         p += 4;
1898         *p++ = 0x04;
1899         for (i = pointlen; i--;) {
1900             *p++ = bignum_byte(ec->publicKey.x, i);
1901         }
1902         for (i = pointlen; i--;) {
1903             *p++ = bignum_byte(ec->publicKey.y, i);
1904         }
1905     } else {
1906         return NULL;
1907     }
1908
1909     assert(p == blob + bloblen);
1910     *len = bloblen;
1911
1912     return blob;
1913 }
1914
1915 static unsigned char *ecdsa_private_blob(void *key, int *len)
1916 {
1917     struct ec_key *ec = (struct ec_key *) key;
1918     int keylen, bloblen;
1919     int i;
1920     unsigned char *blob, *p;
1921
1922     if (!ec->privateKey) return NULL;
1923
1924     if (ec->publicKey.curve->type == EC_EDWARDS) {
1925         /* Unsigned */
1926         keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
1927     } else {
1928         /* Signed */
1929         keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
1930     }
1931
1932     /*
1933      * mpint privateKey. Total 4 + keylen.
1934      */
1935     bloblen = 4 + keylen;
1936     blob = snewn(bloblen, unsigned char);
1937
1938     p = blob;
1939     PUT_32BIT(p, keylen);
1940     p += 4;
1941     if (ec->publicKey.curve->type == EC_EDWARDS) {
1942         /* Little endian */
1943         for (i = 0; i < keylen; ++i)
1944             *p++ = bignum_byte(ec->privateKey, i);
1945     } else {
1946         for (i = keylen; i--;)
1947             *p++ = bignum_byte(ec->privateKey, i);
1948     }
1949
1950     assert(p == blob + bloblen);
1951     *len = bloblen;
1952     return blob;
1953 }
1954
1955 static void *ecdsa_createkey(const struct ssh_signkey *self,
1956                              const unsigned char *pub_blob, int pub_len,
1957                              const unsigned char *priv_blob, int priv_len)
1958 {
1959     struct ec_key *ec;
1960     struct ec_point *publicKey;
1961     const char *pb = (const char *) priv_blob;
1962
1963     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) pub_blob, pub_len);
1964     if (!ec) {
1965         return NULL;
1966     }
1967
1968     if (ec->publicKey.curve->type != EC_WEIERSTRASS
1969         && ec->publicKey.curve->type != EC_EDWARDS) {
1970         ecdsa_freekey(ec);
1971         return NULL;
1972     }
1973
1974     if (ec->publicKey.curve->type == EC_EDWARDS) {
1975         ec->privateKey = getmp_le(&pb, &priv_len);
1976     } else {
1977         ec->privateKey = getmp(&pb, &priv_len);
1978     }
1979     if (!ec->privateKey) {
1980         ecdsa_freekey(ec);
1981         return NULL;
1982     }
1983
1984     /* Check that private key generates public key */
1985     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
1986
1987     if (!publicKey ||
1988         bignum_cmp(publicKey->x, ec->publicKey.x) ||
1989         bignum_cmp(publicKey->y, ec->publicKey.y))
1990     {
1991         ecdsa_freekey(ec);
1992         ec = NULL;
1993     }
1994     ec_point_free(publicKey);
1995
1996     return ec;
1997 }
1998
1999 static void *ed25519_openssh_createkey(const struct ssh_signkey *self,
2000                                        const unsigned char **blob, int *len)
2001 {
2002     struct ec_key *ec;
2003     struct ec_point *publicKey;
2004     const char *p, *q;
2005     int plen, qlen;
2006
2007     getstring((const char**)blob, len, &p, &plen);
2008     if (!p)
2009     {
2010         return NULL;
2011     }
2012
2013     ec = snew(struct ec_key);
2014
2015     ec->signalg = self;
2016     ec->publicKey.curve = ec_ed25519();
2017     ec->publicKey.infinity = 0;
2018     ec->privateKey = NULL;
2019     ec->publicKey.x = NULL;
2020     ec->publicKey.z = NULL;
2021     ec->publicKey.y = NULL;
2022
2023     if (!decodepoint_ed(p, plen, &ec->publicKey))
2024     {
2025         ecdsa_freekey(ec);
2026         return NULL;
2027     }
2028
2029     getstring((const char**)blob, len, &q, &qlen);
2030     if (!q)
2031         return NULL;
2032     if (qlen != 64)
2033         return NULL;
2034
2035     ec->privateKey = bignum_from_bytes_le((const unsigned char *)q, 32);
2036
2037     /* Check that private key generates public key */
2038     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2039
2040     if (!publicKey ||
2041         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2042         bignum_cmp(publicKey->y, ec->publicKey.y))
2043     {
2044         ecdsa_freekey(ec);
2045         ec = NULL;
2046     }
2047     ec_point_free(publicKey);
2048
2049     /* The OpenSSH format for ed25519 private keys also for some
2050      * reason encodes an extra copy of the public key in the second
2051      * half of the secret-key string. Check that that's present and
2052      * correct as well, otherwise the key we think we've imported
2053      * won't behave identically to the way OpenSSH would have treated
2054      * it. */
2055     if (plen != 32 || 0 != memcmp(q + 32, p, 32)) {
2056         ecdsa_freekey(ec);
2057         return NULL;
2058     }
2059
2060     return ec;
2061 }
2062
2063 static int ed25519_openssh_fmtkey(void *key, unsigned char *blob, int len)
2064 {
2065     struct ec_key *ec = (struct ec_key *) key;
2066
2067     int pointlen;
2068     int keylen;
2069     int bloblen;
2070     int i;
2071
2072     if (ec->publicKey.curve->type != EC_EDWARDS) {
2073         return 0;
2074     }
2075
2076     pointlen = (bignum_bitcount(ec->publicKey.y) + 7) / 8;
2077     keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2078     bloblen = 4 + pointlen + 4 + keylen + pointlen;
2079
2080     if (bloblen > len)
2081         return bloblen;
2082
2083     /* Encode the public point */
2084     PUT_32BIT(blob, pointlen);
2085     blob += 4;
2086
2087     for (i = 0; i < pointlen - 1; ++i) {
2088          *blob++ = bignum_byte(ec->publicKey.y, i);
2089     }
2090     /* Unset last bit of y and set first bit of x in its place */
2091     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2092     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2093
2094     PUT_32BIT(blob, keylen + pointlen);
2095     blob += 4;
2096     for (i = 0; i < keylen; ++i) {
2097          *blob++ = bignum_byte(ec->privateKey, i);
2098     }
2099     /* Now encode an extra copy of the public point as the second half
2100      * of the private key string, as the OpenSSH format for some
2101      * reason requires */
2102     for (i = 0; i < pointlen - 1; ++i) {
2103          *blob++ = bignum_byte(ec->publicKey.y, i);
2104     }
2105     /* Unset last bit of y and set first bit of x in its place */
2106     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2107     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2108
2109     return bloblen;
2110 }
2111
2112 static void *ecdsa_openssh_createkey(const struct ssh_signkey *self,
2113                                      const unsigned char **blob, int *len)
2114 {
2115     const struct ecsign_extra *extra =
2116         (const struct ecsign_extra *)self->extra;
2117     const char **b = (const char **) blob;
2118     const char *p;
2119     int slen;
2120     struct ec_key *ec;
2121     struct ec_curve *curve;
2122     struct ec_point *publicKey;
2123
2124     getstring(b, len, &p, &slen);
2125
2126     if (!p) {
2127         return NULL;
2128     }
2129     curve = extra->curve();
2130     assert(curve->type == EC_WEIERSTRASS);
2131
2132     ec = snew(struct ec_key);
2133
2134     ec->signalg = self;
2135     ec->publicKey.curve = curve;
2136     ec->publicKey.infinity = 0;
2137     ec->publicKey.x = NULL;
2138     ec->publicKey.y = NULL;
2139     ec->publicKey.z = NULL;
2140     if (!getmppoint(b, len, &ec->publicKey)) {
2141         ecdsa_freekey(ec);
2142         return NULL;
2143     }
2144     ec->privateKey = NULL;
2145
2146     if (!ec->publicKey.x || !ec->publicKey.y ||
2147         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2148         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2149     {
2150         ecdsa_freekey(ec);
2151         return NULL;
2152     }
2153
2154     ec->privateKey = getmp(b, len);
2155     if (ec->privateKey == NULL)
2156     {
2157         ecdsa_freekey(ec);
2158         return NULL;
2159     }
2160
2161     /* Now check that the private key makes the public key */
2162     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2163     if (!publicKey)
2164     {
2165         ecdsa_freekey(ec);
2166         return NULL;
2167     }
2168
2169     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
2170         bignum_cmp(ec->publicKey.y, publicKey->y))
2171     {
2172         /* Private key doesn't make the public key on the given curve */
2173         ecdsa_freekey(ec);
2174         ec_point_free(publicKey);
2175         return NULL;
2176     }
2177
2178     ec_point_free(publicKey);
2179
2180     return ec;
2181 }
2182
2183 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
2184 {
2185     struct ec_key *ec = (struct ec_key *) key;
2186
2187     int pointlen;
2188     int namelen;
2189     int bloblen;
2190     int i;
2191
2192     if (ec->publicKey.curve->type != EC_WEIERSTRASS) {
2193         return 0;
2194     }
2195
2196     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2197     namelen = strlen(ec->publicKey.curve->name);
2198     bloblen =
2199         4 + namelen /* <LEN> nistpXXX */
2200         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
2201         + ssh2_bignum_length(ec->privateKey);
2202
2203     if (bloblen > len)
2204         return bloblen;
2205
2206     bloblen = 0;
2207
2208     PUT_32BIT(blob+bloblen, namelen);
2209     bloblen += 4;
2210     memcpy(blob+bloblen, ec->publicKey.curve->name, namelen);
2211     bloblen += namelen;
2212
2213     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
2214     bloblen += 4;
2215     blob[bloblen++] = 0x04;
2216     for (i = pointlen; i--; )
2217         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
2218     for (i = pointlen; i--; )
2219         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
2220
2221     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2222     PUT_32BIT(blob+bloblen, pointlen);
2223     bloblen += 4;
2224     for (i = pointlen; i--; )
2225         blob[bloblen++] = bignum_byte(ec->privateKey, i);
2226
2227     return bloblen;
2228 }
2229
2230 static int ecdsa_pubkey_bits(const struct ssh_signkey *self,
2231                              const void *blob, int len)
2232 {
2233     struct ec_key *ec;
2234     int ret;
2235
2236     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) blob, len);
2237     if (!ec)
2238         return -1;
2239     ret = ec->publicKey.curve->fieldBits;
2240     ecdsa_freekey(ec);
2241
2242     return ret;
2243 }
2244
2245 static int ecdsa_verifysig(void *key, const char *sig, int siglen,
2246                            const char *data, int datalen)
2247 {
2248     struct ec_key *ec = (struct ec_key *) key;
2249     const struct ecsign_extra *extra =
2250         (const struct ecsign_extra *)ec->signalg->extra;
2251     const char *p;
2252     int slen;
2253     int digestLen;
2254     int ret;
2255
2256     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2257         return 0;
2258
2259     /* Check the signature starts with the algorithm name */
2260     getstring(&sig, &siglen, &p, &slen);
2261     if (!p) {
2262         return 0;
2263     }
2264     if (!match_ssh_id(slen, p, ec->signalg->name)) {
2265         return 0;
2266     }
2267
2268     getstring(&sig, &siglen, &p, &slen);
2269     if (ec->publicKey.curve->type == EC_EDWARDS) {
2270         struct ec_point *r;
2271         Bignum s, h;
2272
2273         /* Check that the signature is two times the length of a point */
2274         if (slen != (ec->publicKey.curve->fieldBits / 8) * 2) {
2275             return 0;
2276         }
2277
2278         /* Check it's the 256 bit field so that SHA512 is the correct hash */
2279         if (ec->publicKey.curve->fieldBits != 256) {
2280             return 0;
2281         }
2282
2283         /* Get the signature */
2284         r = ec_point_new(ec->publicKey.curve, NULL, NULL, NULL, 0);
2285         if (!r) {
2286             return 0;
2287         }
2288         if (!decodepoint(p, ec->publicKey.curve->fieldBits / 8, r)) {
2289             ec_point_free(r);
2290             return 0;
2291         }
2292         s = bignum_from_bytes_le((unsigned char*)p + (ec->publicKey.curve->fieldBits / 8),
2293                                  ec->publicKey.curve->fieldBits / 8);
2294
2295         /* Get the hash of the encoded value of R + encoded value of pk + message */
2296         {
2297             int i, pointlen;
2298             unsigned char b;
2299             unsigned char digest[512 / 8];
2300             SHA512_State hs;
2301             SHA512_Init(&hs);
2302
2303             /* Add encoded r (no need to encode it again, it was in the signature) */
2304             SHA512_Bytes(&hs, p, ec->publicKey.curve->fieldBits / 8);
2305
2306             /* Encode pk and add it */
2307             pointlen = ec->publicKey.curve->fieldBits / 8;
2308             for (i = 0; i < pointlen - 1; ++i) {
2309                 b = bignum_byte(ec->publicKey.y, i);
2310                 SHA512_Bytes(&hs, &b, 1);
2311             }
2312             /* Unset last bit of y and set first bit of x in its place */
2313             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
2314             b |= bignum_bit(ec->publicKey.x, 0) << 7;
2315             SHA512_Bytes(&hs, &b, 1);
2316
2317             /* Add the message itself */
2318             SHA512_Bytes(&hs, data, datalen);
2319
2320             /* Get the hash */
2321             SHA512_Final(&hs, digest);
2322
2323             /* Convert to Bignum */
2324             h = bignum_from_bytes_le(digest, sizeof(digest));
2325         }
2326
2327         /* Verify sB == r + h*publicKey */
2328         {
2329             struct ec_point *lhs, *rhs, *tmp;
2330
2331             /* lhs = sB */
2332             lhs = ecp_mul(&ec->publicKey.curve->e.B, s);
2333             freebn(s);
2334             if (!lhs) {
2335                 ec_point_free(r);
2336                 freebn(h);
2337                 return 0;
2338             }
2339
2340             /* rhs = r + h*publicKey */
2341             tmp = ecp_mul(&ec->publicKey, h);
2342             freebn(h);
2343             if (!tmp) {
2344                 ec_point_free(lhs);
2345                 ec_point_free(r);
2346                 return 0;
2347             }
2348             rhs = ecp_add(r, tmp, 0);
2349             ec_point_free(r);
2350             ec_point_free(tmp);
2351             if (!rhs) {
2352                 ec_point_free(lhs);
2353                 return 0;
2354             }
2355
2356             /* Check the point is the same */
2357             ret = !bignum_cmp(lhs->x, rhs->x);
2358             if (ret) {
2359                 ret = !bignum_cmp(lhs->y, rhs->y);
2360                 if (ret) {
2361                     ret = 1;
2362                 }
2363             }
2364             ec_point_free(lhs);
2365             ec_point_free(rhs);
2366         }
2367     } else {
2368         Bignum r, s;
2369         unsigned char digest[512 / 8];
2370         void *hashctx;
2371
2372         r = getmp(&p, &slen);
2373         if (!r) return 0;
2374         s = getmp(&p, &slen);
2375         if (!s) {
2376             freebn(r);
2377             return 0;
2378         }
2379
2380         digestLen = extra->hash->hlen;
2381         assert(digestLen <= sizeof(digest));
2382         hashctx = extra->hash->init();
2383         extra->hash->bytes(hashctx, data, datalen);
2384         extra->hash->final(hashctx, digest);
2385
2386         /* Verify the signature */
2387         ret = _ecdsa_verify(&ec->publicKey, digest, digestLen, r, s);
2388
2389         freebn(r);
2390         freebn(s);
2391     }
2392
2393     return ret;
2394 }
2395
2396 static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
2397                                  int *siglen)
2398 {
2399     struct ec_key *ec = (struct ec_key *) key;
2400     const struct ecsign_extra *extra =
2401         (const struct ecsign_extra *)ec->signalg->extra;
2402     unsigned char digest[512 / 8];
2403     int digestLen;
2404     Bignum r = NULL, s = NULL;
2405     unsigned char *buf, *p;
2406     int rlen, slen, namelen;
2407     int i;
2408
2409     if (!ec->privateKey || !ec->publicKey.curve) {
2410         return NULL;
2411     }
2412
2413     if (ec->publicKey.curve->type == EC_EDWARDS) {
2414         struct ec_point *rp;
2415         int pointlen = ec->publicKey.curve->fieldBits / 8;
2416
2417         /* hash = H(sk) (where hash creates 2 * fieldBits)
2418          * b = fieldBits
2419          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
2420          * r = H(h[b/8:b/4] + m)
2421          * R = rB
2422          * S = (r + H(encodepoint(R) + encodepoint(pk) + m) * a) % l */
2423         {
2424             unsigned char hash[512/8];
2425             unsigned char b;
2426             Bignum a;
2427             SHA512_State hs;
2428             SHA512_Init(&hs);
2429
2430             for (i = 0; i < pointlen; ++i) {
2431                 unsigned char b = (unsigned char)bignum_byte(ec->privateKey, i);
2432                 SHA512_Bytes(&hs, &b, 1);
2433             }
2434
2435             SHA512_Final(&hs, hash);
2436
2437             /* The second part is simply turning the hash into a
2438              * Bignum, however the 2^(b-2) bit *must* be set, and the
2439              * bottom 3 bits *must* not be */
2440             hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
2441             hash[31] &= 0x7f; /* Unset above (b-2) */
2442             hash[31] |= 0x40; /* Set 2^(b-2) */
2443             /* Chop off the top part and convert to int */
2444             a = bignum_from_bytes_le(hash, 32);
2445
2446             SHA512_Init(&hs);
2447             SHA512_Bytes(&hs,
2448                          hash+(ec->publicKey.curve->fieldBits / 8),
2449                          (ec->publicKey.curve->fieldBits / 4)
2450                          - (ec->publicKey.curve->fieldBits / 8));
2451             SHA512_Bytes(&hs, data, datalen);
2452             SHA512_Final(&hs, hash);
2453
2454             r = bignum_from_bytes_le(hash, 512/8);
2455             rp = ecp_mul(&ec->publicKey.curve->e.B, r);
2456             if (!rp) {
2457                 freebn(r);
2458                 freebn(a);
2459                 return NULL;
2460             }
2461
2462             /* Now calculate s */
2463             SHA512_Init(&hs);
2464             /* Encode the point R */
2465             for (i = 0; i < pointlen - 1; ++i) {
2466                 b = bignum_byte(rp->y, i);
2467                 SHA512_Bytes(&hs, &b, 1);
2468             }
2469             /* Unset last bit of y and set first bit of x in its place */
2470             b = bignum_byte(rp->y, i) & 0x7f;
2471             b |= bignum_bit(rp->x, 0) << 7;
2472             SHA512_Bytes(&hs, &b, 1);
2473
2474             /* Encode the point pk */
2475             for (i = 0; i < pointlen - 1; ++i) {
2476                 b = bignum_byte(ec->publicKey.y, i);
2477                 SHA512_Bytes(&hs, &b, 1);
2478             }
2479             /* Unset last bit of y and set first bit of x in its place */
2480             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
2481             b |= bignum_bit(ec->publicKey.x, 0) << 7;
2482             SHA512_Bytes(&hs, &b, 1);
2483
2484             /* Add the message */
2485             SHA512_Bytes(&hs, data, datalen);
2486             SHA512_Final(&hs, hash);
2487
2488             {
2489                 Bignum tmp, tmp2;
2490
2491                 tmp = bignum_from_bytes_le(hash, 512/8);
2492                 tmp2 = modmul(tmp, a, ec->publicKey.curve->e.l);
2493                 freebn(a);
2494                 freebn(tmp);
2495                 tmp = bigadd(r, tmp2);
2496                 freebn(r);
2497                 freebn(tmp2);
2498                 s = bigmod(tmp, ec->publicKey.curve->e.l);
2499                 freebn(tmp);
2500             }
2501         }
2502
2503         /* Format the output */
2504         namelen = strlen(ec->signalg->name);
2505         *siglen = 4+namelen+4+((ec->publicKey.curve->fieldBits / 8)*2);
2506         buf = snewn(*siglen, unsigned char);
2507         p = buf;
2508         PUT_32BIT(p, namelen);
2509         p += 4;
2510         memcpy(p, ec->signalg->name, namelen);
2511         p += namelen;
2512         PUT_32BIT(p, ((ec->publicKey.curve->fieldBits / 8)*2));
2513         p += 4;
2514
2515         /* Encode the point */
2516         pointlen = ec->publicKey.curve->fieldBits / 8;
2517         for (i = 0; i < pointlen - 1; ++i) {
2518             *p++ = bignum_byte(rp->y, i);
2519         }
2520         /* Unset last bit of y and set first bit of x in its place */
2521         *p = bignum_byte(rp->y, i) & 0x7f;
2522         *p++ |= bignum_bit(rp->x, 0) << 7;
2523         ec_point_free(rp);
2524
2525         /* Encode the int */
2526         for (i = 0; i < pointlen; ++i) {
2527             *p++ = bignum_byte(s, i);
2528         }
2529         freebn(s);
2530     } else {
2531         void *hashctx;
2532
2533         digestLen = extra->hash->hlen;
2534         assert(digestLen <= sizeof(digest));
2535         hashctx = extra->hash->init();
2536         extra->hash->bytes(hashctx, data, datalen);
2537         extra->hash->final(hashctx, digest);
2538
2539         /* Do the signature */
2540         _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
2541         if (!r || !s) {
2542             if (r) freebn(r);
2543             if (s) freebn(s);
2544             return NULL;
2545         }
2546
2547         rlen = (bignum_bitcount(r) + 8) / 8;
2548         slen = (bignum_bitcount(s) + 8) / 8;
2549
2550         namelen = strlen(ec->signalg->name);
2551
2552         /* Format the output */
2553         *siglen = 8+namelen+rlen+slen+8;
2554         buf = snewn(*siglen, unsigned char);
2555         p = buf;
2556         PUT_32BIT(p, namelen);
2557         p += 4;
2558         memcpy(p, ec->signalg->name, namelen);
2559         p += namelen;
2560         PUT_32BIT(p, rlen + slen + 8);
2561         p += 4;
2562         PUT_32BIT(p, rlen);
2563         p += 4;
2564         for (i = rlen; i--;)
2565             *p++ = bignum_byte(r, i);
2566         PUT_32BIT(p, slen);
2567         p += 4;
2568         for (i = slen; i--;)
2569             *p++ = bignum_byte(s, i);
2570
2571         freebn(r);
2572         freebn(s);
2573     }
2574
2575     return buf;
2576 }
2577
2578 const struct ecsign_extra sign_extra_ed25519 = {
2579     ec_ed25519, NULL,
2580     NULL, 0,
2581 };
2582 const struct ssh_signkey ssh_ecdsa_ed25519 = {
2583     ecdsa_newkey,
2584     ecdsa_freekey,
2585     ecdsa_fmtkey,
2586     ecdsa_public_blob,
2587     ecdsa_private_blob,
2588     ecdsa_createkey,
2589     ed25519_openssh_createkey,
2590     ed25519_openssh_fmtkey,
2591     2 /* point, private exponent */,
2592     ecdsa_pubkey_bits,
2593     ecdsa_verifysig,
2594     ecdsa_sign,
2595     "ssh-ed25519",
2596     "ssh-ed25519",
2597     &sign_extra_ed25519,
2598 };
2599
2600 /* OID: 1.2.840.10045.3.1.7 (ansiX9p256r1) */
2601 static const unsigned char nistp256_oid[] = {
2602     0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07
2603 };
2604 const struct ecsign_extra sign_extra_nistp256 = {
2605     ec_p256, &ssh_sha256,
2606     nistp256_oid, lenof(nistp256_oid),
2607 };
2608 const struct ssh_signkey ssh_ecdsa_nistp256 = {
2609     ecdsa_newkey,
2610     ecdsa_freekey,
2611     ecdsa_fmtkey,
2612     ecdsa_public_blob,
2613     ecdsa_private_blob,
2614     ecdsa_createkey,
2615     ecdsa_openssh_createkey,
2616     ecdsa_openssh_fmtkey,
2617     3 /* curve name, point, private exponent */,
2618     ecdsa_pubkey_bits,
2619     ecdsa_verifysig,
2620     ecdsa_sign,
2621     "ecdsa-sha2-nistp256",
2622     "ecdsa-sha2-nistp256",
2623     &sign_extra_nistp256,
2624 };
2625
2626 /* OID: 1.3.132.0.34 (secp384r1) */
2627 static const unsigned char nistp384_oid[] = {
2628     0x2b, 0x81, 0x04, 0x00, 0x22
2629 };
2630 const struct ecsign_extra sign_extra_nistp384 = {
2631     ec_p384, &ssh_sha384,
2632     nistp384_oid, lenof(nistp384_oid),
2633 };
2634 const struct ssh_signkey ssh_ecdsa_nistp384 = {
2635     ecdsa_newkey,
2636     ecdsa_freekey,
2637     ecdsa_fmtkey,
2638     ecdsa_public_blob,
2639     ecdsa_private_blob,
2640     ecdsa_createkey,
2641     ecdsa_openssh_createkey,
2642     ecdsa_openssh_fmtkey,
2643     3 /* curve name, point, private exponent */,
2644     ecdsa_pubkey_bits,
2645     ecdsa_verifysig,
2646     ecdsa_sign,
2647     "ecdsa-sha2-nistp384",
2648     "ecdsa-sha2-nistp384",
2649     &sign_extra_nistp384,
2650 };
2651
2652 /* OID: 1.3.132.0.35 (secp521r1) */
2653 static const unsigned char nistp521_oid[] = {
2654     0x2b, 0x81, 0x04, 0x00, 0x23
2655 };
2656 const struct ecsign_extra sign_extra_nistp521 = {
2657     ec_p521, &ssh_sha512,
2658     nistp521_oid, lenof(nistp521_oid),
2659 };
2660 const struct ssh_signkey ssh_ecdsa_nistp521 = {
2661     ecdsa_newkey,
2662     ecdsa_freekey,
2663     ecdsa_fmtkey,
2664     ecdsa_public_blob,
2665     ecdsa_private_blob,
2666     ecdsa_createkey,
2667     ecdsa_openssh_createkey,
2668     ecdsa_openssh_fmtkey,
2669     3 /* curve name, point, private exponent */,
2670     ecdsa_pubkey_bits,
2671     ecdsa_verifysig,
2672     ecdsa_sign,
2673     "ecdsa-sha2-nistp521",
2674     "ecdsa-sha2-nistp521",
2675     &sign_extra_nistp521,
2676 };
2677
2678 /* ----------------------------------------------------------------------
2679  * Exposed ECDH interface
2680  */
2681
2682 struct eckex_extra {
2683     struct ec_curve *(*curve)(void);
2684 };
2685
2686 static Bignum ecdh_calculate(const Bignum private,
2687                              const struct ec_point *public)
2688 {
2689     struct ec_point *p;
2690     Bignum ret;
2691     p = ecp_mul(public, private);
2692     if (!p) return NULL;
2693     ret = p->x;
2694     p->x = NULL;
2695
2696     if (p->curve->type == EC_MONTGOMERY) {
2697         /*
2698          * Endianness-swap. The Curve25519 algorithm definition
2699          * assumes you were doing your computation in arrays of 32
2700          * little-endian bytes, and now specifies that you take your
2701          * final one of those and convert it into a bignum in
2702          * _network_ byte order, i.e. big-endian.
2703          *
2704          * In particular, the spec says, you convert the _whole_ 32
2705          * bytes into a bignum. That is, on the rare occasions that
2706          * p->x has come out with the most significant 8 bits zero, we
2707          * have to imagine that being represented by a 32-byte string
2708          * with the last byte being zero, so that has to be converted
2709          * into an SSH-2 bignum with the _low_ byte zero, i.e. a
2710          * multiple of 256.
2711          */
2712         int i;
2713         int bytes = (p->curve->fieldBits+7) / 8;
2714         unsigned char *byteorder = snewn(bytes, unsigned char);
2715         for (i = 0; i < bytes; ++i) {
2716             byteorder[i] = bignum_byte(ret, i);
2717         }
2718         freebn(ret);
2719         ret = bignum_from_bytes(byteorder, bytes);
2720         smemclr(byteorder, bytes);
2721         sfree(byteorder);
2722     }
2723
2724     ec_point_free(p);
2725     return ret;
2726 }
2727
2728 const char *ssh_ecdhkex_curve_textname(const struct ssh_kex *kex)
2729 {
2730     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
2731     struct ec_curve *curve = extra->curve();
2732     return curve->textname;
2733 }
2734
2735 void *ssh_ecdhkex_newkey(const struct ssh_kex *kex)
2736 {
2737     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
2738     struct ec_curve *curve;
2739     struct ec_key *key;
2740     struct ec_point *publicKey;
2741
2742     curve = extra->curve();
2743
2744     key = snew(struct ec_key);
2745
2746     key->signalg = NULL;
2747     key->publicKey.curve = curve;
2748
2749     if (curve->type == EC_MONTGOMERY) {
2750         unsigned char bytes[32] = {0};
2751         int i;
2752
2753         for (i = 0; i < sizeof(bytes); ++i)
2754         {
2755             bytes[i] = (unsigned char)random_byte();
2756         }
2757         bytes[0] &= 248;
2758         bytes[31] &= 127;
2759         bytes[31] |= 64;
2760         key->privateKey = bignum_from_bytes(bytes, sizeof(bytes));
2761         for (i = 0; i < sizeof(bytes); ++i)
2762         {
2763             ((volatile char*)bytes)[i] = 0;
2764         }
2765         if (!key->privateKey) {
2766             sfree(key);
2767             return NULL;
2768         }
2769         publicKey = ecp_mul(&key->publicKey.curve->m.G, key->privateKey);
2770         if (!publicKey) {
2771             freebn(key->privateKey);
2772             sfree(key);
2773             return NULL;
2774         }
2775         key->publicKey.x = publicKey->x;
2776         key->publicKey.y = publicKey->y;
2777         key->publicKey.z = NULL;
2778         sfree(publicKey);
2779     } else {
2780         key->privateKey = bignum_random_in_range(One, key->publicKey.curve->w.n);
2781         if (!key->privateKey) {
2782             sfree(key);
2783             return NULL;
2784         }
2785         publicKey = ecp_mul(&key->publicKey.curve->w.G, key->privateKey);
2786         if (!publicKey) {
2787             freebn(key->privateKey);
2788             sfree(key);
2789             return NULL;
2790         }
2791         key->publicKey.x = publicKey->x;
2792         key->publicKey.y = publicKey->y;
2793         key->publicKey.z = NULL;
2794         sfree(publicKey);
2795     }
2796     return key;
2797 }
2798
2799 char *ssh_ecdhkex_getpublic(void *key, int *len)
2800 {
2801     struct ec_key *ec = (struct ec_key*)key;
2802     char *point, *p;
2803     int i;
2804     int pointlen;
2805
2806     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2807
2808     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2809         *len = 1 + pointlen * 2;
2810     } else {
2811         *len = pointlen;
2812     }
2813     point = (char*)snewn(*len, char);
2814
2815     p = point;
2816     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2817         *p++ = 0x04;
2818         for (i = pointlen; i--;) {
2819             *p++ = bignum_byte(ec->publicKey.x, i);
2820         }
2821         for (i = pointlen; i--;) {
2822             *p++ = bignum_byte(ec->publicKey.y, i);
2823         }
2824     } else {
2825         for (i = 0; i < pointlen; ++i) {
2826             *p++ = bignum_byte(ec->publicKey.x, i);
2827         }
2828     }
2829
2830     return point;
2831 }
2832
2833 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
2834 {
2835     struct ec_key *ec = (struct ec_key*) key;
2836     struct ec_point remote;
2837     Bignum ret;
2838
2839     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2840         remote.curve = ec->publicKey.curve;
2841         remote.infinity = 0;
2842         if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
2843             return NULL;
2844         }
2845     } else {
2846         /* Point length has to be the same length */
2847         if (remoteKeyLen != (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8) {
2848             return NULL;
2849         }
2850
2851         remote.curve = ec->publicKey.curve;
2852         remote.infinity = 0;
2853         remote.x = bignum_from_bytes_le((unsigned char*)remoteKey, remoteKeyLen);
2854         remote.y = NULL;
2855         remote.z = NULL;
2856     }
2857
2858     ret = ecdh_calculate(ec->privateKey, &remote);
2859     if (remote.x) freebn(remote.x);
2860     if (remote.y) freebn(remote.y);
2861     return ret;
2862 }
2863
2864 void ssh_ecdhkex_freekey(void *key)
2865 {
2866     ecdsa_freekey(key);
2867 }
2868
2869 static const struct eckex_extra kex_extra_curve25519 = { ec_curve25519 };
2870 static const struct ssh_kex ssh_ec_kex_curve25519 = {
2871     "curve25519-sha256@libssh.org", NULL, KEXTYPE_ECDH,
2872     &ssh_sha256, &kex_extra_curve25519,
2873 };
2874
2875 const struct eckex_extra kex_extra_nistp256 = { ec_p256 };
2876 static const struct ssh_kex ssh_ec_kex_nistp256 = {
2877     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH,
2878     &ssh_sha256, &kex_extra_nistp256,
2879 };
2880
2881 const struct eckex_extra kex_extra_nistp384 = { ec_p384 };
2882 static const struct ssh_kex ssh_ec_kex_nistp384 = {
2883     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH,
2884     &ssh_sha384, &kex_extra_nistp384,
2885 };
2886
2887 const struct eckex_extra kex_extra_nistp521 = { ec_p521 };
2888 static const struct ssh_kex ssh_ec_kex_nistp521 = {
2889     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH,
2890     &ssh_sha512, &kex_extra_nistp521,
2891 };
2892
2893 static const struct ssh_kex *const ec_kex_list[] = {
2894     &ssh_ec_kex_curve25519,
2895     &ssh_ec_kex_nistp256,
2896     &ssh_ec_kex_nistp384,
2897     &ssh_ec_kex_nistp521,
2898 };
2899
2900 const struct ssh_kexes ssh_ecdh_kex = {
2901     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
2902     ec_kex_list
2903 };
2904
2905 /* ----------------------------------------------------------------------
2906  * Helper functions for finding key algorithms and returning auxiliary
2907  * data.
2908  */
2909
2910 const struct ssh_signkey *ec_alg_by_oid(int len, const void *oid,
2911                                         const struct ec_curve **curve)
2912 {
2913     static const struct ssh_signkey *algs_with_oid[] = {
2914         &ssh_ecdsa_nistp256,
2915         &ssh_ecdsa_nistp384,
2916         &ssh_ecdsa_nistp521,
2917     };
2918     int i;
2919
2920     for (i = 0; i < lenof(algs_with_oid); i++) {
2921         const struct ssh_signkey *alg = algs_with_oid[i];
2922         const struct ecsign_extra *extra =
2923             (const struct ecsign_extra *)alg->extra;
2924         if (len == extra->oidlen && !memcmp(oid, extra->oid, len)) {
2925             *curve = extra->curve();
2926             return alg;
2927         }
2928     }
2929     return NULL;
2930 }
2931
2932 const unsigned char *ec_alg_oid(const struct ssh_signkey *alg,
2933                                 int *oidlen)
2934 {
2935     const struct ecsign_extra *extra = (const struct ecsign_extra *)alg->extra;
2936     *oidlen = extra->oidlen;
2937     return extra->oid;
2938 }
2939
2940 const int ec_nist_alg_and_curve_by_bits(int bits,
2941                                         const struct ec_curve **curve,
2942                                         const struct ssh_signkey **alg)
2943 {
2944     switch (bits) {
2945       case 256: *alg = &ssh_ecdsa_nistp256; break;
2946       case 384: *alg = &ssh_ecdsa_nistp384; break;
2947       case 521: *alg = &ssh_ecdsa_nistp521; break;
2948       default: return FALSE;
2949     }
2950     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
2951     return TRUE;
2952 }
2953
2954 const int ec_ed_alg_and_curve_by_bits(int bits,
2955                                       const struct ec_curve **curve,
2956                                       const struct ssh_signkey **alg)
2957 {
2958     switch (bits) {
2959       case 256: *alg = &ssh_ecdsa_ed25519; break;
2960       default: return FALSE;
2961     }
2962     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
2963     return TRUE;
2964 }