]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshecc.c
Support ECDH key exchange using the 'curve25519' curve.
[PuTTY.git] / sshecc.c
1 /*
2  * Elliptic-curve crypto module for PuTTY
3  * Implements the three required curves, no optional curves
4  *
5  * NOTE: Only curves on prime field are handled by the maths functions
6  *       in Weierstrass form using Jacobian co-ordinates.
7  *
8  *       Montgomery form curves are supported for DH. (Curve25519)
9  */
10
11 /*
12  * References:
13  *
14  * Elliptic curves in SSH are specified in RFC 5656:
15  *   http://tools.ietf.org/html/rfc5656
16  *
17  * That specification delegates details of public key formatting and a
18  * lot of underlying mechanism to SEC 1:
19  *   http://www.secg.org/sec1-v2.pdf
20  *
21  * Montgomery maths from:
22  * Handbook of elliptic and hyperelliptic curve cryptography, Chapter 13
23  *   http://cs.ucsb.edu/~koc/ccs130h/2013/EllipticHyperelliptic-CohenFrey.pdf
24  */
25
26 #include <stdlib.h>
27 #include <assert.h>
28
29 #include "ssh.h"
30
31 /* ----------------------------------------------------------------------
32  * Elliptic curve definitions
33  */
34
35 static int initialise_wcurve(struct ec_curve *curve, int bits, unsigned char *p,
36                              unsigned char *a, unsigned char *b,
37                              unsigned char *n, unsigned char *Gx,
38                              unsigned char *Gy)
39 {
40     int length = bits / 8;
41     if (bits % 8) ++length;
42
43     curve->type = EC_WEIERSTRASS;
44
45     curve->fieldBits = bits;
46     curve->p = bignum_from_bytes(p, length);
47     if (!curve->p) goto error;
48
49     /* Curve co-efficients */
50     curve->w.a = bignum_from_bytes(a, length);
51     if (!curve->w.a) goto error;
52     curve->w.b = bignum_from_bytes(b, length);
53     if (!curve->w.b) goto error;
54
55     /* Group order and generator */
56     curve->w.n = bignum_from_bytes(n, length);
57     if (!curve->w.n) goto error;
58     curve->w.G.x = bignum_from_bytes(Gx, length);
59     if (!curve->w.G.x) goto error;
60     curve->w.G.y = bignum_from_bytes(Gy, length);
61     if (!curve->w.G.y) goto error;
62     curve->w.G.curve = curve;
63     curve->w.G.infinity = 0;
64
65     return 1;
66   error:
67     if (curve->p) freebn(curve->p);
68     if (curve->w.a) freebn(curve->w.a);
69     if (curve->w.b) freebn(curve->w.b);
70     if (curve->w.n) freebn(curve->w.n);
71     if (curve->w.G.x) freebn(curve->w.G.x);
72     return 0;
73 }
74
75 static int initialise_mcurve(struct ec_curve *curve, int bits, unsigned char *p,
76                              unsigned char *a, unsigned char *b,
77                              unsigned char *Gx)
78 {
79     int length = bits / 8;
80     if (bits % 8) ++length;
81
82     curve->type = EC_MONTGOMERY;
83
84     curve->fieldBits = bits;
85     curve->p = bignum_from_bytes(p, length);
86     if (!curve->p) goto error;
87
88     /* Curve co-efficients */
89     curve->m.a = bignum_from_bytes(a, length);
90     if (!curve->m.a) goto error;
91     curve->m.b = bignum_from_bytes(b, length);
92     if (!curve->m.b) goto error;
93
94     /* Generator */
95     curve->m.G.x = bignum_from_bytes(Gx, length);
96     if (!curve->m.G.x) goto error;
97     curve->m.G.y = NULL;
98     curve->m.G.z = NULL;
99     curve->m.G.curve = curve;
100     curve->m.G.infinity = 0;
101
102     return 1;
103   error:
104     if (curve->p) freebn(curve->p);
105     if (curve->m.a) freebn(curve->m.a);
106     if (curve->m.b) freebn(curve->m.b);
107     return 0;
108 }
109
110 unsigned char nistp256_oid[] = {0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07};
111 int nistp256_oid_len = 8;
112 unsigned char nistp384_oid[] = {0x2b, 0x81, 0x04, 0x00, 0x22};
113 int nistp384_oid_len = 5;
114 unsigned char nistp521_oid[] = {0x2b, 0x81, 0x04, 0x00, 0x23};
115 int nistp521_oid_len = 5;
116 unsigned char curve25519_oid[] = {0x06, 0x0A, 0x2B, 0x06, 0x01, 0x04, 0x01, 0x97, 0x55, 0x01, 0x05, 0x01};
117 int curve25519_oid_len = 12;
118
119 struct ec_curve *ec_p256(void)
120 {
121     static struct ec_curve curve = { 0 };
122     static unsigned char initialised = 0;
123
124     if (!initialised)
125     {
126         unsigned char p[] = {
127             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
128             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
129             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
130             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
131         };
132         unsigned char a[] = {
133             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01,
134             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
135             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff,
136             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc
137         };
138         unsigned char b[] = {
139             0x5a, 0xc6, 0x35, 0xd8, 0xaa, 0x3a, 0x93, 0xe7,
140             0xb3, 0xeb, 0xbd, 0x55, 0x76, 0x98, 0x86, 0xbc,
141             0x65, 0x1d, 0x06, 0xb0, 0xcc, 0x53, 0xb0, 0xf6,
142             0x3b, 0xce, 0x3c, 0x3e, 0x27, 0xd2, 0x60, 0x4b
143         };
144         unsigned char n[] = {
145             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
146             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
147             0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
148             0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51
149         };
150         unsigned char Gx[] = {
151             0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47,
152             0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2,
153             0x77, 0x03, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0,
154             0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96
155         };
156         unsigned char Gy[] = {
157             0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b,
158             0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0x0f, 0x9e, 0x16,
159             0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce,
160             0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5
161         };
162
163         if (!initialise_wcurve(&curve, 256, p, a, b, n, Gx, Gy)) {
164             return NULL;
165         }
166
167         /* Now initialised, no need to do it again */
168         initialised = 1;
169     }
170
171     return &curve;
172 }
173
174 struct ec_curve *ec_p384(void)
175 {
176     static struct ec_curve curve = { 0 };
177     static unsigned char initialised = 0;
178
179     if (!initialised)
180     {
181         unsigned char p[] = {
182             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
183             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
184             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
185             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
186             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
187             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff
188         };
189         unsigned char a[] = {
190             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
191             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
192             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
193             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe,
194             0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
195             0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xfc
196         };
197         unsigned char b[] = {
198             0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4,
199             0x98, 0x8e, 0x05, 0x6b, 0xe3, 0xf8, 0x2d, 0x19,
200             0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12,
201             0x03, 0x14, 0x08, 0x8f, 0x50, 0x13, 0x87, 0x5a,
202             0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d,
203             0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef
204         };
205         unsigned char n[] = {
206             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
207             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
208             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
209             0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
210             0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
211             0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73
212         };
213         unsigned char Gx[] = {
214             0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x05, 0x37,
215             0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74,
216             0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98,
217             0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38,
218             0x55, 0x02, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c,
219             0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0x0a, 0xb7
220         };
221         unsigned char Gy[] = {
222             0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f,
223             0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29,
224             0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c,
225             0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0,
226             0x0a, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d,
227             0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0x0e, 0x5f
228         };
229
230         if (!initialise_wcurve(&curve, 384, p, a, b, n, Gx, Gy)) {
231             return NULL;
232         }
233
234         /* Now initialised, no need to do it again */
235         initialised = 1;
236     }
237
238     return &curve;
239 }
240
241 struct ec_curve *ec_p521(void)
242 {
243     static struct ec_curve curve = { 0 };
244     static unsigned char initialised = 0;
245
246     if (!initialised)
247     {
248         unsigned char p[] = {
249             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
250             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
251             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
252             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
253             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
254             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
255             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
256             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
257             0xff, 0xff
258         };
259         unsigned char a[] = {
260             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
261             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
262             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
263             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
264             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
265             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
266             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
267             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
268             0xff, 0xfc
269         };
270         unsigned char b[] = {
271             0x00, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c,
272             0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85,
273             0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3,
274             0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1,
275             0x09, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e,
276             0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1,
277             0xbf, 0x07, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c,
278             0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50,
279             0x3f, 0x00
280         };
281         unsigned char n[] = {
282             0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
283             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
284             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
285             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
286             0xff, 0xfa, 0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f,
287             0x96, 0x6b, 0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09,
288             0xa5, 0xd0, 0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c,
289             0x47, 0xae, 0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38,
290             0x64, 0x09
291         };
292         unsigned char Gx[] = {
293             0x00, 0xc6, 0x85, 0x8e, 0x06, 0xb7, 0x04, 0x04,
294             0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95,
295             0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x05, 0x3f,
296             0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d,
297             0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7,
298             0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff,
299             0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a,
300             0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5,
301             0xbd, 0x66
302         };
303         unsigned char Gy[] = {
304             0x01, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b,
305             0xc0, 0x04, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d,
306             0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b,
307             0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e,
308             0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4,
309             0x26, 0x40, 0xc5, 0x50, 0xb9, 0x01, 0x3f, 0xad,
310             0x07, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72,
311             0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1,
312             0x66, 0x50
313         };
314
315         if (!initialise_wcurve(&curve, 521, p, a, b, n, Gx, Gy)) {
316             return NULL;
317         }
318
319         /* Now initialised, no need to do it again */
320         initialised = 1;
321     }
322
323     return &curve;
324 }
325
326 struct ec_curve *ec_curve25519(void)
327 {
328     static struct ec_curve curve = { 0 };
329     static unsigned char initialised = 0;
330
331     if (!initialised)
332     {
333         unsigned char p[] = {
334             0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
335             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
336             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
337             0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xed
338         };
339         unsigned char a[] = {
340             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
341             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
342             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
343             0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x6d, 0x06
344         };
345         unsigned char b[] = {
346             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
347             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
348             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
349             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01
350         };
351         unsigned char gx[32] = {
352             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
353             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
354             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
355             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09
356         };
357
358         if (!initialise_mcurve(&curve, 256, p, a, b, gx)) {
359             return NULL;
360         }
361
362         /* Now initialised, no need to do it again */
363         initialised = 1;
364     }
365
366     return &curve;
367 }
368
369 static struct ec_curve *ec_name_to_curve(const char *name, int len) {
370     if (len > 11 && !memcmp(name, "ecdsa-sha2-", 11)) {
371         name += 11;
372         len -= 11;
373     } else if (len > 10 && !memcmp(name, "ecdh-sha2-", 10)) {
374         name += 10;
375         len -= 10;
376     }
377
378     if (len == 8 && !memcmp(name, "nistp", 5)) {
379         name += 5;
380         if (!memcmp(name, "256", 3)) {
381             return ec_p256();
382         } else if (!memcmp(name, "384", 3)) {
383             return ec_p384();
384         } else if (!memcmp(name, "521", 3)) {
385             return ec_p521();
386         }
387     }
388
389     if (len == 28 && !memcmp(name, "curve25519-sha256@libssh.org", 28)) {
390         return ec_curve25519();
391     }
392
393     return NULL;
394 }
395
396 /* Type enumeration for specifying the curve name */
397 enum ec_name_type { EC_TYPE_DSA, EC_TYPE_DH, EC_TYPE_CURVE };
398
399 static int ec_curve_to_name(enum ec_name_type type, const struct ec_curve *curve,
400                             unsigned char *name, int len) {
401     if (curve->type == EC_WEIERSTRASS) {
402         int length, loc;
403         if (type == EC_TYPE_DSA) {
404             length = 19;
405             loc = 16;
406         } else if (type == EC_TYPE_DH) {
407             length = 18;
408             loc = 15;
409         } else {
410             length = 8;
411             loc = 5;
412         }
413
414         /* Return length of string */
415         if (name == NULL) return length;
416
417         /* Not enough space for the name */
418         if (len < length) return 0;
419
420         /* Put the name in the buffer */
421         switch (curve->fieldBits) {
422           case 256:
423             memcpy(name+loc, "256", 3);
424             break;
425           case 384:
426             memcpy(name+loc, "384", 3);
427             break;
428           case 521:
429             memcpy(name+loc, "521", 3);
430             break;
431           default:
432             return 0;
433         }
434
435         if (type == EC_TYPE_DSA) {
436             memcpy(name, "ecdsa-sha2-nistp", 16);
437         } else if (type == EC_TYPE_DH) {
438             memcpy(name, "ecdh-sha2-nistp", 15);
439         } else {
440             memcpy(name, "nistp", 5);
441         }
442
443         return length;
444     } else {
445         /* No DSA for curve25519 */
446         if (type == EC_TYPE_DSA || type == EC_TYPE_CURVE) return 0;
447
448         /* Return length of string */
449         if (name == NULL) return 28;
450
451         /* Not enough space for the name */
452         if (len < 28) return 0;
453
454         /* Unknown curve field */
455         if (curve->fieldBits != 256) return 0;
456
457         memcpy(name, "curve25519-sha256@libssh.org", 28);
458         return 28;
459     }
460 }
461
462 /* Return 1 if a is -3 % p, otherwise return 0
463  * This is used because there are some maths optimisations */
464 static int ec_aminus3(const struct ec_curve *curve)
465 {
466     int ret;
467     Bignum _p;
468
469     if (curve->type != EC_WEIERSTRASS) {
470         return 0;
471     }
472
473     _p = bignum_add_long(curve->w.a, 3);
474     if (!_p) return 0;
475
476     ret = !bignum_cmp(curve->p, _p);
477     freebn(_p);
478     return ret;
479 }
480
481 /* ----------------------------------------------------------------------
482  * Elliptic curve field maths
483  */
484
485 static Bignum ecf_add(const Bignum a, const Bignum b,
486                       const struct ec_curve *curve)
487 {
488     Bignum a1, b1, ab, ret;
489
490     a1 = bigmod(a, curve->p);
491     if (!a1) return NULL;
492     b1 = bigmod(b, curve->p);
493     if (!b1)
494     {
495         freebn(a1);
496         return NULL;
497     }
498
499     ab = bigadd(a1, b1);
500     freebn(a1);
501     freebn(b1);
502     if (!ab) return NULL;
503
504     ret = bigmod(ab, curve->p);
505     freebn(ab);
506
507     return ret;
508 }
509
510 static Bignum ecf_square(const Bignum a, const struct ec_curve *curve)
511 {
512     return modmul(a, a, curve->p);
513 }
514
515 static Bignum ecf_treble(const Bignum a, const struct ec_curve *curve)
516 {
517     Bignum ret, tmp;
518
519     /* Double */
520     tmp = bignum_lshift(a, 1);
521     if (!tmp) return NULL;
522
523     /* Add itself (i.e. treble) */
524     ret = bigadd(tmp, a);
525     freebn(tmp);
526
527     /* Normalise */
528     while (ret != NULL && bignum_cmp(ret, curve->p) >= 0)
529     {
530         tmp = bigsub(ret, curve->p);
531         freebn(ret);
532         ret = tmp;
533     }
534
535     return ret;
536 }
537
538 static Bignum ecf_double(const Bignum a, const struct ec_curve *curve)
539 {
540     Bignum ret = bignum_lshift(a, 1);
541     if (!ret) return NULL;
542     if (bignum_cmp(ret, curve->p) >= 0)
543     {
544         Bignum tmp = bigsub(ret, curve->p);
545         freebn(ret);
546         return tmp;
547     }
548     else
549     {
550         return ret;
551     }
552 }
553
554 /* ----------------------------------------------------------------------
555  * Memory functions
556  */
557
558 void ec_point_free(struct ec_point *point)
559 {
560     if (point == NULL) return;
561     point->curve = 0;
562     if (point->x) freebn(point->x);
563     if (point->y) freebn(point->y);
564     if (point->z) freebn(point->z);
565     point->infinity = 0;
566     sfree(point);
567 }
568
569 static struct ec_point *ec_point_new(const struct ec_curve *curve,
570                                      const Bignum x, const Bignum y, const Bignum z,
571                                      unsigned char infinity)
572 {
573     struct ec_point *point = snewn(1, struct ec_point);
574     point->curve = curve;
575     point->x = x;
576     point->y = y;
577     point->z = z;
578     point->infinity = infinity ? 1 : 0;
579     return point;
580 }
581
582 static struct ec_point *ec_point_copy(const struct ec_point *a)
583 {
584     if (a == NULL) return NULL;
585     return ec_point_new(a->curve,
586                         a->x ? copybn(a->x) : NULL,
587                         a->y ? copybn(a->y) : NULL,
588                         a->z ? copybn(a->z) : NULL,
589                         a->infinity);
590 }
591
592 static int ec_point_verify(const struct ec_point *a)
593 {
594     if (a->infinity) {
595         return 1;
596     } else if (a->curve->type == EC_WEIERSTRASS) {
597         /* Verify y^2 = x^3 + ax + b */
598         int ret = 0;
599
600         Bignum lhs = NULL, x3 = NULL, ax = NULL, x3ax = NULL, x3axm = NULL, x3axb = NULL, rhs = NULL;
601
602         Bignum Three = bignum_from_long(3);
603         if (!Three) return 0;
604
605         lhs = modmul(a->y, a->y, a->curve->p);
606         if (!lhs) goto error;
607
608         /* This uses montgomery multiplication to optimise */
609         x3 = modpow(a->x, Three, a->curve->p);
610         freebn(Three);
611         if (!x3) goto error;
612         ax = modmul(a->curve->w.a, a->x, a->curve->p);
613         if (!ax) goto error;
614         x3ax = bigadd(x3, ax);
615         if (!x3ax) goto error;
616         freebn(x3); x3 = NULL;
617         freebn(ax); ax = NULL;
618         x3axm = bigmod(x3ax, a->curve->p);
619         if (!x3axm) goto error;
620         freebn(x3ax); x3ax = NULL;
621         x3axb = bigadd(x3axm, a->curve->w.b);
622         if (!x3axb) goto error;
623         freebn(x3axm); x3axm = NULL;
624         rhs = bigmod(x3axb, a->curve->p);
625         if (!rhs) goto error;
626         freebn(x3axb);
627
628         ret = bignum_cmp(lhs, rhs) ? 0 : 1;
629         freebn(lhs);
630         freebn(rhs);
631
632         return ret;
633
634       error:
635         if (x3) freebn(x3);
636         if (ax) freebn(ax);
637         if (x3ax) freebn(x3ax);
638         if (x3axm) freebn(x3axm);
639         if (x3axb) freebn(x3axb);
640         if (lhs) freebn(lhs);
641         return 0;
642     } else {
643         return 0;
644     }
645 }
646
647 /* ----------------------------------------------------------------------
648  * Elliptic curve point maths
649  */
650
651 /* Returns 1 on success and 0 on memory error */
652 static int ecp_normalise(struct ec_point *a)
653 {
654     if (!a) {
655         /* No point */
656         return 0;
657     }
658
659     if (a->infinity) {
660         /* Point is at infinity - i.e. normalised */
661         return 1;
662     }
663
664     if (a->curve->type == EC_WEIERSTRASS) {
665         /* In Jacobian Coordinates the triple (X, Y, Z) represents
666            the affine point (X / Z^2, Y / Z^3) */
667
668         Bignum Z2, Z2inv, Z3, Z3inv, tx, ty;
669
670         if (!a->x || !a->y) {
671             /* No point defined */
672             return 0;
673         } else if (!a->z) {
674             /* Already normalised */
675             return 1;
676         }
677
678         Z2 = ecf_square(a->z, a->curve);
679         if (!Z2) {
680             return 0;
681         }
682         Z2inv = modinv(Z2, a->curve->p);
683         if (!Z2inv) {
684             freebn(Z2);
685             return 0;
686         }
687         tx = modmul(a->x, Z2inv, a->curve->p);
688         freebn(Z2inv);
689         if (!tx) {
690             freebn(Z2);
691             return 0;
692         }
693
694         Z3 = modmul(Z2, a->z, a->curve->p);
695         freebn(Z2);
696         if (!Z3) {
697             freebn(tx);
698             return 0;
699         }
700         Z3inv = modinv(Z3, a->curve->p);
701         freebn(Z3);
702         if (!Z3inv) {
703             freebn(tx);
704             return 0;
705         }
706         ty = modmul(a->y, Z3inv, a->curve->p);
707         freebn(Z3inv);
708         if (!ty) {
709             freebn(tx);
710             return 0;
711         }
712
713         freebn(a->x);
714         a->x = tx;
715         freebn(a->y);
716         a->y = ty;
717         freebn(a->z);
718         a->z = NULL;
719         return 1;
720     } else if (a->curve->type == EC_MONTGOMERY) {
721         /* In Montgomery (X : Z) represents the x co-ord (X / Z, ?) */
722
723         Bignum tmp, tmp2;
724
725         if (!a->x) {
726             /* No point defined */
727             return 0;
728         } else if (!a->z) {
729             /* Already normalised */
730             return 1;
731         }
732
733         tmp = modinv(a->z, a->curve->p);
734         if (!tmp) {
735             return 0;
736         }
737         tmp2 = modmul(a->x, tmp, a->curve->p);
738         freebn(tmp);
739         if (!tmp2) {
740             return 0;
741         }
742
743         freebn(a->z);
744         a->z = NULL;
745         freebn(a->x);
746         a->x = tmp2;
747         return 1;
748     } else {
749         return 0;
750     }
751 }
752
753 static struct ec_point *ecp_doublew(const struct ec_point *a, const int aminus3)
754 {
755     Bignum S, M, outx, outy, outz;
756
757     if (bignum_cmp(a->y, Zero) == 0)
758     {
759         /* Identity */
760         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
761     }
762
763     /* S = 4*X*Y^2 */
764     {
765         Bignum Y2, XY2, _2XY2;
766
767         Y2 = ecf_square(a->y, a->curve);
768         if (!Y2) {
769             return NULL;
770         }
771         XY2 = modmul(a->x, Y2, a->curve->p);
772         freebn(Y2);
773         if (!XY2) {
774             return NULL;
775         }
776
777         _2XY2 = ecf_double(XY2, a->curve);
778         freebn(XY2);
779         if (!_2XY2) {
780             return NULL;
781         }
782         S = ecf_double(_2XY2, a->curve);
783         freebn(_2XY2);
784         if (!S) {
785             return NULL;
786         }
787     }
788
789     /* Faster calculation if a = -3 */
790     if (aminus3) {
791         /* if a = -3, then M can also be calculated as M = 3*(X + Z^2)*(X - Z^2) */
792         Bignum Z2, XpZ2, XmZ2, second;
793
794         if (a->z == NULL) {
795             Z2 = copybn(One);
796         } else {
797             Z2 = ecf_square(a->z, a->curve);
798         }
799         if (!Z2) {
800             freebn(S);
801             return NULL;
802         }
803
804         XpZ2 = ecf_add(a->x, Z2, a->curve);
805         if (!XpZ2) {
806             freebn(S);
807             freebn(Z2);
808             return NULL;
809         }
810         XmZ2 = modsub(a->x, Z2, a->curve->p);
811         freebn(Z2);
812         if (!XmZ2) {
813             freebn(S);
814             freebn(XpZ2);
815             return NULL;
816         }
817
818         second = modmul(XpZ2, XmZ2, a->curve->p);
819         freebn(XpZ2);
820         freebn(XmZ2);
821         if (!second) {
822             freebn(S);
823             return NULL;
824         }
825
826         M = ecf_treble(second, a->curve);
827         freebn(second);
828         if (!M) {
829             freebn(S);
830             return NULL;
831         }
832     } else {
833         /* M = 3*X^2 + a*Z^4 */
834         Bignum _3X2, X2, aZ4;
835
836         if (a->z == NULL) {
837             aZ4 = copybn(a->curve->w.a);
838         } else {
839             Bignum Z2, Z4;
840
841             Z2 = ecf_square(a->z, a->curve);
842             if (!Z2) {
843                 freebn(S);
844                 return NULL;
845             }
846             Z4 = ecf_square(Z2, a->curve);
847             freebn(Z2);
848             if (!Z4) {
849                 freebn(S);
850                 return NULL;
851             }
852             aZ4 = modmul(a->curve->w.a, Z4, a->curve->p);
853             freebn(Z4);
854         }
855         if (!aZ4) {
856             freebn(S);
857             return NULL;
858         }
859
860         X2 = modmul(a->x, a->x, a->curve->p);
861         if (!X2) {
862             freebn(S);
863             freebn(aZ4);
864             return NULL;
865         }
866         _3X2 = ecf_treble(X2, a->curve);
867         freebn(X2);
868         if (!_3X2) {
869             freebn(S);
870             freebn(aZ4);
871             return NULL;
872         }
873         M = ecf_add(_3X2, aZ4, a->curve);
874         freebn(_3X2);
875         freebn(aZ4);
876         if (!M) {
877             freebn(S);
878             return NULL;
879         }
880     }
881
882     /* X' = M^2 - 2*S */
883     {
884         Bignum M2, _2S;
885
886         M2 = ecf_square(M, a->curve);
887         if (!M2) {
888             freebn(S);
889             freebn(M);
890             return NULL;
891         }
892
893         _2S = ecf_double(S, a->curve);
894         if (!_2S) {
895             freebn(M2);
896             freebn(S);
897             freebn(M);
898             return NULL;
899         }
900
901         outx = modsub(M2, _2S, a->curve->p);
902         freebn(M2);
903         freebn(_2S);
904         if (!outx) {
905             freebn(S);
906             freebn(M);
907             return NULL;
908         }
909     }
910
911     /* Y' = M*(S - X') - 8*Y^4 */
912     {
913         Bignum SX, MSX, Eight, Y2, Y4, _8Y4;
914
915         SX = modsub(S, outx, a->curve->p);
916         freebn(S);
917         if (!SX) {
918             freebn(M);
919             freebn(outx);
920             return NULL;
921         }
922         MSX = modmul(M, SX, a->curve->p);
923         freebn(SX);
924         freebn(M);
925         if (!MSX) {
926             freebn(outx);
927             return NULL;
928         }
929         Y2 = ecf_square(a->y, a->curve);
930         if (!Y2) {
931             freebn(outx);
932             freebn(MSX);
933             return NULL;
934         }
935         Y4 = ecf_square(Y2, a->curve);
936         freebn(Y2);
937         if (!Y4) {
938             freebn(outx);
939             freebn(MSX);
940             return NULL;
941         }
942         Eight = bignum_from_long(8);
943         if (!Eight) {
944             freebn(outx);
945             freebn(MSX);
946             freebn(Y4);
947             return NULL;
948         }
949         _8Y4 = modmul(Eight, Y4, a->curve->p);
950         freebn(Eight);
951         freebn(Y4);
952         if (!_8Y4) {
953             freebn(outx);
954             freebn(MSX);
955             return NULL;
956         }
957         outy = modsub(MSX, _8Y4, a->curve->p);
958         freebn(MSX);
959         freebn(_8Y4);
960         if (!outy) {
961             freebn(outx);
962             return NULL;
963         }
964     }
965
966     /* Z' = 2*Y*Z */
967     {
968         Bignum YZ;
969
970         if (a->z == NULL) {
971             YZ = copybn(a->y);
972         } else {
973             YZ = modmul(a->y, a->z, a->curve->p);
974         }
975         if (!YZ) {
976             freebn(outx);
977             freebn(outy);
978             return NULL;
979         }
980
981         outz = ecf_double(YZ, a->curve);
982         freebn(YZ);
983         if (!outz) {
984             freebn(outx);
985             freebn(outy);
986             return NULL;
987         }
988     }
989
990     return ec_point_new(a->curve, outx, outy, outz, 0);
991 }
992
993 static struct ec_point *ecp_doublem(const struct ec_point *a)
994 {
995     Bignum z, outx, outz, xpz, xmz;
996
997     z = a->z;
998     if (!z) {
999         z = One;
1000     }
1001
1002     /* 4xz = (x + z)^2 - (x - z)^2 */
1003     {
1004         Bignum tmp;
1005
1006         tmp = ecf_add(a->x, z, a->curve);
1007         if (!tmp) {
1008             return NULL;
1009         }
1010         xpz = ecf_square(tmp, a->curve);
1011         freebn(tmp);
1012         if (!xpz) {
1013             return NULL;
1014         }
1015
1016         tmp = modsub(a->x, z, a->curve->p);
1017         if (!tmp) {
1018             freebn(xpz);
1019             return NULL;
1020         }
1021         xmz = ecf_square(tmp, a->curve);
1022         freebn(tmp);
1023         if (!xmz) {
1024             freebn(xpz);
1025             return NULL;
1026         }
1027     }
1028
1029     /* outx = (x + z)^2 * (x - z)^2 */
1030     outx = modmul(xpz, xmz, a->curve->p);
1031     if (!outx) {
1032         freebn(xpz);
1033         freebn(xmz);
1034         return NULL;
1035     }
1036
1037     /* outz = 4xz * ((x - z)^2 + ((A + 2) / 4)*4xz) */
1038     {
1039         Bignum _4xz, tmp, tmp2, tmp3;
1040
1041         tmp = bignum_from_long(2);
1042         if (!tmp) {
1043             freebn(xpz);
1044             freebn(outx);
1045             freebn(xmz);
1046             return NULL;
1047         }
1048         tmp2 = ecf_add(a->curve->m.a, tmp, a->curve);
1049         freebn(tmp);
1050         if (!tmp2) {
1051             freebn(xpz);
1052             freebn(outx);
1053             freebn(xmz);
1054             return NULL;
1055         }
1056
1057         _4xz = modsub(xpz, xmz, a->curve->p);
1058         freebn(xpz);
1059         if (!_4xz) {
1060             freebn(tmp2);
1061             freebn(outx);
1062             freebn(xmz);
1063             return NULL;
1064         }
1065         tmp = modmul(tmp2, _4xz, a->curve->p);
1066         freebn(tmp2);
1067         if (!tmp) {
1068             freebn(_4xz);
1069             freebn(outx);
1070             freebn(xmz);
1071             return NULL;
1072         }
1073
1074         tmp2 = bignum_from_long(4);
1075         if (!tmp2) {
1076             freebn(tmp);
1077             freebn(_4xz);
1078             freebn(outx);
1079             freebn(xmz);
1080             return NULL;
1081         }
1082         tmp3 = modinv(tmp2, a->curve->p);
1083         freebn(tmp2);
1084         if (!tmp3) {
1085             freebn(tmp);
1086             freebn(_4xz);
1087             freebn(outx);
1088             freebn(xmz);
1089             return NULL;
1090         }
1091         tmp2 = modmul(tmp, tmp3, a->curve->p);
1092         freebn(tmp);
1093         freebn(tmp3);
1094         if (!tmp2) {
1095             freebn(_4xz);
1096             freebn(outx);
1097             freebn(xmz);
1098             return NULL;
1099         }
1100
1101         tmp = ecf_add(xmz, tmp2, a->curve);
1102         freebn(xmz);
1103         freebn(tmp2);
1104         if (!tmp) {
1105             freebn(_4xz);
1106             freebn(outx);
1107             return NULL;
1108         }
1109         outz = modmul(_4xz, tmp, a->curve->p);
1110         freebn(_4xz);
1111         freebn(tmp);
1112         if (!outz) {
1113             freebn(outx);
1114             return NULL;
1115         }
1116     }
1117
1118     return ec_point_new(a->curve, outx, NULL, outz, 0);
1119 }
1120
1121 static struct ec_point *ecp_double(const struct ec_point *a, const int aminus3)
1122 {
1123     if (a->infinity)
1124     {
1125         /* Identity */
1126         return ec_point_new(a->curve, NULL, NULL, NULL, 1);
1127     }
1128
1129     if (a->curve->type == EC_WEIERSTRASS)
1130     {
1131         return ecp_doublew(a, aminus3);
1132     }
1133     else
1134     {
1135         return ecp_doublem(a);
1136     }
1137 }
1138
1139 static struct ec_point *ecp_addw(const struct ec_point *a,
1140                                  const struct ec_point *b,
1141                                  const int aminus3)
1142 {
1143     Bignum U1, U2, S1, S2, outx, outy, outz;
1144
1145     /* U1 = X1*Z2^2 */
1146     /* S1 = Y1*Z2^3 */
1147     if (b->z) {
1148         Bignum Z2, Z3;
1149
1150         Z2 = ecf_square(b->z, a->curve);
1151         if (!Z2) {
1152             return NULL;
1153         }
1154         U1 = modmul(a->x, Z2, a->curve->p);
1155         if (!U1) {
1156             freebn(Z2);
1157             return NULL;
1158         }
1159         Z3 = modmul(Z2, b->z, a->curve->p);
1160         freebn(Z2);
1161         if (!Z3) {
1162             freebn(U1);
1163             return NULL;
1164         }
1165         S1 = modmul(a->y, Z3, a->curve->p);
1166         freebn(Z3);
1167         if (!S1) {
1168             freebn(U1);
1169             return NULL;
1170         }
1171     } else {
1172         U1 = copybn(a->x);
1173         if (!U1) {
1174             return NULL;
1175         }
1176         S1 = copybn(a->y);
1177         if (!S1) {
1178             freebn(U1);
1179             return NULL;
1180         }
1181     }
1182
1183     /* U2 = X2*Z1^2 */
1184     /* S2 = Y2*Z1^3 */
1185     if (a->z) {
1186         Bignum Z2, Z3;
1187
1188         Z2 = ecf_square(a->z, b->curve);
1189         if (!Z2) {
1190             freebn(U1);
1191             freebn(S1);
1192             return NULL;
1193         }
1194         U2 = modmul(b->x, Z2, b->curve->p);
1195         if (!U2) {
1196             freebn(U1);
1197             freebn(S1);
1198             freebn(Z2);
1199             return NULL;
1200         }
1201         Z3 = modmul(Z2, a->z, b->curve->p);
1202         freebn(Z2);
1203         if (!Z3) {
1204             freebn(U1);
1205             freebn(S1);
1206             freebn(U2);
1207             return NULL;
1208         }
1209         S2 = modmul(b->y, Z3, b->curve->p);
1210         freebn(Z3);
1211         if (!S2) {
1212             freebn(U1);
1213             freebn(S1);
1214             freebn(U2);
1215             return NULL;
1216         }
1217     } else {
1218         U2 = copybn(b->x);
1219         if (!U2) {
1220             freebn(U1);
1221             freebn(S1);
1222             return NULL;
1223         }
1224         S2 = copybn(b->y);
1225         if (!S2) {
1226             freebn(U1);
1227             freebn(S1);
1228             freebn(U2);
1229             return NULL;
1230         }
1231     }
1232
1233     /* Check if multiplying by self */
1234     if (bignum_cmp(U1, U2) == 0)
1235     {
1236         freebn(U1);
1237         freebn(U2);
1238         if (bignum_cmp(S1, S2) == 0)
1239         {
1240             freebn(S1);
1241             freebn(S2);
1242             return ecp_double(a, aminus3);
1243         }
1244         else
1245         {
1246             freebn(S1);
1247             freebn(S2);
1248             /* Infinity */
1249             return ec_point_new(a->curve, NULL, NULL, NULL, 1);
1250         }
1251     }
1252
1253     {
1254         Bignum H, R, UH2, H3;
1255
1256         /* H = U2 - U1 */
1257         H = modsub(U2, U1, a->curve->p);
1258         freebn(U2);
1259         if (!H) {
1260             freebn(U1);
1261             freebn(S1);
1262             freebn(S2);
1263             return NULL;
1264         }
1265
1266         /* R = S2 - S1 */
1267         R = modsub(S2, S1, a->curve->p);
1268         freebn(S2);
1269         if (!R) {
1270             freebn(H);
1271             freebn(S1);
1272             freebn(U1);
1273             return NULL;
1274         }
1275
1276         /* X3 = R^2 - H^3 - 2*U1*H^2 */
1277         {
1278             Bignum R2, H2, _2UH2, first;
1279
1280             H2 = ecf_square(H, a->curve);
1281             if (!H2) {
1282                 freebn(U1);
1283                 freebn(S1);
1284                 freebn(H);
1285                 freebn(R);
1286                 return NULL;
1287             }
1288             UH2 = modmul(U1, H2, a->curve->p);
1289             freebn(U1);
1290             if (!UH2) {
1291                 freebn(H2);
1292                 freebn(S1);
1293                 freebn(H);
1294                 freebn(R);
1295                 return NULL;
1296             }
1297             H3 = modmul(H2, H, a->curve->p);
1298             freebn(H2);
1299             if (!H3) {
1300                 freebn(UH2);
1301                 freebn(S1);
1302                 freebn(H);
1303                 freebn(R);
1304                 return NULL;
1305             }
1306             R2 = ecf_square(R, a->curve);
1307             if (!R2) {
1308                 freebn(H3);
1309                 freebn(UH2);
1310                 freebn(S1);
1311                 freebn(H);
1312                 freebn(R);
1313                 return NULL;
1314             }
1315             _2UH2 = ecf_double(UH2, a->curve);
1316             if (!_2UH2) {
1317                 freebn(R2);
1318                 freebn(H3);
1319                 freebn(UH2);
1320                 freebn(S1);
1321                 freebn(H);
1322                 freebn(R);
1323                 return NULL;
1324             }
1325             first = modsub(R2, H3, a->curve->p);
1326             freebn(R2);
1327             if (!first) {
1328                 freebn(H3);
1329                 freebn(_2UH2);
1330                 freebn(UH2);
1331                 freebn(S1);
1332                 freebn(H);
1333                 freebn(R);
1334                 return NULL;
1335             }
1336             outx = modsub(first, _2UH2, a->curve->p);
1337             freebn(first);
1338             freebn(_2UH2);
1339             if (!outx) {
1340                 freebn(H3);
1341                 freebn(UH2);
1342                 freebn(S1);
1343                 freebn(H);
1344                 freebn(R);
1345                 return NULL;
1346             }
1347         }
1348
1349         /* Y3 = R*(U1*H^2 - X3) - S1*H^3 */
1350         {
1351             Bignum RUH2mX, UH2mX, SH3;
1352
1353             UH2mX = modsub(UH2, outx, a->curve->p);
1354             freebn(UH2);
1355             if (!UH2mX) {
1356                 freebn(outx);
1357                 freebn(H3);
1358                 freebn(S1);
1359                 freebn(H);
1360                 freebn(R);
1361                 return NULL;
1362             }
1363             RUH2mX = modmul(R, UH2mX, a->curve->p);
1364             freebn(UH2mX);
1365             freebn(R);
1366             if (!RUH2mX) {
1367                 freebn(outx);
1368                 freebn(H3);
1369                 freebn(S1);
1370                 freebn(H);
1371                 return NULL;
1372             }
1373             SH3 = modmul(S1, H3, a->curve->p);
1374             freebn(S1);
1375             freebn(H3);
1376             if (!SH3) {
1377                 freebn(RUH2mX);
1378                 freebn(outx);
1379                 freebn(H);
1380                 return NULL;
1381             }
1382
1383             outy = modsub(RUH2mX, SH3, a->curve->p);
1384             freebn(RUH2mX);
1385             freebn(SH3);
1386             if (!outy) {
1387                 freebn(outx);
1388                 freebn(H);
1389                 return NULL;
1390             }
1391         }
1392
1393         /* Z3 = H*Z1*Z2 */
1394         if (a->z && b->z) {
1395             Bignum ZZ;
1396
1397             ZZ = modmul(a->z, b->z, a->curve->p);
1398             if (!ZZ) {
1399                 freebn(outx);
1400                 freebn(outy);
1401                 freebn(H);
1402                 return NULL;
1403             }
1404             outz = modmul(H, ZZ, a->curve->p);
1405             freebn(H);
1406             freebn(ZZ);
1407             if (!outz) {
1408                 freebn(outx);
1409                 freebn(outy);
1410                 return NULL;
1411             }
1412         } else if (a->z) {
1413             outz = modmul(H, a->z, a->curve->p);
1414             freebn(H);
1415             if (!outz) {
1416                 freebn(outx);
1417                 freebn(outy);
1418                 return NULL;
1419             }
1420         } else if (b->z) {
1421             outz = modmul(H, b->z, a->curve->p);
1422             freebn(H);
1423             if (!outz) {
1424                 freebn(outx);
1425                 freebn(outy);
1426                 return NULL;
1427             }
1428         } else {
1429             outz = H;
1430         }
1431     }
1432
1433     return ec_point_new(a->curve, outx, outy, outz, 0);
1434 }
1435
1436 static struct ec_point *ecp_addm(const struct ec_point *a,
1437                                  const struct ec_point *b,
1438                                  const struct ec_point *base)
1439 {
1440     Bignum outx, outz, az, bz;
1441
1442     az = a->z;
1443     if (!az) {
1444         az = One;
1445     }
1446     bz = b->z;
1447     if (!bz) {
1448         bz = One;
1449     }
1450
1451     /* a-b is maintained at 1 due to Montgomery ladder implementation */
1452     /* Xa+b = Za-b * ((Xa - Za)*(Xb + Zb) + (Xa + Za)*(Xb - Zb))^2 */
1453     /* Za+b = Xa-b * ((Xa - Za)*(Xb + Zb) - (Xa + Za)*(Xb - Zb))^2 */
1454     {
1455         Bignum tmp, tmp2, tmp3, tmp4;
1456
1457         /* (Xa + Za) * (Xb - Zb) */
1458         tmp = ecf_add(a->x, az, a->curve);
1459         if (!tmp) {
1460             return NULL;
1461         }
1462         tmp2 = modsub(b->x, bz, a->curve->p);
1463         if (!tmp2) {
1464             freebn(tmp);
1465             return NULL;
1466         }
1467         tmp3 = modmul(tmp, tmp2, a->curve->p);
1468         freebn(tmp);
1469         freebn(tmp2);
1470         if (!tmp3) {
1471             return NULL;
1472         }
1473
1474         /* (Xa - Za) * (Xb + Zb) */
1475         tmp = modsub(a->x, az, a->curve->p);
1476         if (!tmp) {
1477             freebn(tmp3);
1478             return NULL;
1479         }
1480         tmp2 = ecf_add(b->x, bz, a->curve);
1481         if (!tmp2) {
1482             freebn(tmp);
1483             freebn(tmp3);
1484             return NULL;
1485         }
1486         tmp4 = modmul(tmp, tmp2, a->curve->p);
1487         freebn(tmp);
1488         freebn(tmp2);
1489         if (!tmp4) {
1490             freebn(tmp3);
1491             return NULL;
1492         }
1493
1494         tmp = ecf_add(tmp3, tmp4, a->curve);
1495         if (!tmp) {
1496             freebn(tmp3);
1497             freebn(tmp4);
1498             return NULL;
1499         }
1500         outx = ecf_square(tmp, a->curve);
1501         freebn(tmp);
1502         if (!outx) {
1503             freebn(tmp3);
1504             freebn(tmp4);
1505             return NULL;
1506         }
1507
1508         tmp = modsub(tmp3, tmp4, a->curve->p);
1509         freebn(tmp3);
1510         freebn(tmp4);
1511         if (!tmp) {
1512             freebn(outx);
1513             return NULL;
1514         }
1515         tmp2 = ecf_square(tmp, a->curve);
1516         freebn(tmp);
1517         if (!tmp2) {
1518             freebn(outx);
1519             return NULL;
1520         }
1521         outz = modmul(base->x, tmp2, a->curve->p);
1522         freebn(tmp2);
1523         if (!outz) {
1524             freebn(outx);
1525             return NULL;
1526         }
1527     }
1528
1529     return ec_point_new(a->curve, outx, NULL, outz, 0);
1530 }
1531
1532 static struct ec_point *ecp_add(const struct ec_point *a,
1533                                 const struct ec_point *b,
1534                                 const int aminus3)
1535 {
1536     if (a->curve != b->curve) {
1537         return NULL;
1538     }
1539
1540     /* Check if multiplying by infinity */
1541     if (a->infinity) return ec_point_copy(b);
1542     if (b->infinity) return ec_point_copy(a);
1543
1544     if (a->curve->type == EC_WEIERSTRASS)
1545     {
1546         return ecp_addw(a, b, aminus3);
1547     }
1548
1549     return NULL;
1550 }
1551
1552 static struct ec_point *ecp_mul_(const struct ec_point *a, const Bignum b, int aminus3)
1553 {
1554     struct ec_point *A, *ret;
1555     int bits, i;
1556
1557     A = ec_point_copy(a);
1558     ret = ec_point_new(a->curve, NULL, NULL, NULL, 1);
1559
1560     bits = bignum_bitcount(b);
1561     for (i = 0; ret != NULL && A != NULL && i < bits; ++i)
1562     {
1563         if (bignum_bit(b, i))
1564         {
1565             struct ec_point *tmp = ecp_add(ret, A, aminus3);
1566             ec_point_free(ret);
1567             ret = tmp;
1568         }
1569         if (i+1 != bits)
1570         {
1571             struct ec_point *tmp = ecp_double(A, aminus3);
1572             ec_point_free(A);
1573             A = tmp;
1574         }
1575     }
1576
1577     if (!A) {
1578         ec_point_free(ret);
1579         ret = NULL;
1580     } else {
1581         ec_point_free(A);
1582     }
1583
1584     return ret;
1585 }
1586
1587 static struct ec_point *ecp_mulw(const struct ec_point *a, const Bignum b)
1588 {
1589     struct ec_point *ret = ecp_mul_(a, b, ec_aminus3(a->curve));
1590
1591     if (!ecp_normalise(ret)) {
1592         ec_point_free(ret);
1593         return NULL;
1594     }
1595
1596     return ret;
1597 }
1598
1599 static struct ec_point *ecp_mulm(const struct ec_point *p, const Bignum n)
1600 {
1601     struct ec_point *P1, *P2;
1602     int bits, i;
1603
1604     /* P1 <- P and P2 <- [2]P */
1605     P2 = ecp_double(p, 0);
1606     if (!P2) {
1607         return NULL;
1608     }
1609     P1 = ec_point_copy(p);
1610     if (!P1) {
1611         ec_point_free(P2);
1612         return NULL;
1613     }
1614
1615     /* for i = bits âˆ’ 2 down to 0 */
1616     bits = bignum_bitcount(n);
1617     for (i = bits - 2; P1 != NULL && P2 != NULL && i >= 0; --i)
1618     {
1619         if (!bignum_bit(n, i))
1620         {
1621             /* P2 <- P1 + P2 */
1622             struct ec_point *tmp = ecp_addm(P1, P2, p);
1623             ec_point_free(P2);
1624             P2 = tmp;
1625
1626             /* P1 <- [2]P1 */
1627             tmp = ecp_double(P1, 0);
1628             ec_point_free(P1);
1629             P1 = tmp;
1630         }
1631         else
1632         {
1633             /* P1 <- P1 + P2 */
1634             struct ec_point *tmp = ecp_addm(P1, P2, p);
1635             ec_point_free(P1);
1636             P1 = tmp;
1637
1638             /* P2 <- [2]P2 */
1639             tmp = ecp_double(P2, 0);
1640             ec_point_free(P2);
1641             P2 = tmp;
1642         }
1643     }
1644
1645     if (!P2) {
1646         if (P1) ec_point_free(P1);
1647         P1 = NULL;
1648     } else {
1649         ec_point_free(P2);
1650     }
1651
1652     if (!ecp_normalise(P1)) {
1653         ec_point_free(P1);
1654         return NULL;
1655     }
1656
1657     return P1;
1658 }
1659
1660 /* Not static because it is used by sshecdsag.c to generate a new key */
1661 struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b)
1662 {
1663     if (a->curve->type == EC_WEIERSTRASS) {
1664         return ecp_mulw(a, b);
1665     } else {
1666         return ecp_mulm(a, b);
1667     }
1668 }
1669
1670 static struct ec_point *ecp_summul(const Bignum a, const Bignum b,
1671                                    const struct ec_point *point)
1672 {
1673     struct ec_point *aG, *bP, *ret;
1674     int aminus3;
1675
1676     if (point->curve->type != EC_WEIERSTRASS) {
1677         return NULL;
1678     }
1679
1680     aminus3 = ec_aminus3(point->curve);
1681
1682     aG = ecp_mul_(&point->curve->w.G, a, aminus3);
1683     if (!aG) return NULL;
1684     bP = ecp_mul_(point, b, aminus3);
1685     if (!bP) {
1686         ec_point_free(aG);
1687         return NULL;
1688     }
1689
1690     ret = ecp_add(aG, bP, aminus3);
1691
1692     ec_point_free(aG);
1693     ec_point_free(bP);
1694
1695     if (!ecp_normalise(ret)) {
1696         ec_point_free(ret);
1697         return NULL;
1698     }
1699
1700     return ret;
1701 }
1702
1703 /* ----------------------------------------------------------------------
1704  * Public point from private
1705  */
1706
1707 struct ec_point *ec_public(const Bignum privateKey, const struct ec_curve *curve)
1708 {
1709     if (curve->type == EC_WEIERSTRASS) {
1710         return ecp_mul(&curve->w.G, privateKey);
1711     } else {
1712         return NULL;
1713     }
1714 }
1715
1716 /* ----------------------------------------------------------------------
1717  * Basic sign and verify routines
1718  */
1719
1720 static int _ecdsa_verify(const struct ec_point *publicKey,
1721                          const unsigned char *data, const int dataLen,
1722                          const Bignum r, const Bignum s)
1723 {
1724     int z_bits, n_bits;
1725     Bignum z;
1726     int valid = 0;
1727
1728     if (publicKey->curve->type != EC_WEIERSTRASS) {
1729         return 0;
1730     }
1731
1732     /* Sanity checks */
1733     if (bignum_cmp(r, Zero) == 0 || bignum_cmp(r, publicKey->curve->w.n) >= 0
1734         || bignum_cmp(s, Zero) == 0 || bignum_cmp(s, publicKey->curve->w.n) >= 0)
1735     {
1736         return 0;
1737     }
1738
1739     /* z = left most bitlen(curve->n) of data */
1740     z = bignum_from_bytes(data, dataLen);
1741     if (!z) return 0;
1742     n_bits = bignum_bitcount(publicKey->curve->w.n);
1743     z_bits = bignum_bitcount(z);
1744     if (z_bits > n_bits)
1745     {
1746         Bignum tmp = bignum_rshift(z, z_bits - n_bits);
1747         freebn(z);
1748         z = tmp;
1749         if (!z) return 0;
1750     }
1751
1752     /* Ensure z in range of n */
1753     {
1754         Bignum tmp = bigmod(z, publicKey->curve->w.n);
1755         freebn(z);
1756         z = tmp;
1757         if (!z) return 0;
1758     }
1759
1760     /* Calculate signature */
1761     {
1762         Bignum w, x, u1, u2;
1763         struct ec_point *tmp;
1764
1765         w = modinv(s, publicKey->curve->w.n);
1766         if (!w) {
1767             freebn(z);
1768             return 0;
1769         }
1770         u1 = modmul(z, w, publicKey->curve->w.n);
1771         if (!u1) {
1772             freebn(z);
1773             freebn(w);
1774             return 0;
1775         }
1776         u2 = modmul(r, w, publicKey->curve->w.n);
1777         freebn(w);
1778         if (!u2) {
1779             freebn(z);
1780             freebn(u1);
1781             return 0;
1782         }
1783
1784         tmp = ecp_summul(u1, u2, publicKey);
1785         freebn(u1);
1786         freebn(u2);
1787         if (!tmp) {
1788             freebn(z);
1789             return 0;
1790         }
1791
1792         x = bigmod(tmp->x, publicKey->curve->w.n);
1793         ec_point_free(tmp);
1794         if (!x) {
1795             freebn(z);
1796             return 0;
1797         }
1798
1799         valid = (bignum_cmp(r, x) == 0) ? 1 : 0;
1800         freebn(x);
1801     }
1802
1803     freebn(z);
1804
1805     return valid;
1806 }
1807
1808 static void _ecdsa_sign(const Bignum privateKey, const struct ec_curve *curve,
1809                         const unsigned char *data, const int dataLen,
1810                         Bignum *r, Bignum *s)
1811 {
1812     unsigned char digest[20];
1813     int z_bits, n_bits;
1814     Bignum z, k;
1815     struct ec_point *kG;
1816
1817     *r = NULL;
1818     *s = NULL;
1819
1820     if (curve->type != EC_WEIERSTRASS) {
1821         return;
1822     }
1823
1824     /* z = left most bitlen(curve->n) of data */
1825     z = bignum_from_bytes(data, dataLen);
1826     if (!z) return;
1827     n_bits = bignum_bitcount(curve->w.n);
1828     z_bits = bignum_bitcount(z);
1829     if (z_bits > n_bits)
1830     {
1831         Bignum tmp;
1832         tmp = bignum_rshift(z, z_bits - n_bits);
1833         freebn(z);
1834         z = tmp;
1835         if (!z) return;
1836     }
1837
1838     /* Generate k between 1 and curve->n, using the same deterministic
1839      * k generation system we use for conventional DSA. */
1840     SHA_Simple(data, dataLen, digest);
1841     k = dss_gen_k("ECDSA deterministic k generator", curve->w.n, privateKey,
1842                   digest, sizeof(digest));
1843     if (!k) return;
1844
1845     kG = ecp_mul(&curve->w.G, k);
1846     if (!kG) {
1847         freebn(z);
1848         freebn(k);
1849         return;
1850     }
1851
1852     /* r = kG.x mod n */
1853     *r = bigmod(kG->x, curve->w.n);
1854     ec_point_free(kG);
1855     if (!*r) {
1856         freebn(z);
1857         freebn(k);
1858         return;
1859     }
1860
1861     /* s = (z + r * priv)/k mod n */
1862     {
1863         Bignum rPriv, zMod, first, firstMod, kInv;
1864         rPriv = modmul(*r, privateKey, curve->w.n);
1865         if (!rPriv) {
1866             freebn(*r);
1867             freebn(z);
1868             freebn(k);
1869             return;
1870         }
1871         zMod = bigmod(z, curve->w.n);
1872         freebn(z);
1873         if (!zMod) {
1874             freebn(rPriv);
1875             freebn(*r);
1876             freebn(k);
1877             return;
1878         }
1879         first = bigadd(rPriv, zMod);
1880         freebn(rPriv);
1881         freebn(zMod);
1882         if (!first) {
1883             freebn(*r);
1884             freebn(k);
1885             return;
1886         }
1887         firstMod = bigmod(first, curve->w.n);
1888         freebn(first);
1889         if (!firstMod) {
1890             freebn(*r);
1891             freebn(k);
1892             return;
1893         }
1894         kInv = modinv(k, curve->w.n);
1895         freebn(k);
1896         if (!kInv) {
1897             freebn(firstMod);
1898             freebn(*r);
1899             return;
1900         }
1901         *s = modmul(firstMod, kInv, curve->w.n);
1902         freebn(firstMod);
1903         freebn(kInv);
1904         if (!*s) {
1905             freebn(*r);
1906             return;
1907         }
1908     }
1909 }
1910
1911 /* ----------------------------------------------------------------------
1912  * Misc functions
1913  */
1914
1915 static void getstring(const char **data, int *datalen,
1916                       const char **p, int *length)
1917 {
1918     *p = NULL;
1919     if (*datalen < 4)
1920         return;
1921     *length = toint(GET_32BIT(*data));
1922     if (*length < 0)
1923         return;
1924     *datalen -= 4;
1925     *data += 4;
1926     if (*datalen < *length)
1927         return;
1928     *p = *data;
1929     *data += *length;
1930     *datalen -= *length;
1931 }
1932
1933 static Bignum getmp(const char **data, int *datalen)
1934 {
1935     const char *p;
1936     int length;
1937
1938     getstring(data, datalen, &p, &length);
1939     if (!p)
1940         return NULL;
1941     if (p[0] & 0x80)
1942         return NULL;                   /* negative mp */
1943     return bignum_from_bytes((unsigned char *)p, length);
1944 }
1945
1946 static int decodepoint(const char *p, int length, struct ec_point *point)
1947 {
1948     if (length < 1 || p[0] != 0x04) /* Only support uncompressed point */
1949         return 0;
1950     /* Skip compression flag */
1951     ++p;
1952     --length;
1953     /* The two values must be equal length */
1954     if (length % 2 != 0) {
1955         point->x = NULL;
1956         point->y = NULL;
1957         point->z = NULL;
1958         return 0;
1959     }
1960     length = length / 2;
1961     point->x = bignum_from_bytes((unsigned char *)p, length);
1962     if (!point->x) return 0;
1963     p += length;
1964     point->y = bignum_from_bytes((unsigned char *)p, length);
1965     if (!point->y) {
1966         freebn(point->x);
1967         point->x = NULL;
1968         return 0;
1969     }
1970     point->z = NULL;
1971
1972     /* Verify the point is on the curve */
1973     if (!ec_point_verify(point)) {
1974         freebn(point->x);
1975         point->x = NULL;
1976         freebn(point->y);
1977         point->y = NULL;
1978         return 0;
1979     }
1980
1981     return 1;
1982 }
1983
1984 static int getmppoint(const char **data, int *datalen, struct ec_point *point)
1985 {
1986     const char *p;
1987     int length;
1988
1989     getstring(data, datalen, &p, &length);
1990     if (!p) return 0;
1991     return decodepoint(p, length, point);
1992 }
1993
1994 /* ----------------------------------------------------------------------
1995  * Exposed ECDSA interface
1996  */
1997
1998 static void ecdsa_freekey(void *key)
1999 {
2000     struct ec_key *ec = (struct ec_key *) key;
2001     if (!ec) return;
2002
2003     if (ec->publicKey.x)
2004         freebn(ec->publicKey.x);
2005     if (ec->publicKey.y)
2006         freebn(ec->publicKey.y);
2007     if (ec->publicKey.z)
2008         freebn(ec->publicKey.z);
2009     if (ec->privateKey)
2010         freebn(ec->privateKey);
2011     sfree(ec);
2012 }
2013
2014 static void *ecdsa_newkey(const char *data, int len)
2015 {
2016     const char *p;
2017     int slen;
2018     struct ec_key *ec;
2019     struct ec_curve *curve;
2020
2021     getstring(&data, &len, &p, &slen);
2022
2023     if (!p) {
2024         return NULL;
2025     }
2026     curve = ec_name_to_curve(p, slen);
2027     if (!curve) return NULL;
2028
2029     if (curve->type != EC_WEIERSTRASS) {
2030         return NULL;
2031     }
2032
2033     /* Curve name is duplicated for Weierstrass form */
2034     if (curve->type == EC_WEIERSTRASS) {
2035         getstring(&data, &len, &p, &slen);
2036         if (curve != ec_name_to_curve(p, slen)) return NULL;
2037     }
2038
2039     ec = snew(struct ec_key);
2040
2041     ec->publicKey.curve = curve;
2042     ec->publicKey.infinity = 0;
2043     ec->publicKey.x = NULL;
2044     ec->publicKey.y = NULL;
2045     ec->publicKey.z = NULL;
2046     if (!getmppoint(&data, &len, &ec->publicKey)) {
2047         ecdsa_freekey(ec);
2048         return NULL;
2049     }
2050     ec->privateKey = NULL;
2051
2052     if (!ec->publicKey.x || !ec->publicKey.y ||
2053         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2054         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2055     {
2056         ecdsa_freekey(ec);
2057         ec = NULL;
2058     }
2059
2060     return ec;
2061 }
2062
2063 static char *ecdsa_fmtkey(void *key)
2064 {
2065     struct ec_key *ec = (struct ec_key *) key;
2066     char *p;
2067     int len, i, pos, nibbles;
2068     static const char hex[] = "0123456789abcdef";
2069     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2070         return NULL;
2071
2072     pos = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
2073     if (pos == 0) return NULL;
2074
2075     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
2076     len += pos; /* Curve name */
2077     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
2078     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
2079     p = snewn(len, char);
2080
2081     pos = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, (unsigned char*)p, pos);
2082     pos += sprintf(p + pos, ",0x");
2083     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
2084     if (nibbles < 1)
2085         nibbles = 1;
2086     for (i = nibbles; i--;) {
2087         p[pos++] =
2088             hex[(bignum_byte(ec->publicKey.x, i / 2) >> (4 * (i % 2))) & 0xF];
2089     }
2090     pos += sprintf(p + pos, ",0x");
2091     nibbles = (3 + bignum_bitcount(ec->publicKey.y)) / 4;
2092     if (nibbles < 1)
2093         nibbles = 1;
2094     for (i = nibbles; i--;) {
2095         p[pos++] =
2096             hex[(bignum_byte(ec->publicKey.y, i / 2) >> (4 * (i % 2))) & 0xF];
2097     }
2098     p[pos] = '\0';
2099     return p;
2100 }
2101
2102 static unsigned char *ecdsa_public_blob(void *key, int *len)
2103 {
2104     struct ec_key *ec = (struct ec_key *) key;
2105     int pointlen, bloblen, fullnamelen, namelen;
2106     int i;
2107     unsigned char *blob, *p;
2108
2109     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2110         fullnamelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
2111         if (fullnamelen == 0) return NULL;
2112         namelen = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
2113         if (namelen == 0) return NULL;
2114
2115         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2116
2117         /*
2118          * string "ecdsa-sha2-<name>", string "<name>", 0x04 point x, y.
2119          */
2120         bloblen = 4 + fullnamelen + 4 + namelen + 4 + 1 + (pointlen * 2);
2121         blob = snewn(bloblen, unsigned char);
2122
2123         p = blob;
2124         PUT_32BIT(p, fullnamelen);
2125         p += 4;
2126         p += ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, p, fullnamelen);
2127         PUT_32BIT(p, namelen);
2128         p += 4;
2129         p += ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, p, namelen);
2130         PUT_32BIT(p, (2 * pointlen) + 1);
2131         p += 4;
2132         *p++ = 0x04;
2133         for (i = pointlen; i--;) {
2134             *p++ = bignum_byte(ec->publicKey.x, i);
2135         }
2136         for (i = pointlen; i--;) {
2137             *p++ = bignum_byte(ec->publicKey.y, i);
2138         }
2139     } else {
2140         return NULL;
2141     }
2142
2143     assert(p == blob + bloblen);
2144     *len = bloblen;
2145
2146     return blob;
2147 }
2148
2149 static unsigned char *ecdsa_private_blob(void *key, int *len)
2150 {
2151     struct ec_key *ec = (struct ec_key *) key;
2152     int keylen, bloblen;
2153     int i;
2154     unsigned char *blob, *p;
2155
2156     if (!ec->privateKey) return NULL;
2157
2158     keylen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2159
2160     /*
2161      * mpint privateKey. Total 4 + keylen.
2162      */
2163     bloblen = 4 + keylen;
2164     blob = snewn(bloblen, unsigned char);
2165
2166     p = blob;
2167     PUT_32BIT(p, keylen);
2168     p += 4;
2169     for (i = keylen; i--;)
2170         *p++ = bignum_byte(ec->privateKey, i);
2171
2172     assert(p == blob + bloblen);
2173     *len = bloblen;
2174     return blob;
2175 }
2176
2177 static void *ecdsa_createkey(const unsigned char *pub_blob, int pub_len,
2178                              const unsigned char *priv_blob, int priv_len)
2179 {
2180     struct ec_key *ec;
2181     struct ec_point *publicKey;
2182     const char *pb = (const char *) priv_blob;
2183
2184     ec = (struct ec_key*)ecdsa_newkey((const char *) pub_blob, pub_len);
2185     if (!ec) {
2186         return NULL;
2187     }
2188
2189     ec->privateKey = getmp(&pb, &priv_len);
2190     if (!ec->privateKey) {
2191         ecdsa_freekey(ec);
2192         return NULL;
2193     }
2194
2195     /* Check that private key generates public key */
2196     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2197
2198     if (!publicKey ||
2199         bignum_cmp(publicKey->x, ec->publicKey.x) ||
2200         bignum_cmp(publicKey->y, ec->publicKey.y))
2201     {
2202         ecdsa_freekey(ec);
2203         ec = NULL;
2204     }
2205     ec_point_free(publicKey);
2206
2207     return ec;
2208 }
2209
2210 static void *ecdsa_openssh_createkey(const unsigned char **blob, int *len)
2211 {
2212     const char **b = (const char **) blob;
2213     const char *p;
2214     int slen;
2215     struct ec_key *ec;
2216     struct ec_curve *curve;
2217     struct ec_point *publicKey;
2218
2219     getstring(b, len, &p, &slen);
2220
2221     if (!p) {
2222         return NULL;
2223     }
2224     curve = ec_name_to_curve(p, slen);
2225     if (!curve) return NULL;
2226
2227     ec = snew(struct ec_key);
2228
2229     ec->publicKey.curve = curve;
2230     ec->publicKey.infinity = 0;
2231     ec->publicKey.x = NULL;
2232     ec->publicKey.y = NULL;
2233     ec->publicKey.z = NULL;
2234     if (!getmppoint(b, len, &ec->publicKey)) {
2235         ecdsa_freekey(ec);
2236         return NULL;
2237     }
2238     ec->privateKey = NULL;
2239
2240     if (!ec->publicKey.x || !ec->publicKey.y ||
2241         bignum_cmp(ec->publicKey.x, curve->p) >= 0 ||
2242         bignum_cmp(ec->publicKey.y, curve->p) >= 0)
2243     {
2244         ecdsa_freekey(ec);
2245         return NULL;
2246     }
2247
2248     ec->privateKey = getmp(b, len);
2249     if (ec->privateKey == NULL)
2250     {
2251         ecdsa_freekey(ec);
2252         return NULL;
2253     }
2254
2255     /* Now check that the private key makes the public key */
2256     publicKey = ec_public(ec->privateKey, ec->publicKey.curve);
2257     if (!publicKey)
2258     {
2259         ecdsa_freekey(ec);
2260         return NULL;
2261     }
2262
2263     if (bignum_cmp(ec->publicKey.x, publicKey->x) ||
2264         bignum_cmp(ec->publicKey.y, publicKey->y))
2265     {
2266         /* Private key doesn't make the public key on the given curve */
2267         ecdsa_freekey(ec);
2268         ec_point_free(publicKey);
2269         return NULL;
2270     }
2271
2272     ec_point_free(publicKey);
2273
2274     return ec;
2275 }
2276
2277 static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
2278 {
2279     struct ec_key *ec = (struct ec_key *) key;
2280
2281     int pointlen;
2282     int namelen;
2283     int bloblen;
2284     int i;
2285
2286
2287     if (ec->publicKey.curve->type != EC_WEIERSTRASS) {
2288         return 0;
2289     }
2290
2291     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2292     namelen = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
2293     bloblen =
2294         4 + namelen /* <LEN> nistpXXX */
2295         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
2296         + ssh2_bignum_length(ec->privateKey);
2297
2298     if (bloblen > len)
2299         return bloblen;
2300
2301     bloblen = 0;
2302
2303     PUT_32BIT(blob+bloblen, namelen);
2304     bloblen += 4;
2305
2306     bloblen += ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, blob+bloblen, namelen);
2307
2308     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
2309     bloblen += 4;
2310     blob[bloblen++] = 0x04;
2311     for (i = pointlen; i--; )
2312         blob[bloblen++] = bignum_byte(ec->publicKey.x, i);
2313     for (i = pointlen; i--; )
2314         blob[bloblen++] = bignum_byte(ec->publicKey.y, i);
2315
2316     pointlen = (bignum_bitcount(ec->privateKey) + 8) / 8;
2317     PUT_32BIT(blob+bloblen, pointlen);
2318     bloblen += 4;
2319     for (i = pointlen; i--; )
2320         blob[bloblen++] = bignum_byte(ec->privateKey, i);
2321
2322     return bloblen;
2323 }
2324
2325 static int ecdsa_pubkey_bits(const void *blob, int len)
2326 {
2327     struct ec_key *ec;
2328     int ret;
2329
2330     ec = (struct ec_key*)ecdsa_newkey((const char *) blob, len);
2331     if (!ec)
2332         return -1;
2333     ret = ec->publicKey.curve->fieldBits;
2334     ecdsa_freekey(ec);
2335
2336     return ret;
2337 }
2338
2339 static char *ecdsa_fingerprint(void *key)
2340 {
2341     struct ec_key *ec = (struct ec_key *) key;
2342     struct MD5Context md5c;
2343     unsigned char digest[16], lenbuf[4];
2344     char *ret;
2345     unsigned char *name, *fullname;
2346     int pointlen, namelen, fullnamelen, i, j;
2347
2348     MD5Init(&md5c);
2349
2350     namelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
2351     name = snewn(namelen, unsigned char);
2352     ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, name, namelen);
2353
2354     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2355         fullnamelen = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
2356         fullname = snewn(namelen, unsigned char);
2357         ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, fullname, fullnamelen);
2358
2359         PUT_32BIT(lenbuf, fullnamelen);
2360         MD5Update(&md5c, lenbuf, 4);
2361         MD5Update(&md5c, fullname, fullnamelen);
2362         sfree(fullname);
2363
2364         PUT_32BIT(lenbuf, namelen);
2365         MD5Update(&md5c, lenbuf, 4);
2366         MD5Update(&md5c, name, namelen);
2367
2368         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2369         PUT_32BIT(lenbuf, 1 + (pointlen * 2));
2370         MD5Update(&md5c, lenbuf, 4);
2371         MD5Update(&md5c, (const unsigned char *)"\x04", 1);
2372         for (i = pointlen; i--; ) {
2373             unsigned char c = bignum_byte(ec->publicKey.x, i);
2374             MD5Update(&md5c, &c, 1);
2375         }
2376         for (i = pointlen; i--; ) {
2377             unsigned char c = bignum_byte(ec->publicKey.y, i);
2378             MD5Update(&md5c, &c, 1);
2379         }
2380     } else {
2381         sfree(name);
2382         return NULL;
2383     }
2384
2385     MD5Final(digest, &md5c);
2386
2387     ret = snewn(namelen + 1 + (16 * 3), char);
2388
2389     i = 0;
2390     memcpy(ret, name, namelen);
2391     i += namelen;
2392     sfree(name);
2393     ret[i++] = ' ';
2394     for (j = 0; j < 16; j++) {
2395         i += sprintf(ret + i, "%s%02x", j ? ":" : "", digest[j]);
2396     }
2397
2398     return ret;
2399 }
2400
2401 static int ecdsa_verifysig(void *key, const char *sig, int siglen,
2402                            const char *data, int datalen)
2403 {
2404     struct ec_key *ec = (struct ec_key *) key;
2405     const char *p;
2406     int slen;
2407     unsigned char digest[512 / 8];
2408     int digestLen;
2409     int ret;
2410     Bignum r, s;
2411
2412     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
2413         return 0;
2414
2415     /* Check the signature curve matches the key curve */
2416     getstring(&sig, &siglen, &p, &slen);
2417     if (!p) {
2418         return 0;
2419     }
2420     if (ec->publicKey.curve != ec_name_to_curve(p, slen)) {
2421         return 0;
2422     }
2423
2424     getstring(&sig, &siglen, &p, &slen);
2425
2426
2427     r = getmp(&p, &slen);
2428     if (!r) return 0;
2429     s = getmp(&p, &slen);
2430     if (!s) {
2431         freebn(r);
2432         return 0;
2433     }
2434
2435     /* Perform correct hash function depending on curve size */
2436     if (ec->publicKey.curve->fieldBits <= 256) {
2437         SHA256_Simple(data, datalen, digest);
2438         digestLen = 256 / 8;
2439     } else if (ec->publicKey.curve->fieldBits <= 384) {
2440         SHA384_Simple(data, datalen, digest);
2441         digestLen = 384 / 8;
2442     } else {
2443         SHA512_Simple(data, datalen, digest);
2444         digestLen = 512 / 8;
2445     }
2446
2447     /* Verify the signature */
2448     ret = _ecdsa_verify(&ec->publicKey, digest, digestLen, r, s);
2449
2450     freebn(r);
2451     freebn(s);
2452
2453     return ret;
2454 }
2455
2456 static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
2457                                  int *siglen)
2458 {
2459     struct ec_key *ec = (struct ec_key *) key;
2460     unsigned char digest[512 / 8];
2461     int digestLen;
2462     Bignum r = NULL, s = NULL;
2463     unsigned char *buf, *p;
2464     int rlen, slen, namelen;
2465     int i;
2466
2467     if (!ec->privateKey || !ec->publicKey.curve) {
2468         return NULL;
2469     }
2470
2471     /* Perform correct hash function depending on curve size */
2472     if (ec->publicKey.curve->fieldBits <= 256) {
2473         SHA256_Simple(data, datalen, digest);
2474         digestLen = 256 / 8;
2475     } else if (ec->publicKey.curve->fieldBits <= 384) {
2476         SHA384_Simple(data, datalen, digest);
2477         digestLen = 384 / 8;
2478     } else {
2479         SHA512_Simple(data, datalen, digest);
2480         digestLen = 512 / 8;
2481     }
2482
2483     /* Do the signature */
2484     _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
2485     if (!r || !s) {
2486         if (r) freebn(r);
2487         if (s) freebn(s);
2488         return NULL;
2489     }
2490
2491     rlen = (bignum_bitcount(r) + 8) / 8;
2492     slen = (bignum_bitcount(s) + 8) / 8;
2493
2494     namelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
2495
2496     /* Format the output */
2497     *siglen = 8+namelen+rlen+slen+8;
2498     buf = snewn(*siglen, unsigned char);
2499     p = buf;
2500     PUT_32BIT(p, namelen);
2501     p += 4;
2502     p += ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, p, namelen);
2503     PUT_32BIT(p, rlen + slen + 8);
2504     p += 4;
2505     PUT_32BIT(p, rlen);
2506     p += 4;
2507     for (i = rlen; i--;)
2508         *p++ = bignum_byte(r, i);
2509     PUT_32BIT(p, slen);
2510     p += 4;
2511     for (i = slen; i--;)
2512         *p++ = bignum_byte(s, i);
2513
2514     freebn(r);
2515     freebn(s);
2516
2517     return buf;
2518 }
2519
2520 const struct ssh_signkey ssh_ecdsa_nistp256 = {
2521     ecdsa_newkey,
2522     ecdsa_freekey,
2523     ecdsa_fmtkey,
2524     ecdsa_public_blob,
2525     ecdsa_private_blob,
2526     ecdsa_createkey,
2527     ecdsa_openssh_createkey,
2528     ecdsa_openssh_fmtkey,
2529     3 /* curve name, point, private exponent */,
2530     ecdsa_pubkey_bits,
2531     ecdsa_fingerprint,
2532     ecdsa_verifysig,
2533     ecdsa_sign,
2534     "ecdsa-sha2-nistp256",
2535     "ecdsa-sha2-nistp256",
2536 };
2537
2538 const struct ssh_signkey ssh_ecdsa_nistp384 = {
2539     ecdsa_newkey,
2540     ecdsa_freekey,
2541     ecdsa_fmtkey,
2542     ecdsa_public_blob,
2543     ecdsa_private_blob,
2544     ecdsa_createkey,
2545     ecdsa_openssh_createkey,
2546     ecdsa_openssh_fmtkey,
2547     3 /* curve name, point, private exponent */,
2548     ecdsa_pubkey_bits,
2549     ecdsa_fingerprint,
2550     ecdsa_verifysig,
2551     ecdsa_sign,
2552     "ecdsa-sha2-nistp384",
2553     "ecdsa-sha2-nistp384",
2554 };
2555
2556 const struct ssh_signkey ssh_ecdsa_nistp521 = {
2557     ecdsa_newkey,
2558     ecdsa_freekey,
2559     ecdsa_fmtkey,
2560     ecdsa_public_blob,
2561     ecdsa_private_blob,
2562     ecdsa_createkey,
2563     ecdsa_openssh_createkey,
2564     ecdsa_openssh_fmtkey,
2565     3 /* curve name, point, private exponent */,
2566     ecdsa_pubkey_bits,
2567     ecdsa_fingerprint,
2568     ecdsa_verifysig,
2569     ecdsa_sign,
2570     "ecdsa-sha2-nistp521",
2571     "ecdsa-sha2-nistp521",
2572 };
2573
2574 /* ----------------------------------------------------------------------
2575  * Exposed ECDH interface
2576  */
2577
2578 static Bignum ecdh_calculate(const Bignum private,
2579                              const struct ec_point *public)
2580 {
2581     struct ec_point *p;
2582     Bignum ret;
2583     p = ecp_mul(public, private);
2584     if (!p) return NULL;
2585     ret = p->x;
2586     p->x = NULL;
2587
2588     if (p->curve->type == EC_MONTGOMERY) {
2589         /* Do conversion in network byte order */
2590         int i;
2591         int bytes = (bignum_bitcount(ret)+7) / 8;
2592         unsigned char *byteorder = snewn(bytes, unsigned char);
2593         if (!byteorder) {
2594             ec_point_free(p);
2595             freebn(ret);
2596             return NULL;
2597         }
2598         for (i = 0; i < bytes; ++i) {
2599             byteorder[i] = bignum_byte(ret, i);
2600         }
2601         freebn(ret);
2602         ret = bignum_from_bytes(byteorder, bytes);
2603         sfree(byteorder);
2604     }
2605
2606     ec_point_free(p);
2607     return ret;
2608 }
2609
2610 void *ssh_ecdhkex_newkey(const char *name)
2611 {
2612     struct ec_curve *curve;
2613     struct ec_key *key;
2614     struct ec_point *publicKey;
2615
2616     curve = ec_name_to_curve(name, strlen(name));
2617
2618     key = snew(struct ec_key);
2619     if (!key) {
2620         return NULL;
2621     }
2622
2623     key->publicKey.curve = curve;
2624
2625     if (curve->type == EC_MONTGOMERY) {
2626         unsigned char bytes[32] = {0};
2627         int i;
2628
2629         for (i = 0; i < sizeof(bytes); ++i)
2630         {
2631             bytes[i] = (unsigned char)random_byte();
2632         }
2633         bytes[0] &= 248;
2634         bytes[31] &= 127;
2635         bytes[31] |= 64;
2636         key->privateKey = bignum_from_bytes(bytes, sizeof(bytes));
2637         for (i = 0; i < sizeof(bytes); ++i)
2638         {
2639             ((volatile char*)bytes)[i] = 0;
2640         }
2641         if (!key->privateKey) {
2642             sfree(key);
2643             return NULL;
2644         }
2645         publicKey = ecp_mul(&key->publicKey.curve->m.G, key->privateKey);
2646         if (!publicKey) {
2647             freebn(key->privateKey);
2648             sfree(key);
2649             return NULL;
2650         }
2651         key->publicKey.x = publicKey->x;
2652         key->publicKey.y = publicKey->y;
2653         key->publicKey.z = NULL;
2654         sfree(publicKey);
2655     } else {
2656         key->privateKey = bignum_random_in_range(One, key->publicKey.curve->w.n);
2657         if (!key->privateKey) {
2658             sfree(key);
2659             return NULL;
2660         }
2661         publicKey = ecp_mul(&key->publicKey.curve->w.G, key->privateKey);
2662         if (!publicKey) {
2663             freebn(key->privateKey);
2664             sfree(key);
2665             return NULL;
2666         }
2667         key->publicKey.x = publicKey->x;
2668         key->publicKey.y = publicKey->y;
2669         key->publicKey.z = NULL;
2670         sfree(publicKey);
2671     }
2672     return key;
2673 }
2674
2675 char *ssh_ecdhkex_getpublic(void *key, int *len)
2676 {
2677     struct ec_key *ec = (struct ec_key*)key;
2678     char *point, *p;
2679     int i;
2680     int pointlen;
2681
2682     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
2683
2684     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2685         *len = 1 + pointlen * 2;
2686     } else {
2687         *len = pointlen;
2688     }
2689     point = (char*)snewn(*len, char);
2690     if (!point) {
2691         return NULL;
2692     }
2693
2694     p = point;
2695     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2696         *p++ = 0x04;
2697         for (i = pointlen; i--;) {
2698             *p++ = bignum_byte(ec->publicKey.x, i);
2699         }
2700         for (i = pointlen; i--;) {
2701             *p++ = bignum_byte(ec->publicKey.y, i);
2702         }
2703     } else {
2704         for (i = 0; i < pointlen; ++i) {
2705             *p++ = bignum_byte(ec->publicKey.x, i);
2706         }
2707     }
2708
2709     return point;
2710 }
2711
2712 Bignum ssh_ecdhkex_getkey(void *key, char *remoteKey, int remoteKeyLen)
2713 {
2714     struct ec_key *ec = (struct ec_key*) key;
2715     struct ec_point remote;
2716     Bignum ret;
2717
2718     if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
2719         remote.curve = ec->publicKey.curve;
2720         remote.infinity = 0;
2721         if (!decodepoint(remoteKey, remoteKeyLen, &remote)) {
2722             return NULL;
2723         }
2724     } else {
2725         /* Point length has to be the same length */
2726         if (remoteKeyLen != (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8) {
2727             return NULL;
2728         }
2729
2730         remote.curve = ec->publicKey.curve;
2731         remote.infinity = 0;
2732         remote.x = bignum_from_bytes_le((unsigned char*)remoteKey, remoteKeyLen);
2733         remote.y = NULL;
2734         remote.z = NULL;
2735     }
2736
2737     ret = ecdh_calculate(ec->privateKey, &remote);
2738     if (remote.x) freebn(remote.x);
2739     if (remote.y) freebn(remote.y);
2740     return ret;
2741 }
2742
2743 void ssh_ecdhkex_freekey(void *key)
2744 {
2745     ecdsa_freekey(key);
2746 }
2747
2748 static const struct ssh_kex ssh_ec_kex_curve25519 = {
2749     "curve25519-sha256@libssh.org", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha256
2750 };
2751
2752 static const struct ssh_kex ssh_ec_kex_nistp256 = {
2753     "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha256
2754 };
2755
2756 static const struct ssh_kex ssh_ec_kex_nistp384 = {
2757     "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha384
2758 };
2759
2760 static const struct ssh_kex ssh_ec_kex_nistp521 = {
2761     "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH, NULL, NULL, 0, 0, &ssh_sha512
2762 };
2763
2764 static const struct ssh_kex *const ec_kex_list[] = {
2765     &ssh_ec_kex_curve25519,
2766     &ssh_ec_kex_nistp256,
2767     &ssh_ec_kex_nistp384,
2768     &ssh_ec_kex_nistp521
2769 };
2770
2771 const struct ssh_kexes ssh_ecdh_kex = {
2772     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
2773     ec_kex_list
2774 };