]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
Fix construction of the output bignum in Curve25519 kex.
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  *
5  * NOTE: Only curves on prime field are handled by the maths functions
6  *       in Weierstrass form using Jacobian co-ordinates.
7  *
8  *       Montgomery form curves are supported for DH. (Curve25519)
9  *
10  *       Edwards form curves are supported for DSA. (Ed25519)
11  */
12
13 /*
14  * References:
15  *
16  * Elliptic curves in SSH are specified in RFC 5656:
17  *   http://tools.ietf.org/html/rfc5656
18  *
19  * That specification delegates details of public key formatting and a
20  * lot of underlying mechanism to SEC 1:
21  *   http://www.secg.org/sec1-v2.pdf
22  *
23  * Montgomery maths from:
24  * Handbook of elliptic and hyperelliptic curve cryptography, Chapter 13
25  *   http://cs.ucsb.edu/~koc/ccs130h/2013/EllipticHyperelliptic-CohenFrey.pdf
26  *
27  * Edwards DSA:
28  *   http://ed25519.cr.yp.to/ed25519-20110926.pdf
29  */
30
31 #include <stdlib.h>
32 #include <assert.h>
33
34 #include "ssh.h"
35
36 /* ----------------------------------------------------------------------
37  * Elliptic curve definitions
38  */
39
40 static void initialise_wcurve(struct ec_curve *curve, int bits,
41                               const unsigned char *p,
42                               const unsigned char *a, const unsigned char *b,
43                               const unsigned char *n, const unsigned char *Gx,
44                               const unsigned char *Gy)
45 {
46     int length = bits / 8;
47     if (bits % 8) ++length;
48
49     curve->type = EC_WEIERSTRASS;
50
51     curve->fieldBits = bits;
52     curve->p = bignum_from_bytes(p, length);
53
54     /* Curve co-efficients */
55     curve->w.a = bignum_from_bytes(a, length);
56     curve->w.b = bignum_from_bytes(b, length);
57
58     /* Group order and generator */
59     curve->w.n = bignum_from_bytes(n, length);
60     curve->w.G.x = bignum_from_bytes(Gx, length);
61     curve->w.G.y = bignum_from_bytes(Gy, length);
62     curve->w.G.curve = curve;
63     curve->w.G.infinity = 0;
64 }
65
66 static void initialise_mcurve(struct ec_curve *curve, int bits,
67                               const unsigned char *p,
68                               const unsigned char *a, const unsigned char *b,
69                               const unsigned char *Gx)
70 {
71     int length = bits / 8;
72     if (bits % 8) ++length;
73
74     curve->type = EC_MONTGOMERY;
75
76     curve->fieldBits = bits;
77     curve->p = bignum_from_bytes(p, length);
78
79     /* Curve co-efficients */
80     curve->m.a = bignum_from_bytes(a, length);
81     curve->m.b = bignum_from_bytes(b, length);
82
83     /* Generator */
84     curve->m.G.x = bignum_from_bytes(Gx, length);
85     curve->m.G.y = NULL;
86     curve->m.G.z = NULL;
87     curve->m.G.curve = curve;
88     curve->m.G.infinity = 0;
89 }
90
91 static void initialise_ecurve(struct ec_curve *curve, int bits,
92                               const unsigned char *p,
93                               const unsigned char *l, const unsigned char *d,
94                               const unsigned char *Bx, const unsigned char *By)
95 {
96     int length = bits / 8;
97     if (bits % 8) ++length;
98
99     curve->type = EC_EDWARDS;
100
101     curve->fieldBits = bits;
102     curve->p = bignum_from_bytes(p, length);
103
104     /* Curve co-efficients */
105     curve->e.l = bignum_from_bytes(l, length);
106     curve->e.d = bignum_from_bytes(d, length);
107
108     /* Group order and generator */
109     curve->e.B.x = bignum_from_bytes(Bx, length);
110     curve->e.B.y = bignum_from_bytes(By, length);
111     curve->e.B.curve = curve;
112     curve->e.B.infinity = 0;
113 }
114
115 static struct ec_curve *ec_p256(void)
116 {
117     static struct ec_curve curve = { 0 };
118     static unsigned char initialised = 0;
119
120     if (!initialised)
121     {
122         static const unsigned char p[] = {
123             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
124             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
125             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
126             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
127         };
128         static const unsigned char a[] = {
129             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
130             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
131             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
132             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
133         };
134         static const unsigned char b[] = {
135             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
136             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
137             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
138             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
139         };
140         static const unsigned char n[] = {
141             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
142             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
143             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
144             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
145         };
146         static const unsigned char Gx[] = {
147             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
148             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
149             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
150             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
151         };
152         static const unsigned char Gy[] = {
153             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
154             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
155             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
156             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
157         };
158
159         initialise_wcurve(&curve, 256, p, a, b, n, Gx, Gy);
160         curve.textname = curve.name = "nistp256";
161
162         /* Now initialised, no need to do it again */
163         initialised = 1;
164     }
165
166     return &curve;
167 }
168
169 static struct ec_curve *ec_p384(void)
170 {
171     static struct ec_curve curve = { 0 };
172     static unsigned char initialised = 0;
173
174     if (!initialised)
175     {
176         static const unsigned char p[] = {
177             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
178             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
179             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
180             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
181             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
182             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
183         };
184         static const unsigned char a[] = {
185             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
186             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
187             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
188             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
189             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
190             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
191         };
192         static const unsigned char b[] = {
193             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
194             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
195             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
196             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
197             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
198             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
199         };
200         static const unsigned char n[] = {
201             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
202             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
203             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
204             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
205             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
206             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
207         };
208         static const unsigned char Gx[] = {
209             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
210             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
211             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
212             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
213             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
214             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
215         };
216         static const unsigned char Gy[] = {
217             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
218             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
219             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
220             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
221             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
222             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
223         };
224
225         initialise_wcurve(&curve, 384, p, a, b, n, Gx, Gy);
226         curve.textname = curve.name = "nistp384";
227
228         /* Now initialised, no need to do it again */
229         initialised = 1;
230     }
231
232     return &curve;
233 }
234
235 static struct ec_curve *ec_p521(void)
236 {
237     static struct ec_curve curve = { 0 };
238     static unsigned char initialised = 0;
239
240     if (!initialised)
241     {
242         static const unsigned char p[] = {
243             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
244             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
245             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
246             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
247             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
248             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
249             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
250             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
251             0xff, 0xff
252         };
253         static const unsigned char a[] = {
254             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
255             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
256             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
257             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
258             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
259             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
260             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
261             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
262             0xff, 0xfc
263         };
264         static const unsigned char b[] = {
265             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
266             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
267             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
268             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
269             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
270             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
271             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
272             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
273             0x3f, 0x00
274         };
275         static const unsigned char n[] = {
276             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
277             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
278             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
279             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
280             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
281             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
282             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
283             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
284             0x64, 0x09
285         };
286         static const unsigned char Gx[] = {
287             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
288             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
289             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
290             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
291             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
292             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
293             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
294             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
295             0xbd, 0x66
296         };
297         static const unsigned char Gy[] = {
298             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
299             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
300             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
301             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
302             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
303             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
304             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
305             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
306             0x66, 0x50
307         };
308
309         initialise_wcurve(&curve, 521, p, a, b, n, Gx, Gy);
310         curve.textname = curve.name = "nistp521";
311
312         /* Now initialised, no need to do it again */
313         initialised = 1;
314     }
315
316     return &curve;
317 }
318
319 static struct ec_curve *ec_curve25519(void)
320 {
321     static struct ec_curve curve = { 0 };
322     static unsigned char initialised = 0;
323
324     if (!initialised)
325     {
326         static const unsigned char p[] = {
327             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
328             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
329             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
330             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
331         };
332         static const unsigned char a[] = {
333             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
334             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
335             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
336             0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x6d, 0x06
337         };
338         static const unsigned char b[] = {
339             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
340             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
341             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
342             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01
343         };
344         static const unsigned char gx[32] = {
345             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
346             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
347             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
348             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09
349         };
350
351         initialise_mcurve(&curve, 256, p, a, b, gx);
352         /* This curve doesn't need a name, because it's never used in
353          * any format that embeds the curve name */
354         curve.name = NULL;
355         curve.textname = "Curve25519";
356
357         /* Now initialised, no need to do it again */
358         initialised = 1;
359     }
360
361     return &curve;
362 }
363
364 static struct ec_curve *ec_ed25519(void)
365 {
366     static struct ec_curve curve = { 0 };
367     static unsigned char initialised = 0;
368
369     if (!initialised)
370     {
371         static const unsigned char q[] = {
372             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
373             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
374             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
375             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
376         };
377         static const unsigned char l[32] = {
378             0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
379             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
380             0x14, 0xde, 0xf9, 0xde, 0xa2, 0xf7, 0x9c, 0xd6,
381             0x58, 0x12, 0x63, 0x1a, 0x5c, 0xf5, 0xd3, 0xed
382         };
383         static const unsigned char d[32] = {
384             0x52, 0x03, 0x6c, 0xee, 0x2b, 0x6f, 0xfe, 0x73,
385             0x8c, 0xc7, 0x40, 0x79, 0x77, 0x79, 0xe8, 0x98,
386             0x00, 0x70, 0x0a, 0x4d, 0x41, 0x41, 0xd8, 0xab,
387             0x75, 0xeb, 0x4d, 0xca, 0x13, 0x59, 0x78, 0xa3
388         };
389         static const unsigned char Bx[32] = {
390             0x21, 0x69, 0x36, 0xd3, 0xcd, 0x6e, 0x53, 0xfe,
391             0xc0, 0xa4, 0xe2, 0x31, 0xfd, 0xd6, 0xdc, 0x5c,
392             0x69, 0x2c, 0xc7, 0x60, 0x95, 0x25, 0xa7, 0xb2,
393             0xc9, 0x56, 0x2d, 0x60, 0x8f, 0x25, 0xd5, 0x1a
394         };
395         static const unsigned char By[32] = {
396             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
397             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
398             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
399             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x58
400         };
401
402         /* This curve doesn't need a name, because it's never used in
403          * any format that embeds the curve name */
404         curve.name = NULL;
405
406         initialise_ecurve(&curve, 256, q, l, d, Bx, By);
407         curve.textname = "Ed25519";
408
409         /* Now initialised, no need to do it again */
410         initialised = 1;
411     }
412
413     return &curve;
414 }
415
416 /* Return 1 if a is -3 % p, otherwise return 0
417  * This is used because there are some maths optimisations */
418 static int ec_aminus3(const struct ec_curve *curve)
419 {
420     int ret;
421     Bignum _p;
422
423     if (curve->type != EC_WEIERSTRASS) {
424         return 0;
425     }
426
427     _p = bignum_add_long(curve->w.a, 3);
428
429     ret = !bignum_cmp(curve->p, _p);
430     freebn(_p);
431     return ret;
432 }
433
434 /* ----------------------------------------------------------------------
435  * Elliptic curve field maths
436  */
437
438 static Bignum ecf_add(const Bignum a, const Bignum b,
439                       const struct ec_curve *curve)
440 {
441     Bignum a1, b1, ab, ret;
442
443     a1 = bigmod(a, curve->p);
444     b1 = bigmod(b, curve->p);
445
446     ab = bigadd(a1, b1);
447     freebn(a1);
448     freebn(b1);
449
450     ret = bigmod(ab, curve->p);
451     freebn(ab);
452
453     return ret;
454 }
455
456 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
457 {
458     return modmul(a, a, curve->p);
459 }
460
461 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
462 {
463     Bignum ret, tmp;
464
465     /* Double */
466     tmp = bignum_lshift(a, 1);
467
468     /* Add itself (i.e. treble) */
469     ret = bigadd(tmp, a);
470     freebn(tmp);
471
472     /* Normalise */
473     while (bignum_cmp(ret, curve->p) >= 0)
474     {
475         tmp = bigsub(ret, curve->p);
476         assert(tmp);
477         freebn(ret);
478         ret = tmp;
479     }
480
481     return ret;
482 }
483
484 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
485 {
486     Bignum ret = bignum_lshift(a, 1);
487     if (bignum_cmp(ret, curve->p) >= 0)
488     {
489         Bignum tmp = bigsub(ret, curve->p);
490         assert(tmp);
491         freebn(ret);
492         return tmp;
493     }
494     else
495     {
496         return ret;
497     }
498 }
499
500 /* ----------------------------------------------------------------------
501  * Memory functions
502  */
503
504 void ec_point_free(struct ec_point *point)
505 {
506     if (point == NULL) return;
507     point->curve = 0;
508     if (point->x) freebn(point->x);
509     if (point->y) freebn(point->y);
510     if (point->z) freebn(point->z);
511     point->infinity = 0;
512     sfree(point);
513 }
514
515 static struct ec_point *ec_point_new(const struct ec_curve *curve,
516                                      const Bignum x, const Bignum y, const Bignum z,
517                                      unsigned char infinity)
518 {
519     struct ec_point *point = snewn(1, struct ec_point);
520     point->curve = curve;
521     point->x = x;
522     point->y = y;
523     point->z = z;
524     point->infinity = infinity ? 1 : 0;
525     return point;
526 }
527
528 static struct ec_point *ec_point_copy(const struct ec_point *a)
529 {
530     if (a == NULL) return NULL;
531     return ec_point_new(a->curve,
532                         a->x ? copybn(a->x) : NULL,
533                         a->y ? copybn(a->y) : NULL,
534                         a->z ? copybn(a->z) : NULL,
535                         a->infinity);
536 }
537
538 static int ec_point_verify(const struct ec_point *a)
539 {
540     if (a->infinity) {
541         return 1;
542     } else if (a->curve->type == EC_EDWARDS) {
543         /* Check y^2 - x^2 - 1 - d * x^2 * y^2 == 0 */
544         Bignum y2, x2, tmp, tmp2, tmp3;
545         int ret;
546
547         y2 = ecf_square(a->y, a->curve);
548         x2 = ecf_square(a->x, a->curve);
549         tmp = modmul(a->curve->e.d, x2, a->curve->p);
550         tmp2 = modmul(tmp, y2, a->curve->p);
551         freebn(tmp);
552         tmp = modsub(y2, x2, a->curve->p);
553         freebn(y2);
554         freebn(x2);
555         tmp3 = modsub(tmp, tmp2, a->curve->p);
556         freebn(tmp);
557         freebn(tmp2);
558         ret = !bignum_cmp(tmp3, One);
559         freebn(tmp3);
560         return ret;
561     } else if (a->curve->type == EC_WEIERSTRASS) {
562         /* Verify y^2 = x^3 + ax + b */
563         int ret = 0;
564
565         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
566
567         Bignum Three = bignum_from_long(3);
568
569         lhs = modmul(a->y, a->y, a->curve->p);
570
571         /* This uses montgomery multiplication to optimise */
572         x3 = modpow(a->x, Three, a->curve->p);
573         freebn(Three);
574         ax = modmul(a->curve->w.a, a->x, a->curve->p);
575         x3ax = bigadd(x3, ax);
576         freebn(x3); x3 = NULL;
577         freebn(ax); ax = NULL;
578         x3axm = bigmod(x3ax, a->curve->p);
579         freebn(x3ax); x3ax = NULL;
580         x3axb = bigadd(x3axm, a->curve->w.b);
581         freebn(x3axm); x3axm = NULL;
582         rhs = bigmod(x3axb, a->curve->p);
583         freebn(x3axb);
584
585         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
586         freebn(lhs);
587         freebn(rhs);
588
589         return ret;
590     } else {
591         return 0;
592     }
593 }
594
595 /* ----------------------------------------------------------------------
596  * Elliptic curve point maths
597  */
598
599 /* Returns 1 on success and 0 on memory error */
600 static int ecp_normalise(struct ec_point *a)
601 {
602     if (!a) {
603         /* No point */
604         return 0;
605     }
606
607     if (a->infinity) {
608         /* Point is at infinity - i.e. normalised */
609         return 1;
610     }
611
612     if (a->curve->type == EC_WEIERSTRASS) {
613         /* In Jacobian Coordinates the triple (X, Y, Z) represents
614            the affine point (X / Z^2, Y / Z^3) */
615
616         Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
617
618         if (!a->x || !a->y) {
619             /* No point defined */
620             return 0;
621         } else if (!a->z) {
622             /* Already normalised */
623             return 1;
624         }
625
626         Z2 = ecf_square(a->z, a->curve);
627         Z2inv = modinv(Z2, a->curve->p);
628         if (!Z2inv) {
629             freebn(Z2);
630             return 0;
631         }
632         tx = modmul(a->x, Z2inv, a->curve->p);
633         freebn(Z2inv);
634
635         Z3 = modmul(Z2, a->z, a->curve->p);
636         freebn(Z2);
637         Z3inv = modinv(Z3, a->curve->p);
638         freebn(Z3);
639         if (!Z3inv) {
640             freebn(tx);
641             return 0;
642         }
643         ty = modmul(a->y, Z3inv, a->curve->p);
644         freebn(Z3inv);
645
646         freebn(a->x);
647         a->x = tx;
648         freebn(a->y);
649         a->y = ty;
650         freebn(a->z);
651         a->z = NULL;
652         return 1;
653     } else if (a->curve->type == EC_MONTGOMERY) {
654         /* In Montgomery (X : Z) represents the x co-ord (X / Z, ?) */
655
656         Bignum tmp, tmp2;
657
658         if (!a->x) {
659             /* No point defined */
660             return 0;
661         } else if (!a->z) {
662             /* Already normalised */
663             return 1;
664         }
665
666         tmp = modinv(a->z, a->curve->p);
667         if (!tmp) {
668             return 0;
669         }
670         tmp2 = modmul(a->x, tmp, a->curve->p);
671         freebn(tmp);
672
673         freebn(a->z);
674         a->z = NULL;
675         freebn(a->x);
676         a->x = tmp2;
677         return 1;
678     } else if (a->curve->type == EC_EDWARDS) {
679         /* Always normalised */
680         return 1;
681     } else {
682         return 0;
683     }
684 }
685
686 static struct ec_point *ecp_doublew(const struct ec_point *a, const int aminus3)
687 {
688     Bignum S, M, outx, outy, outz;
689
690     if (bignum_cmp(a->y, Zero) == 0)
691     {
692         /* Identity */
693         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
694     }
695
696     /* S = 4*X*Y^2 */
697     {
698         Bignum Y2, XY2, _2XY2;
699
700         Y2 = ecf_square(a->y, a->curve);
701         XY2 = modmul(a->x, Y2, a->curve->p);
702         freebn(Y2);
703
704         _2XY2 = ecf_double(XY2, a->curve);
705         freebn(XY2);
706         S = ecf_double(_2XY2, a->curve);
707         freebn(_2XY2);
708     }
709
710     /* Faster calculation if a = -3 */
711     if (aminus3) {
712         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
713         Bignum Z2, XpZ2, XmZ2, second;
714
715         if (a->z == NULL) {
716             Z2 = copybn(One);
717         } else {
718             Z2 = ecf_square(a->z, a->curve);
719         }
720
721         XpZ2 = ecf_add(a->x, Z2, a->curve);
722         XmZ2 = modsub(a->x, Z2, a->curve->p);
723         freebn(Z2);
724
725         second = modmul(XpZ2, XmZ2, a->curve->p);
726         freebn(XpZ2);
727         freebn(XmZ2);
728
729         M = ecf_treble(second, a->curve);
730         freebn(second);
731     } else {
732         /* M = 3*X^2 + a*Z^4 */
733         Bignum _3X2, X2, aZ4;
734
735         if (a->z == NULL) {
736             aZ4 = copybn(a->curve->w.a);
737         } else {
738             Bignum Z2, Z4;
739
740             Z2 = ecf_square(a->z, a->curve);
741             Z4 = ecf_square(Z2, a->curve);
742             freebn(Z2);
743             aZ4 = modmul(a->curve->w.a, Z4, a->curve->p);
744             freebn(Z4);
745         }
746
747         X2 = modmul(a->x, a->x, a->curve->p);
748         _3X2 = ecf_treble(X2, a->curve);
749         freebn(X2);
750         M = ecf_add(_3X2, aZ4, a->curve);
751         freebn(_3X2);
752         freebn(aZ4);
753     }
754
755     /* X' = M^2 - 2*S */
756     {
757         Bignum M2, _2S;
758
759         M2 = ecf_square(M, a->curve);
760         _2S = ecf_double(S, a->curve);
761         outx = modsub(M2, _2S, a->curve->p);
762         freebn(M2);
763         freebn(_2S);
764     }
765
766     /* Y' = M*(S - X') - 8*Y^4 */
767     {
768         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
769
770         SX = modsub(S, outx, a->curve->p);
771         freebn(S);
772         MSX = modmul(M, SX, a->curve->p);
773         freebn(SX);
774         freebn(M);
775         Y2 = ecf_square(a->y, a->curve);
776         Y4 = ecf_square(Y2, a->curve);
777         freebn(Y2);
778         Eight = bignum_from_long(8);
779         _8Y4 = modmul(Eight, Y4, a->curve->p);
780         freebn(Eight);
781         freebn(Y4);
782         outy = modsub(MSX, _8Y4, a->curve->p);
783         freebn(MSX);
784         freebn(_8Y4);
785     }
786
787     /* Z' = 2*Y*Z */
788     {
789         Bignum YZ;
790
791         if (a->z == NULL) {
792             YZ = copybn(a->y);
793         } else {
794             YZ = modmul(a->y, a->z, a->curve->p);
795         }
796
797         outz = ecf_double(YZ, a->curve);
798         freebn(YZ);
799     }
800
801     return ec_point_new(a->curve, outx, outy, outz, 0);
802 }
803
804 static struct ec_point *ecp_doublem(const struct ec_point *a)
805 {
806     Bignum z, outx, outz, xpz, xmz;
807
808     z = a->z;
809     if (!z) {
810         z = One;
811     }
812
813     /* 4xz = (x + z)^2 - (x - z)^2 */
814     {
815         Bignum tmp;
816
817         tmp = ecf_add(a->x, z, a->curve);
818         xpz = ecf_square(tmp, a->curve);
819         freebn(tmp);
820
821         tmp = modsub(a->x, z, a->curve->p);
822         xmz = ecf_square(tmp, a->curve);
823         freebn(tmp);
824     }
825
826     /* outx = (x + z)^2 * (x - z)^2 */
827     outx = modmul(xpz, xmz, a->curve->p);
828
829     /* outz = 4xz * ((x - z)^2 + ((A + 2) / 4)*4xz) */
830     {
831         Bignum _4xz, tmp, tmp2, tmp3;
832
833         tmp = bignum_from_long(2);
834         tmp2 = ecf_add(a->curve->m.a, tmp, a->curve);
835         freebn(tmp);
836
837         _4xz = modsub(xpz, xmz, a->curve->p);
838         freebn(xpz);
839         tmp = modmul(tmp2, _4xz, a->curve->p);
840         freebn(tmp2);
841
842         tmp2 = bignum_from_long(4);
843         tmp3 = modinv(tmp2, a->curve->p);
844         freebn(tmp2);
845         if (!tmp3) {
846             freebn(tmp);
847             freebn(_4xz);
848             freebn(outx);
849             freebn(xmz);
850             return NULL;
851         }
852         tmp2 = modmul(tmp, tmp3, a->curve->p);
853         freebn(tmp);
854         freebn(tmp3);
855
856         tmp = ecf_add(xmz, tmp2, a->curve);
857         freebn(xmz);
858         freebn(tmp2);
859         outz = modmul(_4xz, tmp, a->curve->p);
860         freebn(_4xz);
861         freebn(tmp);
862     }
863
864     return ec_point_new(a->curve, outx, NULL, outz, 0);
865 }
866
867 /* Forward declaration for Edwards curve doubling */
868 static struct ec_point *ecp_add(const struct ec_point *a,
869                                 const struct ec_point *b,
870                                 const int aminus3);
871
872 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
873 {
874     if (a->infinity)
875     {
876         /* Identity */
877         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
878     }
879
880     if (a->curve->type == EC_EDWARDS)
881     {
882         return ecp_add(a, a, aminus3);
883     }
884     else if (a->curve->type == EC_WEIERSTRASS)
885     {
886         return ecp_doublew(a, aminus3);
887     }
888     else
889     {
890         return ecp_doublem(a);
891     }
892 }
893
894 static struct ec_point *ecp_addw(const struct ec_point *a,
895                                  const struct ec_point *b,
896                                  const int aminus3)
897 {
898     Bignum U1, U2, S1, S2, outx, outy, outz;
899
900     /* U1 = X1*Z2^2 */
901     /* S1 = Y1*Z2^3 */
902     if (b->z) {
903         Bignum Z2, Z3;
904
905         Z2 = ecf_square(b->z, a->curve);
906         U1 = modmul(a->x, Z2, a->curve->p);
907         Z3 = modmul(Z2, b->z, a->curve->p);
908         freebn(Z2);
909         S1 = modmul(a->y, Z3, a->curve->p);
910         freebn(Z3);
911     } else {
912         U1 = copybn(a->x);
913         S1 = copybn(a->y);
914     }
915
916     /* U2 = X2*Z1^2 */
917     /* S2 = Y2*Z1^3 */
918     if (a->z) {
919         Bignum Z2, Z3;
920
921         Z2 = ecf_square(a->z, b->curve);
922         U2 = modmul(b->x, Z2, b->curve->p);
923         Z3 = modmul(Z2, a->z, b->curve->p);
924         freebn(Z2);
925         S2 = modmul(b->y, Z3, b->curve->p);
926         freebn(Z3);
927     } else {
928         U2 = copybn(b->x);
929         S2 = copybn(b->y);
930     }
931
932     /* Check if multiplying by self */
933     if (bignum_cmp(U1, U2) == 0)
934     {
935         freebn(U1);
936         freebn(U2);
937         if (bignum_cmp(S1, S2) == 0)
938         {
939             freebn(S1);
940             freebn(S2);
941             return ecp_double(a, aminus3);
942         }
943         else
944         {
945             freebn(S1);
946             freebn(S2);
947             /* Infinity */
948             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
949         }
950     }
951
952     {
953         Bignum H, R, UH2, H3;
954
955         /* H = U2 - U1 */
956         H = modsub(U2, U1, a->curve->p);
957         freebn(U2);
958
959         /* R = S2 - S1 */
960         R = modsub(S2, S1, a->curve->p);
961         freebn(S2);
962
963         /* X3 = R^2 - H^3 - 2*U1*H^2 */
964         {
965             Bignum R2, H2, _2UH2, first;
966
967             H2 = ecf_square(H, a->curve);
968             UH2 = modmul(U1, H2, a->curve->p);
969             freebn(U1);
970             H3 = modmul(H2, H, a->curve->p);
971             freebn(H2);
972             R2 = ecf_square(R, a->curve);
973             _2UH2 = ecf_double(UH2, a->curve);
974             first = modsub(R2, H3, a->curve->p);
975             freebn(R2);
976             outx = modsub(first, _2UH2, a->curve->p);
977             freebn(first);
978             freebn(_2UH2);
979         }
980
981         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
982         {
983             Bignum RUH2mX, UH2mX, SH3;
984
985             UH2mX = modsub(UH2, outx, a->curve->p);
986             freebn(UH2);
987             RUH2mX = modmul(R, UH2mX, a->curve->p);
988             freebn(UH2mX);
989             freebn(R);
990             SH3 = modmul(S1, H3, a->curve->p);
991             freebn(S1);
992             freebn(H3);
993
994             outy = modsub(RUH2mX, SH3, a->curve->p);
995             freebn(RUH2mX);
996             freebn(SH3);
997         }
998
999         /* Z3 = H*Z1*Z2 */
1000         if (a->z && b->z) {
1001             Bignum ZZ;
1002
1003             ZZ = modmul(a->z, b->z, a->curve->p);
1004             outz = modmul(H, ZZ, a->curve->p);
1005             freebn(H);
1006             freebn(ZZ);
1007         } else if (a->z) {
1008             outz = modmul(H, a->z, a->curve->p);
1009             freebn(H);
1010         } else if (b->z) {
1011             outz = modmul(H, b->z, a->curve->p);
1012             freebn(H);
1013         } else {
1014             outz = H;
1015         }
1016     }
1017
1018     return ec_point_new(a->curve, outx, outy, outz, 0);
1019 }
1020
1021 static struct ec_point *ecp_addm(const struct ec_point *a,
1022                                  const struct ec_point *b,
1023                                  const struct ec_point *base)
1024 {
1025     Bignum outx, outz, az, bz;
1026
1027     az = a->z;
1028     if (!az) {
1029         az = One;
1030     }
1031     bz = b->z;
1032     if (!bz) {
1033         bz = One;
1034     }
1035
1036     /* a-b is maintained at 1 due to Montgomery ladder implementation */
1037     /* Xa+b = Za-b * ((Xa - Za)*(Xb + Zb) + (Xa + Za)*(Xb - Zb))^2 */
1038     /* Za+b = Xa-b * ((Xa - Za)*(Xb + Zb) - (Xa + Za)*(Xb - Zb))^2 */
1039     {
1040         Bignum tmp, tmp2, tmp3, tmp4;
1041
1042         /* (Xa + Za) * (Xb - Zb) */
1043         tmp = ecf_add(a->x, az, a->curve);
1044         tmp2 = modsub(b->x, bz, a->curve->p);
1045         tmp3 = modmul(tmp, tmp2, a->curve->p);
1046         freebn(tmp);
1047         freebn(tmp2);
1048
1049         /* (Xa - Za) * (Xb + Zb) */
1050         tmp = modsub(a->x, az, a->curve->p);
1051         tmp2 = ecf_add(b->x, bz, a->curve);
1052         tmp4 = modmul(tmp, tmp2, a->curve->p);
1053         freebn(tmp);
1054         freebn(tmp2);
1055
1056         tmp = ecf_add(tmp3, tmp4, a->curve);
1057         outx = ecf_square(tmp, a->curve);
1058         freebn(tmp);
1059
1060         tmp = modsub(tmp3, tmp4, a->curve->p);
1061         freebn(tmp3);
1062         freebn(tmp4);
1063         tmp2 = ecf_square(tmp, a->curve);
1064         freebn(tmp);
1065         outz = modmul(base->x, tmp2, a->curve->p);
1066         freebn(tmp2);
1067     }
1068
1069     return ec_point_new(a->curve, outx, NULL, outz, 0);
1070 }
1071
1072 static struct ec_point *ecp_adde(const struct ec_point *a,
1073                                  const struct ec_point *b)
1074 {
1075     Bignum outx, outy, dmul;
1076
1077     /* outx = (a->x * b->y + b->x * a->y) /
1078      *        (1 + a->curve->e.d * a->x * b->x * a->y * b->y) */
1079     {
1080         Bignum tmp, tmp2, tmp3, tmp4;
1081
1082         tmp = modmul(a->x, b->y, a->curve->p);
1083         tmp2 = modmul(b->x, a->y, a->curve->p);
1084         tmp3 = ecf_add(tmp, tmp2, a->curve);
1085
1086         tmp4 = modmul(tmp, tmp2, a->curve->p);
1087         freebn(tmp);
1088         freebn(tmp2);
1089         dmul = modmul(a->curve->e.d, tmp4, a->curve->p);
1090         freebn(tmp4);
1091
1092         tmp = ecf_add(One, dmul, a->curve);
1093         tmp2 = modinv(tmp, a->curve->p);
1094         freebn(tmp);
1095         if (!tmp2)
1096         {
1097             freebn(tmp3);
1098             freebn(dmul);
1099             return NULL;
1100         }
1101
1102         outx = modmul(tmp3, tmp2, a->curve->p);
1103         freebn(tmp3);
1104         freebn(tmp2);
1105     }
1106
1107     /* outy = (a->y * b->y + a->x * b->x) /
1108      *        (1 - a->curve->e.d * a->x * b->x * a->y * b->y) */
1109     {
1110         Bignum tmp, tmp2, tmp3, tmp4;
1111
1112         tmp = modsub(One, dmul, a->curve->p);
1113         freebn(dmul);
1114
1115         tmp2 = modinv(tmp, a->curve->p);
1116         freebn(tmp);
1117         if (!tmp2)
1118         {
1119             freebn(outx);
1120             return NULL;
1121         }
1122
1123         tmp = modmul(a->y, b->y, a->curve->p);
1124         tmp3 = modmul(a->x, b->x, a->curve->p);
1125         tmp4 = ecf_add(tmp, tmp3, a->curve);
1126         freebn(tmp);
1127         freebn(tmp3);
1128
1129         outy = modmul(tmp4, tmp2, a->curve->p);
1130         freebn(tmp4);
1131         freebn(tmp2);
1132     }
1133
1134     return ec_point_new(a->curve, outx, outy, NULL, 0);
1135 }
1136
1137 static struct ec_point *ecp_add(const struct ec_point *a,
1138                                 const struct ec_point *b,
1139                                 const int aminus3)
1140 {
1141     if (a->curve != b->curve) {
1142         return NULL;
1143     }
1144
1145     /* Check if multiplying by infinity */
1146     if (a->infinity) return ec_point_copy(b);
1147     if (b->infinity) return ec_point_copy(a);
1148
1149     if (a->curve->type == EC_EDWARDS)
1150     {
1151         return ecp_adde(a, b);
1152     }
1153
1154     if (a->curve->type == EC_WEIERSTRASS)
1155     {
1156         return ecp_addw(a, b, aminus3);
1157     }
1158
1159     return NULL;
1160 }
1161
1162 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1163 {
1164     struct ec_point *A, *ret;
1165     int bits, i;
1166
1167     A = ec_point_copy(a);
1168     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1169
1170     bits = bignum_bitcount(b);
1171     for (i = 0; i < bits; ++i)
1172     {
1173         if (bignum_bit(b, i))
1174         {
1175             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1176             ec_point_free(ret);
1177             ret = tmp;
1178         }
1179         if (i+1 != bits)
1180         {
1181             struct ec_point *tmp = ecp_double(A, aminus3);
1182             ec_point_free(A);
1183             A = tmp;
1184         }
1185     }
1186
1187     ec_point_free(A);
1188     return ret;
1189 }
1190
1191 static struct ec_point *ecp_mulw(const struct ec_point *a, const Bignum b)
1192 {
1193     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1194
1195     if (!ecp_normalise(ret)) {
1196         ec_point_free(ret);
1197         return NULL;
1198     }
1199
1200     return ret;
1201 }
1202
1203 static struct ec_point *ecp_mule(const struct ec_point *a, const Bignum b)
1204 {
1205     int i;
1206     struct ec_point *ret;
1207
1208     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1209
1210     for (i = bignum_bitcount(b); i >= 0 && ret; --i)
1211     {
1212         {
1213             struct ec_point *tmp = ecp_double(ret, 0);
1214             ec_point_free(ret);
1215             ret = tmp;
1216         }
1217         if (ret && bignum_bit(b, i))
1218         {
1219             struct ec_point *tmp = ecp_add(ret, a, 0);
1220             ec_point_free(ret);
1221             ret = tmp;
1222         }
1223     }
1224
1225     return ret;
1226 }
1227
1228 static struct ec_point *ecp_mulm(const struct ec_point *p, const Bignum n)
1229 {
1230     struct ec_point *P1, *P2;
1231     int bits, i;
1232
1233     /* P1 <- P and P2 <- [2]P */
1234     P2 = ecp_double(p, 0);
1235     P1 = ec_point_copy(p);
1236
1237     /* for i = bits âˆ’ 2 down to 0 */
1238     bits = bignum_bitcount(n);
1239     for (i = bits - 2; i >= 0; --i)
1240     {
1241         if (!bignum_bit(n, i))
1242         {
1243             /* P2 <- P1 + P2 */
1244             struct ec_point *tmp = ecp_addm(P1, P2, p);
1245             ec_point_free(P2);
1246             P2 = tmp;
1247
1248             /* P1 <- [2]P1 */
1249             tmp = ecp_double(P1, 0);
1250             ec_point_free(P1);
1251             P1 = tmp;
1252         }
1253         else
1254         {
1255             /* P1 <- P1 + P2 */
1256             struct ec_point *tmp = ecp_addm(P1, P2, p);
1257             ec_point_free(P1);
1258             P1 = tmp;
1259
1260             /* P2 <- [2]P2 */
1261             tmp = ecp_double(P2, 0);
1262             ec_point_free(P2);
1263             P2 = tmp;
1264         }
1265     }
1266
1267     ec_point_free(P2);
1268
1269     if (!ecp_normalise(P1)) {
1270         ec_point_free(P1);
1271         return NULL;
1272     }
1273
1274     return P1;
1275 }
1276
1277 /* Not static because it is used by sshecdsag.c to generate a new key */
1278 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1279 {
1280     if (a->curve->type == EC_WEIERSTRASS) {
1281         return ecp_mulw(a, b);
1282     } else if (a->curve->type == EC_EDWARDS) {
1283         return ecp_mule(a, b);
1284     } else {
1285         return ecp_mulm(a, b);
1286     }
1287 }
1288
1289 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
1290                                    const struct ec_point *point)
1291 {
1292     struct ec_point *aG, *bP, *ret;
1293     int aminus3;
1294
1295     if (point->curve->type != EC_WEIERSTRASS) {
1296         return NULL;
1297     }
1298
1299     aminus3 = ec_aminus3(point->curve);
1300
1301     aG = ecp_mul_(&point->curve->w.G, a, aminus3);
1302     if (!aG) return NULL;
1303     bP = ecp_mul_(point, b, aminus3);
1304     if (!bP) {
1305         ec_point_free(aG);
1306         return NULL;
1307     }
1308
1309     ret = ecp_add(aG, bP, aminus3);
1310
1311     ec_point_free(aG);
1312     ec_point_free(bP);
1313
1314     if (!ecp_normalise(ret)) {
1315         ec_point_free(ret);
1316         return NULL;
1317     }
1318
1319     return ret;
1320 }
1321 static Bignum *ecp_edx(const struct ec_curve *curve, const Bignum y)
1322 {
1323     /* Get the x value on the given Edwards curve for a given y */
1324     Bignum x, xx;
1325
1326     /* xx = (y^2 - 1) / (d * y^2 + 1) */
1327     {
1328         Bignum tmp, tmp2, tmp3;
1329
1330         tmp = ecf_square(y, curve);
1331         tmp2 = modmul(curve->e.d, tmp, curve->p);
1332         tmp3 = ecf_add(tmp2, One, curve);
1333         freebn(tmp2);
1334         tmp2 = modinv(tmp3, curve->p);
1335         freebn(tmp3);
1336         if (!tmp2) {
1337             freebn(tmp);
1338             return NULL;
1339         }
1340
1341         tmp3 = modsub(tmp, One, curve->p);
1342         freebn(tmp);
1343         xx = modmul(tmp3, tmp2, curve->p);
1344         freebn(tmp3);
1345         freebn(tmp2);
1346     }
1347
1348     /* x = xx^((p + 3) / 8) */
1349     {
1350         Bignum tmp, tmp2;
1351
1352         tmp = bignum_add_long(curve->p, 3);
1353         tmp2 = bignum_rshift(tmp, 3);
1354         freebn(tmp);
1355         x = modpow(xx, tmp2, curve->p);
1356         freebn(tmp2);
1357     }
1358
1359     /* if x^2 - xx != 0 then x = x*(2^((p - 1) / 4)) */
1360     {
1361         Bignum tmp, tmp2;
1362
1363         tmp = ecf_square(x, curve);
1364         tmp2 = modsub(tmp, xx, curve->p);
1365         freebn(tmp);
1366         freebn(xx);
1367         if (bignum_cmp(tmp2, Zero)) {
1368             Bignum tmp3;
1369
1370             freebn(tmp2);
1371
1372             tmp = modsub(curve->p, One, curve->p);
1373             tmp2 = bignum_rshift(tmp, 2);
1374             freebn(tmp);
1375             tmp = bignum_from_long(2);
1376             tmp3 = modpow(tmp, tmp2, curve->p);
1377             freebn(tmp);
1378             freebn(tmp2);
1379
1380             tmp = modmul(x, tmp3, curve->p);
1381             freebn(x);
1382             freebn(tmp3);
1383             x = tmp;
1384         } else {
1385             freebn(tmp2);
1386         }
1387     }
1388
1389     /* if x % 2 != 0 then x = p - x */
1390     if (bignum_bit(x, 0)) {
1391         Bignum tmp = modsub(curve->p, x, curve->p);
1392         freebn(x);
1393         x = tmp;
1394     }
1395
1396     return x;
1397 }
1398
1399 /* ----------------------------------------------------------------------
1400  * Public point from private
1401  */
1402
1403 struct ec_point *ec_public(const Bignum privateKey, const struct ec_curve *curve)
1404 {
1405     if (curve->type == EC_WEIERSTRASS) {
1406         return ecp_mul(&curve->w.G, privateKey);
1407     } else if (curve->type == EC_EDWARDS) {
1408         /* hash = H(sk) (where hash creates 2 * fieldBits)
1409          * b = fieldBits
1410          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
1411          * publicKey = aB */
1412         struct ec_point *ret;
1413         unsigned char hash[512/8];
1414         Bignum a;
1415         int i, keylen;
1416         SHA512_State s;
1417         SHA512_Init(&s);
1418
1419         keylen = curve->fieldBits / 8;
1420         for (i = 0; i < keylen; ++i) {
1421             unsigned char b = bignum_byte(privateKey, i);
1422             SHA512_Bytes(&s, &b, 1);
1423         }
1424         SHA512_Final(&s, hash);
1425
1426         /* The second part is simply turning the hash into a Bignum,
1427          * however the 2^(b-2) bit *must* be set, and the bottom 3
1428          * bits *must* not be */
1429         hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
1430         hash[31] &= 0x7f; /* Unset above (b-2) */
1431         hash[31] |= 0x40; /* Set 2^(b-2) */
1432         /* Chop off the top part and convert to int */
1433         a = bignum_from_bytes_le(hash, 32);
1434
1435         ret = ecp_mul(&curve->e.B, a);
1436         freebn(a);
1437         return ret;
1438     } else {
1439         return NULL;
1440     }
1441 }
1442
1443 /* ----------------------------------------------------------------------
1444  * Basic sign and verify routines
1445  */
1446
1447 static int _ecdsa_verify(const struct ec_point *publicKey,
1448                          const unsigned char *data, const int dataLen,
1449                          const Bignum r, const Bignum s)
1450 {
1451     int z_bits, n_bits;
1452     Bignum z;
1453     int valid = 0;
1454
1455     if (publicKey->curve->type != EC_WEIERSTRASS) {
1456         return 0;
1457     }
1458
1459     /* Sanity checks */
1460     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->w.n) >= 0
1461         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->w.n) >= 0)
1462     {
1463         return 0;
1464     }
1465
1466     /* z = left most bitlen(curve->n) of data */
1467     z = bignum_from_bytes(data, dataLen);
1468     n_bits = bignum_bitcount(publicKey->curve->w.n);
1469     z_bits = bignum_bitcount(z);
1470     if (z_bits > n_bits)
1471     {
1472         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
1473         freebn(z);
1474         z = tmp;
1475     }
1476
1477     /* Ensure z in range of n */
1478     {
1479         Bignum tmp = bigmod(z, publicKey->curve->w.n);
1480         freebn(z);
1481         z = tmp;
1482     }
1483
1484     /* Calculate signature */
1485     {
1486         Bignum w, x, u1, u2;
1487         struct ec_point *tmp;
1488
1489         w = modinv(s, publicKey->curve->w.n);
1490         if (!w) {
1491             freebn(z);
1492             return 0;
1493         }
1494         u1 = modmul(z, w, publicKey->curve->w.n);
1495         u2 = modmul(r, w, publicKey->curve->w.n);
1496         freebn(w);
1497
1498         tmp = ecp_summul(u1, u2, publicKey);
1499         freebn(u1);
1500         freebn(u2);
1501         if (!tmp) {
1502             freebn(z);
1503             return 0;
1504         }
1505
1506         x = bigmod(tmp->x, publicKey->curve->w.n);
1507         ec_point_free(tmp);
1508
1509         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
1510         freebn(x);
1511     }
1512
1513     freebn(z);
1514
1515     return valid;
1516 }
1517
1518 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
1519                         const unsigned char *data, const int dataLen,
1520                         Bignum *r, Bignum *s)
1521 {
1522     unsigned char digest[20];
1523     int z_bits, n_bits;
1524     Bignum z, k;
1525     struct ec_point *kG;
1526
1527     *r = NULL;
1528     *s = NULL;
1529
1530     if (curve->type != EC_WEIERSTRASS) {
1531         return;
1532     }
1533
1534     /* z = left most bitlen(curve->n) of data */
1535     z = bignum_from_bytes(data, dataLen);
1536     n_bits = bignum_bitcount(curve->w.n);
1537     z_bits = bignum_bitcount(z);
1538     if (z_bits > n_bits)
1539     {
1540         Bignum tmp;
1541         tmp = bignum_rshift(z, z_bits - n_bits);
1542         freebn(z);
1543         z = tmp;
1544     }
1545
1546     /* Generate k between 1 and curve->n, using the same deterministic
1547      * k generation system we use for conventional DSA. */
1548     SHA_Simple(data, dataLen, digest);
1549     k = dss_gen_k("ECDSA deterministic k generator", curve->w.n, privateKey,
1550                   digest, sizeof(digest));
1551
1552     kG = ecp_mul(&curve->w.G, k);
1553     if (!kG) {
1554         freebn(z);
1555         freebn(k);
1556         return;
1557     }
1558
1559     /* r = kG.x mod n */
1560     *r = bigmod(kG->x, curve->w.n);
1561     ec_point_free(kG);
1562
1563     /* s = (z + r * priv)/k mod n */
1564     {
1565         Bignum rPriv, zMod, first, firstMod, kInv;
1566         rPriv = modmul(*r, privateKey, curve->w.n);
1567         zMod = bigmod(z, curve->w.n);
1568         freebn(z);
1569         first = bigadd(rPriv, zMod);
1570         freebn(rPriv);
1571         freebn(zMod);
1572         firstMod = bigmod(first, curve->w.n);
1573         freebn(first);
1574         kInv = modinv(k, curve->w.n);
1575         freebn(k);
1576         if (!kInv) {
1577             freebn(firstMod);
1578             freebn(*r);
1579             return;
1580         }
1581         *s = modmul(firstMod, kInv, curve->w.n);
1582         freebn(firstMod);
1583         freebn(kInv);
1584     }
1585 }
1586
1587 /* ----------------------------------------------------------------------
1588  * Misc functions
1589  */
1590
1591 static void getstring(const char **data, int *datalen,
1592                       const char **p, int *length)
1593 {
1594     *p = NULL;
1595     if (*datalen < 4)
1596         return;
1597     *length = toint(GET_32BIT(*data));
1598     if (*length < 0)
1599         return;
1600     *datalen -= 4;
1601     *data += 4;
1602     if (*datalen < *length)
1603         return;
1604     *p = *data;
1605     *data += *length;
1606     *datalen -= *length;
1607 }
1608
1609 static Bignum getmp(const char **data, int *datalen)
1610 {
1611     const char *p;
1612     int length;
1613
1614     getstring(data, datalen, &p, &length);
1615     if (!p)
1616         return NULL;
1617     if (p[0] & 0x80)
1618         return NULL;                   /* negative mp */
1619     return bignum_from_bytes((unsigned char *)p, length);
1620 }
1621
1622 static Bignum getmp_le(const char **data, int *datalen)
1623 {
1624     const char *p;
1625     int length;
1626
1627     getstring(data, datalen, &p, &length);
1628     if (!p)
1629         return NULL;
1630     return bignum_from_bytes_le((const unsigned char *)p, length);
1631 }
1632
1633 static int decodepoint_ed(const char *p, int length, struct ec_point *point)
1634 {
1635     /* Got some conversion to do, first read in the y co-ord */
1636     int negative;
1637
1638     point->y = bignum_from_bytes_le((const unsigned char*)p, length);
1639     if ((unsigned)bignum_bitcount(point->y) > point->curve->fieldBits) {
1640         freebn(point->y);
1641         point->y = NULL;
1642         return 0;
1643     }
1644     /* Read x bit and then reset it */
1645     negative = bignum_bit(point->y, point->curve->fieldBits - 1);
1646     bignum_set_bit(point->y, point->curve->fieldBits - 1, 0);
1647
1648     /* Get the x from the y */
1649     point->x = ecp_edx(point->curve, point->y);
1650     if (!point->x) {
1651         freebn(point->y);
1652         point->y = NULL;
1653         return 0;
1654     }
1655     if (negative) {
1656         Bignum tmp = modsub(point->curve->p, point->x, point->curve->p);
1657         freebn(point->x);
1658         point->x = tmp;
1659     }
1660
1661     /* Verify the point is on the curve */
1662     if (!ec_point_verify(point)) {
1663         freebn(point->x);
1664         point->x = NULL;
1665         freebn(point->y);
1666         point->y = NULL;
1667         return 0;
1668     }
1669
1670     return 1;
1671 }
1672
1673 static int decodepoint(const char *p, int length, struct ec_point *point)
1674 {
1675     if (point->curve->type == EC_EDWARDS) {
1676         return decodepoint_ed(p, length, point);
1677     }
1678
1679     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
1680         return 0;
1681     /* Skip compression flag */
1682     ++p;
1683     --length;
1684     /* The two values must be equal length */
1685     if (length % 2 != 0) {
1686         point->x = NULL;
1687         point->y = NULL;
1688         point->z = NULL;
1689         return 0;
1690     }
1691     length = length / 2;
1692     point->x = bignum_from_bytes((const unsigned char *)p, length);
1693     p += length;
1694     point->y = bignum_from_bytes((const unsigned char *)p, length);
1695     point->z = NULL;
1696
1697     /* Verify the point is on the curve */
1698     if (!ec_point_verify(point)) {
1699         freebn(point->x);
1700         point->x = NULL;
1701         freebn(point->y);
1702         point->y = NULL;
1703         return 0;
1704     }
1705
1706     return 1;
1707 }
1708
1709 static int getmppoint(const char **data, int *datalen, struct ec_point *point)
1710 {
1711     const char *p;
1712     int length;
1713
1714     getstring(data, datalen, &p, &length);
1715     if (!p) return 0;
1716     return decodepoint(p, length, point);
1717 }
1718
1719 /* ----------------------------------------------------------------------
1720  * Exposed ECDSA interface
1721  */
1722
1723 struct ecsign_extra {
1724     struct ec_curve *(*curve)(void);
1725     const struct ssh_hash *hash;
1726
1727     /* These fields are used by the OpenSSH PEM format importer/exporter */
1728     const unsigned char *oid;
1729     int oidlen;
1730 };
1731
1732 static void ecdsa_freekey(void *key)
1733 {
1734     struct ec_key *ec = (struct ec_key *) key;
1735     if (!ec) return;
1736
1737     if (ec->publicKey.x)
1738         freebn(ec->publicKey.x);
1739     if (ec->publicKey.y)
1740         freebn(ec->publicKey.y);
1741     if (ec->publicKey.z)
1742         freebn(ec->publicKey.z);
1743     if (ec->privateKey)
1744         freebn(ec->privateKey);
1745     sfree(ec);
1746 }
1747
1748 static void *ecdsa_newkey(const struct ssh_signkey *self,
1749                           const char *data, int len)
1750 {
1751     const struct ecsign_extra *extra =
1752         (const struct ecsign_extra *)self->extra;
1753     const char *p;
1754     int slen;
1755     struct ec_key *ec;
1756     struct ec_curve *curve;
1757
1758     getstring(&data, &len, &p, &slen);
1759
1760     if (!p) {
1761         return NULL;
1762     }
1763     curve = extra->curve();
1764     assert(curve->type == EC_WEIERSTRASS || curve->type == EC_EDWARDS);
1765
1766     /* Curve name is duplicated for Weierstrass form */
1767     if (curve->type == EC_WEIERSTRASS) {
1768         getstring(&data, &len, &p, &slen);
1769         if (!match_ssh_id(slen, p, curve->name)) return NULL;
1770     }
1771
1772     ec = snew(struct ec_key);
1773
1774     ec->signalg = self;
1775     ec->publicKey.curve = curve;
1776     ec->publicKey.infinity = 0;
1777     ec->publicKey.x = NULL;
1778     ec->publicKey.y = NULL;
1779     ec->publicKey.z = NULL;
1780     if (!getmppoint(&data, &len, &ec->publicKey)) {
1781         ecdsa_freekey(ec);
1782         return NULL;
1783     }
1784     ec->privateKey = NULL;
1785
1786     if (!ec->publicKey.x || !ec->publicKey.y ||
1787         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
1788         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
1789     {
1790         ecdsa_freekey(ec);
1791         ec = NULL;
1792     }
1793
1794     return ec;
1795 }
1796
1797 static char *ecdsa_fmtkey(void *key)
1798 {
1799     struct ec_key *ec = (struct ec_key *) key;
1800     char *p;
1801     int len, i, pos, nibbles;
1802     static const char hex[] = "0123456789abcdef";
1803     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
1804         return NULL;
1805
1806     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
1807     if (ec->publicKey.curve->name)
1808         len += strlen(ec->publicKey.curve->name); /* Curve name */
1809     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
1810     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
1811     p = snewn(len, char);
1812
1813     pos = 0;
1814     if (ec->publicKey.curve->name)
1815         pos += sprintf(p + pos, "%s,", ec->publicKey.curve->name);
1816     pos += sprintf(p + pos, "0x");
1817     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
1818     if (nibbles < 1)
1819         nibbles = 1;
1820     for (i = nibbles; i--;) {
1821         p[pos++] =
1822             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
1823     }
1824     pos += sprintf(p + pos, ",0x");
1825     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
1826     if (nibbles < 1)
1827         nibbles = 1;
1828     for (i = nibbles; i--;) {
1829         p[pos++] =
1830             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
1831     }
1832     p[pos] = '\0';
1833     return p;
1834 }
1835
1836 static unsigned char *ecdsa_public_blob(void *key, int *len)
1837 {
1838     struct ec_key *ec = (struct ec_key *) key;
1839     int pointlen, bloblen, fullnamelen, namelen;
1840     int i;
1841     unsigned char *blob, *p;
1842
1843     fullnamelen = strlen(ec->signalg->name);
1844
1845     if (ec->publicKey.curve->type == EC_EDWARDS) {
1846         /* Edwards compressed form "ssh-ed25519" point y[:-1] + x[0:1] */
1847
1848         pointlen = ec->publicKey.curve->fieldBits / 8;
1849
1850         /* Can't handle this in our loop */
1851         if (pointlen < 2) return NULL;
1852
1853         bloblen = 4 + fullnamelen + 4 + pointlen;
1854         blob = snewn(bloblen, unsigned char);
1855
1856         p = blob;
1857         PUT_32BIT(p, fullnamelen);
1858         p += 4;
1859         memcpy(p, ec->signalg->name, fullnamelen);
1860         p += fullnamelen;
1861         PUT_32BIT(p, pointlen);
1862         p += 4;
1863
1864         /* Unset last bit of y and set first bit of x in its place */
1865         for (i = 0; i < pointlen - 1; ++i) {
1866             *p++ = bignum_byte(ec->publicKey.y, i);
1867         }
1868         /* Unset last bit of y and set first bit of x in its place */
1869         *p = bignum_byte(ec->publicKey.y, i) & 0x7f;
1870         *p++ |= bignum_bit(ec->publicKey.x, 0) << 7;
1871     } else if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
1872         assert(ec->publicKey.curve->name);
1873         namelen = strlen(ec->publicKey.curve->name);
1874
1875         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
1876
1877         /*
1878          * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
1879          */
1880         bloblen = 4 + fullnamelen + 4 + namelen + 4 + 1 + (pointlen * 2);
1881         blob = snewn(bloblen, unsigned char);
1882
1883         p = blob;
1884         PUT_32BIT(p, fullnamelen);
1885         p += 4;
1886         memcpy(p, ec->signalg->name, fullnamelen);
1887         p += fullnamelen;
1888         PUT_32BIT(p, namelen);
1889         p += 4;
1890         memcpy(p, ec->publicKey.curve->name, namelen);
1891         p += namelen;
1892         PUT_32BIT(p, (2 * pointlen) + 1);
1893         p += 4;
1894         *p++ = 0x04;
1895         for (i = pointlen; i--;) {
1896             *p++ = bignum_byte(ec->publicKey.x, i);
1897         }
1898         for (i = pointlen; i--;) {
1899             *p++ = bignum_byte(ec->publicKey.y, i);
1900         }
1901     } else {
1902         return NULL;
1903     }
1904
1905     assert(p == blob + bloblen);
1906     *len = bloblen;
1907
1908     return blob;
1909 }
1910
1911 static unsigned char *ecdsa_private_blob(void *key, int *len)
1912 {
1913     struct ec_key *ec = (struct ec_key *) key;
1914     int keylen, bloblen;
1915     int i;
1916     unsigned char *blob, *p;
1917
1918     if (!ec->privateKey) return NULL;
1919
1920     if (ec->publicKey.curve->type == EC_EDWARDS) {
1921         /* Unsigned */
1922         keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
1923     } else {
1924         /* Signed */
1925         keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
1926     }
1927
1928     /*
1929      * mpint privateKey. Total 4 + keylen.
1930      */
1931     bloblen = 4 + keylen;
1932     blob = snewn(bloblen, unsigned char);
1933
1934     p = blob;
1935     PUT_32BIT(p, keylen);
1936     p += 4;
1937     if (ec->publicKey.curve->type == EC_EDWARDS) {
1938         /* Little endian */
1939         for (i = 0; i < keylen; ++i)
1940             *p++ = bignum_byte(ec->privateKey, i);
1941     } else {
1942         for (i = keylen; i--;)
1943             *p++ = bignum_byte(ec->privateKey, i);
1944     }
1945
1946     assert(p == blob + bloblen);
1947     *len = bloblen;
1948     return blob;
1949 }
1950
1951 static void *ecdsa_createkey(const struct ssh_signkey *self,
1952                              const unsigned char *pub_blob, int pub_len,
1953                              const unsigned char *priv_blob, int priv_len)
1954 {
1955     struct ec_key *ec;
1956     struct ec_point *publicKey;
1957     const char *pb = (const char *) priv_blob;
1958
1959     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) pub_blob, pub_len);
1960     if (!ec) {
1961         return NULL;
1962     }
1963
1964     if (ec->publicKey.curve->type != EC_WEIERSTRASS
1965         && ec->publicKey.curve->type != EC_EDWARDS) {
1966         ecdsa_freekey(ec);
1967         return NULL;
1968     }
1969
1970     if (ec->publicKey.curve->type == EC_EDWARDS) {
1971         ec->privateKey = getmp_le(&pb, &priv_len);
1972     } else {
1973         ec->privateKey = getmp(&pb, &priv_len);
1974     }
1975     if (!ec->privateKey) {
1976         ecdsa_freekey(ec);
1977         return NULL;
1978     }
1979
1980     /* Check that private key generates public key */
1981     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
1982
1983     if (!publicKey ||
1984         bignum_cmp(publicKey->x, ec->publicKey.x) ||
1985         bignum_cmp(publicKey->y, ec->publicKey.y))
1986     {
1987         ecdsa_freekey(ec);
1988         ec = NULL;
1989     }
1990     ec_point_free(publicKey);
1991
1992     return ec;
1993 }
1994
1995 static void *ed25519_openssh_createkey(const struct ssh_signkey *self,
1996                                        const unsigned char **blob, int *len)
1997 {
1998     struct ec_key *ec;
1999     struct ec_point *publicKey;
2000     const char *p, *q;
2001     int plen, qlen;
2002
2003     getstring((const char**)blob, len, &p, &plen);
2004     if (!p)
2005     {
2006         return NULL;
2007     }
2008
2009     ec = snew(struct ec_key);
2010
2011     ec->signalg = self;
2012     ec->publicKey.curve = ec_ed25519();
2013     ec->publicKey.infinity = 0;
2014     ec->privateKey = NULL;
2015     ec->publicKey.x = NULL;
2016     ec->publicKey.z = NULL;
2017     ec->publicKey.y = NULL;
2018
2019     if (!decodepoint_ed(p, plen, &ec->publicKey))
2020     {
2021         ecdsa_freekey(ec);
2022         return NULL;
2023     }
2024
2025     getstring((const char**)blob, len, &q, &qlen);
2026     if (!q)
2027         return NULL;
2028     if (qlen != 64)
2029         return NULL;
2030
2031     ec->privateKey = bignum_from_bytes_le((const unsigned char *)q, 32);
2032
2033     /* Check that private key generates public key */
2034     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2035
2036     if (!publicKey ||
2037         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2038         bignum_cmp(publicKey->y, ec->publicKey.y))
2039     {
2040         ecdsa_freekey(ec);
2041         ec = NULL;
2042     }
2043     ec_point_free(publicKey);
2044
2045     /* The OpenSSH format for ed25519 private keys also for some
2046      * reason encodes an extra copy of the public key in the second
2047      * half of the secret-key string. Check that that's present and
2048      * correct as well, otherwise the key we think we've imported
2049      * won't behave identically to the way OpenSSH would have treated
2050      * it. */
2051     if (plen != 32 || 0 != memcmp(q + 32, p, 32)) {
2052         ecdsa_freekey(ec);
2053         return NULL;
2054     }
2055
2056     return ec;
2057 }
2058
2059 static int ed25519_openssh_fmtkey(void *key, unsigned char *blob, int len)
2060 {
2061     struct ec_key *ec = (struct ec_key *) key;
2062
2063     int pointlen;
2064     int keylen;
2065     int bloblen;
2066     int i;
2067
2068     if (ec->publicKey.curve->type != EC_EDWARDS) {
2069         return 0;
2070     }
2071
2072     pointlen = (bignum_bitcount(ec->publicKey.y) + 7) / 8;
2073     keylen = (bignum_bitcount(ec->privateKey) + 7) / 8;
2074     bloblen = 4 + pointlen + 4 + keylen + pointlen;
2075
2076     if (bloblen > len)
2077         return bloblen;
2078
2079     /* Encode the public point */
2080     PUT_32BIT(blob, pointlen);
2081     blob += 4;
2082
2083     for (i = 0; i < pointlen - 1; ++i) {
2084          *blob++ = bignum_byte(ec->publicKey.y, i);
2085     }
2086     /* Unset last bit of y and set first bit of x in its place */
2087     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2088     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2089
2090     PUT_32BIT(blob, keylen + pointlen);
2091     blob += 4;
2092     for (i = 0; i < keylen; ++i) {
2093          *blob++ = bignum_byte(ec->privateKey, i);
2094     }
2095     /* Now encode an extra copy of the public point as the second half
2096      * of the private key string, as the OpenSSH format for some
2097      * reason requires */
2098     for (i = 0; i < pointlen - 1; ++i) {
2099          *blob++ = bignum_byte(ec->publicKey.y, i);
2100     }
2101     /* Unset last bit of y and set first bit of x in its place */
2102     *blob = bignum_byte(ec->publicKey.y, i) & 0x7f;
2103     *blob++ |= bignum_bit(ec->publicKey.x, 0) << 7;
2104
2105     return bloblen;
2106 }
2107
2108 static void *ecdsa_openssh_createkey(const struct ssh_signkey *self,
2109                                      const unsigned char **blob, int *len)
2110 {
2111     const struct ecsign_extra *extra =
2112         (const struct ecsign_extra *)self->extra;
2113     const char **b = (const char **) blob;
2114     const char *p;
2115     int slen;
2116     struct ec_key *ec;
2117     struct ec_curve *curve;
2118     struct ec_point *publicKey;
2119
2120     getstring(b, len, &p, &slen);
2121
2122     if (!p) {
2123         return NULL;
2124     }
2125     curve = extra->curve();
2126     assert(curve->type == EC_WEIERSTRASS);
2127
2128     ec = snew(struct ec_key);
2129
2130     ec->signalg = self;
2131     ec->publicKey.curve = curve;
2132     ec->publicKey.infinity = 0;
2133     ec->publicKey.x = NULL;
2134     ec->publicKey.y = NULL;
2135     ec->publicKey.z = NULL;
2136     if (!getmppoint(b, len, &ec->publicKey)) {
2137         ecdsa_freekey(ec);
2138         return NULL;
2139     }
2140     ec->privateKey = NULL;
2141
2142     if (!ec->publicKey.x || !ec->publicKey.y ||
2143         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2144         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2145     {
2146         ecdsa_freekey(ec);
2147         return NULL;
2148     }
2149
2150     ec->privateKey = getmp(b, len);
2151     if (ec->privateKey == NULL)
2152     {
2153         ecdsa_freekey(ec);
2154         return NULL;
2155     }
2156
2157     /* Now check that the private key makes the public key */
2158     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2159     if (!publicKey)
2160     {
2161         ecdsa_freekey(ec);
2162         return NULL;
2163     }
2164
2165     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
2166         bignum_cmp(ec->publicKey.y, publicKey->y))
2167     {
2168         /* Private key doesn't make the public key on the given curve */
2169         ecdsa_freekey(ec);
2170         ec_point_free(publicKey);
2171         return NULL;
2172     }
2173
2174     ec_point_free(publicKey);
2175
2176     return ec;
2177 }
2178
2179 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
2180 {
2181     struct ec_key *ec = (struct ec_key *) key;
2182
2183     int pointlen;
2184     int namelen;
2185     int bloblen;
2186     int i;
2187
2188     if (ec->publicKey.curve->type != EC_WEIERSTRASS) {
2189         return 0;
2190     }
2191
2192     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2193     namelen = strlen(ec->publicKey.curve->name);
2194     bloblen =
2195         4 + namelen /* <LEN> nistpXXX */
2196         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
2197         + ssh2_bignum_length(ec->privateKey);
2198
2199     if (bloblen > len)
2200         return bloblen;
2201
2202     bloblen = 0;
2203
2204     PUT_32BIT(blob+bloblen, namelen);
2205     bloblen += 4;
2206     memcpy(blob+bloblen, ec->publicKey.curve->name, namelen);
2207     bloblen += namelen;
2208
2209     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
2210     bloblen += 4;
2211     blob[bloblen++] = 0x04;
2212     for (i = pointlen; i--; )
2213         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
2214     for (i = pointlen; i--; )
2215         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
2216
2217     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2218     PUT_32BIT(blob+bloblen, pointlen);
2219     bloblen += 4;
2220     for (i = pointlen; i--; )
2221         blob[bloblen++] = bignum_byte(ec->privateKey, i);
2222
2223     return bloblen;
2224 }
2225
2226 static int ecdsa_pubkey_bits(const struct ssh_signkey *self,
2227                              const void *blob, int len)
2228 {
2229     struct ec_key *ec;
2230     int ret;
2231
2232     ec = (struct ec_key*)ecdsa_newkey(self, (const char *) blob, len);
2233     if (!ec)
2234         return -1;
2235     ret = ec->publicKey.curve->fieldBits;
2236     ecdsa_freekey(ec);
2237
2238     return ret;
2239 }
2240
2241 static int ecdsa_verifysig(void *key, const char *sig, int siglen,
2242                            const char *data, int datalen)
2243 {
2244     struct ec_key *ec = (struct ec_key *) key;
2245     const struct ecsign_extra *extra =
2246         (const struct ecsign_extra *)ec->signalg->extra;
2247     const char *p;
2248     int slen;
2249     int digestLen;
2250     int ret;
2251
2252     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2253         return 0;
2254
2255     /* Check the signature starts with the algorithm name */
2256     getstring(&sig, &siglen, &p, &slen);
2257     if (!p) {
2258         return 0;
2259     }
2260     if (!match_ssh_id(slen, p, ec->signalg->name)) {
2261         return 0;
2262     }
2263
2264     getstring(&sig, &siglen, &p, &slen);
2265     if (ec->publicKey.curve->type == EC_EDWARDS) {
2266         struct ec_point *r;
2267         Bignum s, h;
2268
2269         /* Check that the signature is two times the length of a point */
2270         if (slen != (ec->publicKey.curve->fieldBits / 8) * 2) {
2271             return 0;
2272         }
2273
2274         /* Check it's the 256 bit field so that SHA512 is the correct hash */
2275         if (ec->publicKey.curve->fieldBits != 256) {
2276             return 0;
2277         }
2278
2279         /* Get the signature */
2280         r = ec_point_new(ec->publicKey.curve, NULL, NULL, NULL, 0);
2281         if (!r) {
2282             return 0;
2283         }
2284         if (!decodepoint(p, ec->publicKey.curve->fieldBits / 8, r)) {
2285             ec_point_free(r);
2286             return 0;
2287         }
2288         s = bignum_from_bytes_le((unsigned char*)p + (ec->publicKey.curve->fieldBits / 8),
2289                                  ec->publicKey.curve->fieldBits / 8);
2290
2291         /* Get the hash of the encoded value of R + encoded value of pk + message */
2292         {
2293             int i, pointlen;
2294             unsigned char b;
2295             unsigned char digest[512 / 8];
2296             SHA512_State hs;
2297             SHA512_Init(&hs);
2298
2299             /* Add encoded r (no need to encode it again, it was in the signature) */
2300             SHA512_Bytes(&hs, p, ec->publicKey.curve->fieldBits / 8);
2301
2302             /* Encode pk and add it */
2303             pointlen = ec->publicKey.curve->fieldBits / 8;
2304             for (i = 0; i < pointlen - 1; ++i) {
2305                 b = bignum_byte(ec->publicKey.y, i);
2306                 SHA512_Bytes(&hs, &b, 1);
2307             }
2308             /* Unset last bit of y and set first bit of x in its place */
2309             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
2310             b |= bignum_bit(ec->publicKey.x, 0) << 7;
2311             SHA512_Bytes(&hs, &b, 1);
2312
2313             /* Add the message itself */
2314             SHA512_Bytes(&hs, data, datalen);
2315
2316             /* Get the hash */
2317             SHA512_Final(&hs, digest);
2318
2319             /* Convert to Bignum */
2320             h = bignum_from_bytes_le(digest, sizeof(digest));
2321         }
2322
2323         /* Verify sB == r + h*publicKey */
2324         {
2325             struct ec_point *lhs, *rhs, *tmp;
2326
2327             /* lhs = sB */
2328             lhs = ecp_mul(&ec->publicKey.curve->e.B, s);
2329             freebn(s);
2330             if (!lhs) {
2331                 ec_point_free(r);
2332                 freebn(h);
2333                 return 0;
2334             }
2335
2336             /* rhs = r + h*publicKey */
2337             tmp = ecp_mul(&ec->publicKey, h);
2338             freebn(h);
2339             if (!tmp) {
2340                 ec_point_free(lhs);
2341                 ec_point_free(r);
2342                 return 0;
2343             }
2344             rhs = ecp_add(r, tmp, 0);
2345             ec_point_free(r);
2346             ec_point_free(tmp);
2347             if (!rhs) {
2348                 ec_point_free(lhs);
2349                 return 0;
2350             }
2351
2352             /* Check the point is the same */
2353             ret = !bignum_cmp(lhs->x, rhs->x);
2354             if (ret) {
2355                 ret = !bignum_cmp(lhs->y, rhs->y);
2356                 if (ret) {
2357                     ret = 1;
2358                 }
2359             }
2360             ec_point_free(lhs);
2361             ec_point_free(rhs);
2362         }
2363     } else {
2364         Bignum r, s;
2365         unsigned char digest[512 / 8];
2366         void *hashctx;
2367
2368         r = getmp(&p, &slen);
2369         if (!r) return 0;
2370         s = getmp(&p, &slen);
2371         if (!s) {
2372             freebn(r);
2373             return 0;
2374         }
2375
2376         digestLen = extra->hash->hlen;
2377         assert(digestLen <= sizeof(digest));
2378         hashctx = extra->hash->init();
2379         extra->hash->bytes(hashctx, data, datalen);
2380         extra->hash->final(hashctx, digest);
2381
2382         /* Verify the signature */
2383         ret = _ecdsa_verify(&ec->publicKey, digest, digestLen, r, s);
2384
2385         freebn(r);
2386         freebn(s);
2387     }
2388
2389     return ret;
2390 }
2391
2392 static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
2393                                  int *siglen)
2394 {
2395     struct ec_key *ec = (struct ec_key *) key;
2396     const struct ecsign_extra *extra =
2397         (const struct ecsign_extra *)ec->signalg->extra;
2398     unsigned char digest[512 / 8];
2399     int digestLen;
2400     Bignum r = NULL, s = NULL;
2401     unsigned char *buf, *p;
2402     int rlen, slen, namelen;
2403     int i;
2404
2405     if (!ec->privateKey || !ec->publicKey.curve) {
2406         return NULL;
2407     }
2408
2409     if (ec->publicKey.curve->type == EC_EDWARDS) {
2410         struct ec_point *rp;
2411         int pointlen = ec->publicKey.curve->fieldBits / 8;
2412
2413         /* hash = H(sk) (where hash creates 2 * fieldBits)
2414          * b = fieldBits
2415          * a = 2^(b-2) + SUM(2^i * h_i) for i = 2 -> b-2
2416          * r = H(h[b/8:b/4] + m)
2417          * R = rB
2418          * S = (r + H(encodepoint(R) + encodepoint(pk) + m) * a) % l */
2419         {
2420             unsigned char hash[512/8];
2421             unsigned char b;
2422             Bignum a;
2423             SHA512_State hs;
2424             SHA512_Init(&hs);
2425
2426             for (i = 0; i < pointlen; ++i) {
2427                 unsigned char b = (unsigned char)bignum_byte(ec->privateKey, i);
2428                 SHA512_Bytes(&hs, &b, 1);
2429             }
2430
2431             SHA512_Final(&hs, hash);
2432
2433             /* The second part is simply turning the hash into a
2434              * Bignum, however the 2^(b-2) bit *must* be set, and the
2435              * bottom 3 bits *must* not be */
2436             hash[0] &= 0xf8; /* Unset bottom 3 bits (if set) */
2437             hash[31] &= 0x7f; /* Unset above (b-2) */
2438             hash[31] |= 0x40; /* Set 2^(b-2) */
2439             /* Chop off the top part and convert to int */
2440             a = bignum_from_bytes_le(hash, 32);
2441
2442             SHA512_Init(&hs);
2443             SHA512_Bytes(&hs,
2444                          hash+(ec->publicKey.curve->fieldBits / 8),
2445                          (ec->publicKey.curve->fieldBits / 4)
2446                          - (ec->publicKey.curve->fieldBits / 8));
2447             SHA512_Bytes(&hs, data, datalen);
2448             SHA512_Final(&hs, hash);
2449
2450             r = bignum_from_bytes_le(hash, 512/8);
2451             rp = ecp_mul(&ec->publicKey.curve->e.B, r);
2452             if (!rp) {
2453                 freebn(r);
2454                 freebn(a);
2455                 return NULL;
2456             }
2457
2458             /* Now calculate s */
2459             SHA512_Init(&hs);
2460             /* Encode the point R */
2461             for (i = 0; i < pointlen - 1; ++i) {
2462                 b = bignum_byte(rp->y, i);
2463                 SHA512_Bytes(&hs, &b, 1);
2464             }
2465             /* Unset last bit of y and set first bit of x in its place */
2466             b = bignum_byte(rp->y, i) & 0x7f;
2467             b |= bignum_bit(rp->x, 0) << 7;
2468             SHA512_Bytes(&hs, &b, 1);
2469
2470             /* Encode the point pk */
2471             for (i = 0; i < pointlen - 1; ++i) {
2472                 b = bignum_byte(ec->publicKey.y, i);
2473                 SHA512_Bytes(&hs, &b, 1);
2474             }
2475             /* Unset last bit of y and set first bit of x in its place */
2476             b = bignum_byte(ec->publicKey.y, i) & 0x7f;
2477             b |= bignum_bit(ec->publicKey.x, 0) << 7;
2478             SHA512_Bytes(&hs, &b, 1);
2479
2480             /* Add the message */
2481             SHA512_Bytes(&hs, data, datalen);
2482             SHA512_Final(&hs, hash);
2483
2484             {
2485                 Bignum tmp, tmp2;
2486
2487                 tmp = bignum_from_bytes_le(hash, 512/8);
2488                 tmp2 = modmul(tmp, a, ec->publicKey.curve->e.l);
2489                 freebn(a);
2490                 freebn(tmp);
2491                 tmp = bigadd(r, tmp2);
2492                 freebn(r);
2493                 freebn(tmp2);
2494                 s = bigmod(tmp, ec->publicKey.curve->e.l);
2495                 freebn(tmp);
2496             }
2497         }
2498
2499         /* Format the output */
2500         namelen = strlen(ec->signalg->name);
2501         *siglen = 4+namelen+4+((ec->publicKey.curve->fieldBits / 8)*2);
2502         buf = snewn(*siglen, unsigned char);
2503         p = buf;
2504         PUT_32BIT(p, namelen);
2505         p += 4;
2506         memcpy(p, ec->signalg->name, namelen);
2507         p += namelen;
2508         PUT_32BIT(p, ((ec->publicKey.curve->fieldBits / 8)*2));
2509         p += 4;
2510
2511         /* Encode the point */
2512         pointlen = ec->publicKey.curve->fieldBits / 8;
2513         for (i = 0; i < pointlen - 1; ++i) {
2514             *p++ = bignum_byte(rp->y, i);
2515         }
2516         /* Unset last bit of y and set first bit of x in its place */
2517         *p = bignum_byte(rp->y, i) & 0x7f;
2518         *p++ |= bignum_bit(rp->x, 0) << 7;
2519         ec_point_free(rp);
2520
2521         /* Encode the int */
2522         for (i = 0; i < pointlen; ++i) {
2523             *p++ = bignum_byte(s, i);
2524         }
2525         freebn(s);
2526     } else {
2527         void *hashctx;
2528
2529         digestLen = extra->hash->hlen;
2530         assert(digestLen <= sizeof(digest));
2531         hashctx = extra->hash->init();
2532         extra->hash->bytes(hashctx, data, datalen);
2533         extra->hash->final(hashctx, digest);
2534
2535         /* Do the signature */
2536         _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
2537         if (!r || !s) {
2538             if (r) freebn(r);
2539             if (s) freebn(s);
2540             return NULL;
2541         }
2542
2543         rlen = (bignum_bitcount(r) + 8) / 8;
2544         slen = (bignum_bitcount(s) + 8) / 8;
2545
2546         namelen = strlen(ec->signalg->name);
2547
2548         /* Format the output */
2549         *siglen = 8+namelen+rlen+slen+8;
2550         buf = snewn(*siglen, unsigned char);
2551         p = buf;
2552         PUT_32BIT(p, namelen);
2553         p += 4;
2554         memcpy(p, ec->signalg->name, namelen);
2555         p += namelen;
2556         PUT_32BIT(p, rlen + slen + 8);
2557         p += 4;
2558         PUT_32BIT(p, rlen);
2559         p += 4;
2560         for (i = rlen; i--;)
2561             *p++ = bignum_byte(r, i);
2562         PUT_32BIT(p, slen);
2563         p += 4;
2564         for (i = slen; i--;)
2565             *p++ = bignum_byte(s, i);
2566
2567         freebn(r);
2568         freebn(s);
2569     }
2570
2571     return buf;
2572 }
2573
2574 const struct ecsign_extra sign_extra_ed25519 = {
2575     ec_ed25519, NULL,
2576     NULL, 0,
2577 };
2578 const struct ssh_signkey ssh_ecdsa_ed25519 = {
2579     ecdsa_newkey,
2580     ecdsa_freekey,
2581     ecdsa_fmtkey,
2582     ecdsa_public_blob,
2583     ecdsa_private_blob,
2584     ecdsa_createkey,
2585     ed25519_openssh_createkey,
2586     ed25519_openssh_fmtkey,
2587     2 /* point, private exponent */,
2588     ecdsa_pubkey_bits,
2589     ecdsa_verifysig,
2590     ecdsa_sign,
2591     "ssh-ed25519",
2592     "ssh-ed25519",
2593     &sign_extra_ed25519,
2594 };
2595
2596 /* OID: 1.2.840.10045.3.1.7 (ansiX9p256r1) */
2597 static const unsigned char nistp256_oid[] = {
2598     0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07
2599 };
2600 const struct ecsign_extra sign_extra_nistp256 = {
2601     ec_p256, &ssh_sha256,
2602     nistp256_oid, lenof(nistp256_oid),
2603 };
2604 const struct ssh_signkey ssh_ecdsa_nistp256 = {
2605     ecdsa_newkey,
2606     ecdsa_freekey,
2607     ecdsa_fmtkey,
2608     ecdsa_public_blob,
2609     ecdsa_private_blob,
2610     ecdsa_createkey,
2611     ecdsa_openssh_createkey,
2612     ecdsa_openssh_fmtkey,
2613     3 /* curve name, point, private exponent */,
2614     ecdsa_pubkey_bits,
2615     ecdsa_verifysig,
2616     ecdsa_sign,
2617     "ecdsa-sha2-nistp256",
2618     "ecdsa-sha2-nistp256",
2619     &sign_extra_nistp256,
2620 };
2621
2622 /* OID: 1.3.132.0.34 (secp384r1) */
2623 static const unsigned char nistp384_oid[] = {
2624     0x2b, 0x81, 0x04, 0x00, 0x22
2625 };
2626 const struct ecsign_extra sign_extra_nistp384 = {
2627     ec_p384, &ssh_sha384,
2628     nistp384_oid, lenof(nistp384_oid),
2629 };
2630 const struct ssh_signkey ssh_ecdsa_nistp384 = {
2631     ecdsa_newkey,
2632     ecdsa_freekey,
2633     ecdsa_fmtkey,
2634     ecdsa_public_blob,
2635     ecdsa_private_blob,
2636     ecdsa_createkey,
2637     ecdsa_openssh_createkey,
2638     ecdsa_openssh_fmtkey,
2639     3 /* curve name, point, private exponent */,
2640     ecdsa_pubkey_bits,
2641     ecdsa_verifysig,
2642     ecdsa_sign,
2643     "ecdsa-sha2-nistp384",
2644     "ecdsa-sha2-nistp384",
2645     &sign_extra_nistp384,
2646 };
2647
2648 /* OID: 1.3.132.0.35 (secp521r1) */
2649 static const unsigned char nistp521_oid[] = {
2650     0x2b, 0x81, 0x04, 0x00, 0x23
2651 };
2652 const struct ecsign_extra sign_extra_nistp521 = {
2653     ec_p521, &ssh_sha512,
2654     nistp521_oid, lenof(nistp521_oid),
2655 };
2656 const struct ssh_signkey ssh_ecdsa_nistp521 = {
2657     ecdsa_newkey,
2658     ecdsa_freekey,
2659     ecdsa_fmtkey,
2660     ecdsa_public_blob,
2661     ecdsa_private_blob,
2662     ecdsa_createkey,
2663     ecdsa_openssh_createkey,
2664     ecdsa_openssh_fmtkey,
2665     3 /* curve name, point, private exponent */,
2666     ecdsa_pubkey_bits,
2667     ecdsa_verifysig,
2668     ecdsa_sign,
2669     "ecdsa-sha2-nistp521",
2670     "ecdsa-sha2-nistp521",
2671     &sign_extra_nistp521,
2672 };
2673
2674 /* ----------------------------------------------------------------------
2675  * Exposed ECDH interface
2676  */
2677
2678 struct eckex_extra {
2679     struct ec_curve *(*curve)(void);
2680 };
2681
2682 static Bignum ecdh_calculate(const Bignum private,
2683                              const struct ec_point *public)
2684 {
2685     struct ec_point *p;
2686     Bignum ret;
2687     p = ecp_mul(public, private);
2688     if (!p) return NULL;
2689     ret = p->x;
2690     p->x = NULL;
2691
2692     if (p->curve->type == EC_MONTGOMERY) {
2693         /*
2694          * Endianness-swap. The Curve25519 algorithm definition
2695          * assumes you were doing your computation in arrays of 32
2696          * little-endian bytes, and now specifies that you take your
2697          * final one of those and convert it into a bignum in
2698          * _network_ byte order, i.e. big-endian.
2699          *
2700          * In particular, the spec says, you convert the _whole_ 32
2701          * bytes into a bignum. That is, on the rare occasions that
2702          * p->x has come out with the most significant 8 bits zero, we
2703          * have to imagine that being represented by a 32-byte string
2704          * with the last byte being zero, so that has to be converted
2705          * into an SSH-2 bignum with the _low_ byte zero, i.e. a
2706          * multiple of 256.
2707          */
2708         int i;
2709         int bytes = (p->curve->fieldBits+7) / 8;
2710         unsigned char *byteorder = snewn(bytes, unsigned char);
2711         for (i = 0; i < bytes; ++i) {
2712             byteorder[i] = bignum_byte(ret, i);
2713         }
2714         freebn(ret);
2715         ret = bignum_from_bytes(byteorder, bytes);
2716         smemclr(byteorder, bytes);
2717         sfree(byteorder);
2718     }
2719
2720     ec_point_free(p);
2721     return ret;
2722 }
2723
2724 const char *ssh_ecdhkex_curve_textname(const struct ssh_kex *kex)
2725 {
2726     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
2727     struct ec_curve *curve = extra->curve();
2728     return curve->textname;
2729 }
2730
2731 void *ssh_ecdhkex_newkey(const struct ssh_kex *kex)
2732 {
2733     const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
2734     struct ec_curve *curve;
2735     struct ec_key *key;
2736     struct ec_point *publicKey;
2737
2738     curve = extra->curve();
2739
2740     key = snew(struct ec_key);
2741
2742     key->signalg = NULL;
2743     key->publicKey.curve = curve;
2744
2745     if (curve->type == EC_MONTGOMERY) {
2746         unsigned char bytes[32] = {0};
2747         int i;
2748
2749         for (i = 0; i < sizeof(bytes); ++i)
2750         {
2751             bytes[i] = (unsigned char)random_byte();
2752         }
2753         bytes[0] &= 248;
2754         bytes[31] &= 127;
2755         bytes[31] |= 64;
2756         key->privateKey = bignum_from_bytes(bytes, sizeof(bytes));
2757         for (i = 0; i < sizeof(bytes); ++i)
2758         {
2759             ((volatile char*)bytes)[i] = 0;
2760         }
2761         if (!key->privateKey) {
2762             sfree(key);
2763             return NULL;
2764         }
2765         publicKey = ecp_mul(&key->publicKey.curve->m.G, key->privateKey);
2766         if (!publicKey) {
2767             freebn(key->privateKey);
2768             sfree(key);
2769             return NULL;
2770         }
2771         key->publicKey.x = publicKey->x;
2772         key->publicKey.y = publicKey->y;
2773         key->publicKey.z = NULL;
2774         sfree(publicKey);
2775     } else {
2776         key->privateKey = bignum_random_in_range(One, key->publicKey.curve->w.n);
2777         if (!key->privateKey) {
2778             sfree(key);
2779             return NULL;
2780         }
2781         publicKey = ecp_mul(&key->publicKey.curve->w.G, key->privateKey);
2782         if (!publicKey) {
2783             freebn(key->privateKey);
2784             sfree(key);
2785             return NULL;
2786         }
2787         key->publicKey.x = publicKey->x;
2788         key->publicKey.y = publicKey->y;
2789         key->publicKey.z = NULL;
2790         sfree(publicKey);
2791     }
2792     return key;
2793 }
2794
2795 char *ssh_ecdhkex_getpublic(void *key, int *len)
2796 {
2797     struct ec_key *ec = (struct ec_key*)key;
2798     char *point, *p;
2799     int i;
2800     int pointlen;
2801
2802     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2803
2804     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2805         *len = 1 + pointlen * 2;
2806     } else {
2807         *len = pointlen;
2808     }
2809     point = (char*)snewn(*len, char);
2810
2811     p = point;
2812     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2813         *p++ = 0x04;
2814         for (i = pointlen; i--;) {
2815             *p++ = bignum_byte(ec->publicKey.x, i);
2816         }
2817         for (i = pointlen; i--;) {
2818             *p++ = bignum_byte(ec->publicKey.y, i);
2819         }
2820     } else {
2821         for (i = 0; i < pointlen; ++i) {
2822             *p++ = bignum_byte(ec->publicKey.x, i);
2823         }
2824     }
2825
2826     return point;
2827 }
2828
2829 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
2830 {
2831     struct ec_key *ec = (struct ec_key*) key;
2832     struct ec_point remote;
2833     Bignum ret;
2834
2835     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2836         remote.curve = ec->publicKey.curve;
2837         remote.infinity = 0;
2838         if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
2839             return NULL;
2840         }
2841     } else {
2842         /* Point length has to be the same length */
2843         if (remoteKeyLen != (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8) {
2844             return NULL;
2845         }
2846
2847         remote.curve = ec->publicKey.curve;
2848         remote.infinity = 0;
2849         remote.x = bignum_from_bytes_le((unsigned char*)remoteKey, remoteKeyLen);
2850         remote.y = NULL;
2851         remote.z = NULL;
2852     }
2853
2854     ret = ecdh_calculate(ec->privateKey, &remote);
2855     if (remote.x) freebn(remote.x);
2856     if (remote.y) freebn(remote.y);
2857     return ret;
2858 }
2859
2860 void ssh_ecdhkex_freekey(void *key)
2861 {
2862     ecdsa_freekey(key);
2863 }
2864
2865 static const struct eckex_extra kex_extra_curve25519 = { ec_curve25519 };
2866 static const struct ssh_kex ssh_ec_kex_curve25519 = {
2867     "curve25519-sha256@libssh.org", NULL, KEXTYPE_ECDH,
2868     &ssh_sha256, &kex_extra_curve25519,
2869 };
2870
2871 const struct eckex_extra kex_extra_nistp256 = { ec_p256 };
2872 static const struct ssh_kex ssh_ec_kex_nistp256 = {
2873     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH,
2874     &ssh_sha256, &kex_extra_nistp256,
2875 };
2876
2877 const struct eckex_extra kex_extra_nistp384 = { ec_p384 };
2878 static const struct ssh_kex ssh_ec_kex_nistp384 = {
2879     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH,
2880     &ssh_sha384, &kex_extra_nistp384,
2881 };
2882
2883 const struct eckex_extra kex_extra_nistp521 = { ec_p521 };
2884 static const struct ssh_kex ssh_ec_kex_nistp521 = {
2885     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH,
2886     &ssh_sha512, &kex_extra_nistp521,
2887 };
2888
2889 static const struct ssh_kex *const ec_kex_list[] = {
2890     &ssh_ec_kex_curve25519,
2891     &ssh_ec_kex_nistp256,
2892     &ssh_ec_kex_nistp384,
2893     &ssh_ec_kex_nistp521,
2894 };
2895
2896 const struct ssh_kexes ssh_ecdh_kex = {
2897     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
2898     ec_kex_list
2899 };
2900
2901 /* ----------------------------------------------------------------------
2902  * Helper functions for finding key algorithms and returning auxiliary
2903  * data.
2904  */
2905
2906 const struct ssh_signkey *ec_alg_by_oid(int len, const void *oid,
2907                                         const struct ec_curve **curve)
2908 {
2909     static const struct ssh_signkey *algs_with_oid[] = {
2910         &ssh_ecdsa_nistp256,
2911         &ssh_ecdsa_nistp384,
2912         &ssh_ecdsa_nistp521,
2913     };
2914     int i;
2915
2916     for (i = 0; i < lenof(algs_with_oid); i++) {
2917         const struct ssh_signkey *alg = algs_with_oid[i];
2918         const struct ecsign_extra *extra =
2919             (const struct ecsign_extra *)alg->extra;
2920         if (len == extra->oidlen && !memcmp(oid, extra->oid, len)) {
2921             *curve = extra->curve();
2922             return alg;
2923         }
2924     }
2925     return NULL;
2926 }
2927
2928 const unsigned char *ec_alg_oid(const struct ssh_signkey *alg,
2929                                 int *oidlen)
2930 {
2931     const struct ecsign_extra *extra = (const struct ecsign_extra *)alg->extra;
2932     *oidlen = extra->oidlen;
2933     return extra->oid;
2934 }
2935
2936 const int ec_nist_alg_and_curve_by_bits(int bits,
2937                                         const struct ec_curve **curve,
2938                                         const struct ssh_signkey **alg)
2939 {
2940     switch (bits) {
2941       case 256: *alg = &ssh_ecdsa_nistp256; break;
2942       case 384: *alg = &ssh_ecdsa_nistp384; break;
2943       case 521: *alg = &ssh_ecdsa_nistp521; break;
2944       default: return FALSE;
2945     }
2946     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
2947     return TRUE;
2948 }
2949
2950 const int ec_ed_alg_and_curve_by_bits(int bits,
2951                                       const struct ec_curve **curve,
2952                                       const struct ssh_signkey **alg)
2953 {
2954     switch (bits) {
2955       case 256: *alg = &ssh_ecdsa_ed25519; break;
2956       default: return FALSE;
2957     }
2958     *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
2959     return TRUE;
2960 }